| @@ -270,6 +270,41 @@ protected: | |||||
| }; | }; | ||||
| using Remap = RemapForward; | using Remap = RemapForward; | ||||
| class RemapBackwardData : public RemapBase { | |||||
| DEF_OPR_IMPL(RemapBackwardData, RemapBase, 2, 1); | |||||
| public: | |||||
| virtual void exec(_megdnn_tensor_in map_xy, _megdnn_tensor_in diff, | |||||
| _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; | |||||
| virtual size_t get_workspace_in_bytes(const TensorLayout& map_xy, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad) = 0; | |||||
| protected: | |||||
| void check_exec(const TensorLayout& map_xy, const TensorLayout& diff, | |||||
| const TensorLayout& grad, size_t workspace_in_bytes); | |||||
| }; | |||||
| class RemapBackwardMat : public RemapBase { | |||||
| DEF_OPR_IMPL(RemapBackwardMat, RemapBase, 3, 1); | |||||
| public: | |||||
| virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy, | |||||
| _megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
| const TensorLayout& map_xy, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad) = 0; | |||||
| protected: | |||||
| void check_exec(const TensorLayout& src, const TensorLayout& map_xy, | |||||
| const TensorLayout& diff, const TensorLayout& grad, | |||||
| size_t workspace_in_bytes); | |||||
| }; | |||||
| class SeparableFilterBase : public OperatorBase { | class SeparableFilterBase : public OperatorBase { | ||||
| DEF_OPR_IMPL_CTOR(SeparableFilterBase, OperatorBase); | DEF_OPR_IMPL_CTOR(SeparableFilterBase, OperatorBase); | ||||
| DEF_OPR_PARAM(SeparableFilter); | DEF_OPR_PARAM(SeparableFilter); | ||||
| @@ -197,6 +197,8 @@ private: | |||||
| cb(ROIAlignBackward) \ | cb(ROIAlignBackward) \ | ||||
| cb(BatchConvBiasForward) \ | cb(BatchConvBiasForward) \ | ||||
| cb(Remap) \ | cb(Remap) \ | ||||
| cb(RemapBackwardData) \ | |||||
| cb(RemapBackwardMat) \ | |||||
| /*! | /*! | ||||
| * \brief specialize HandleImpl::create_operator for a single opr type; | * \brief specialize HandleImpl::create_operator for a single opr type; | ||||
| @@ -50,6 +50,7 @@ void RemapBase::check_layout_fwd(const TensorLayout& src, | |||||
| megdnn_assert(dst.shape[0] == src.shape[0], "%s", errmsg().c_str()); | megdnn_assert(dst.shape[0] == src.shape[0], "%s", errmsg().c_str()); | ||||
| megdnn_assert(map_xy.shape[3] == 2); | megdnn_assert(map_xy.shape[3] == 2); | ||||
| megdnn_assert(map_xy.shape[0] == src.shape[0]); | megdnn_assert(map_xy.shape[0] == src.shape[0]); | ||||
| megdnn_assert_contiguous(src); | |||||
| // map_xy only support floa32 type | // map_xy only support floa32 type | ||||
| // map_xy always in NHWC format | // map_xy always in NHWC format | ||||
| @@ -85,6 +86,34 @@ void Remap::check_exec(const TensorLayout& src, const TensorLayout& map_xy, | |||||
| megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | ||||
| } | } | ||||
| void RemapBackwardData::check_exec(const TensorLayout& map_xy, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad, | |||||
| size_t workspace_in_bytes) { | |||||
| check_layout_fwd(grad, map_xy, diff); | |||||
| megdnn_assert(grad.dtype == dtype::Float32() MEGDNN_INC_FLOAT16( | |||||
| || grad.dtype == dtype::BFloat16()), | |||||
| "Backward Remap only supports Float32/BFloat16."); | |||||
| auto required_workspace_in_bytes = | |||||
| get_workspace_in_bytes(map_xy, diff, grad); | |||||
| megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||||
| } | |||||
| void RemapBackwardMat::check_exec(const TensorLayout& src, | |||||
| const TensorLayout& map_xy, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad, | |||||
| size_t workspace_in_bytes) { | |||||
| check_layout_fwd(src, map_xy, diff); | |||||
| megdnn_assert_eq_layout(map_xy, grad); | |||||
| megdnn_assert(grad.dtype == dtype::Float32() MEGDNN_INC_FLOAT16( | |||||
| || grad.dtype == dtype::BFloat16()), | |||||
| "Backward Remap only supports Float32/BFloat16."); | |||||
| auto required_workspace_in_bytes = | |||||
| get_workspace_in_bytes(src, map_xy, diff, grad); | |||||
| megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||||
| } | |||||
| } // namespace megdnn | } // namespace megdnn | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -0,0 +1,71 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/remap/backward_data.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/cuda/remap/common.h" | |||||
| #include "src/cuda/remap/opr_impl.h" | |||||
| #include "src/cuda/utils.h" | |||||
| using namespace megdnn; | |||||
| using namespace cuda; | |||||
| void RemapBackwardDataImpl::exec(_megdnn_tensor_in map_xy, | |||||
| _megdnn_tensor_in diff, | |||||
| _megdnn_tensor_out grad, | |||||
| _megdnn_workspace workspace) { | |||||
| check_exec(map_xy.layout, diff.layout, grad.layout, workspace.size); | |||||
| megdnn_assert(param().imode == param::Remap::InterpolationMode::LINEAR, | |||||
| "only support LINEAR interpolationMode"); | |||||
| megdnn_assert(param().format == param::Remap::Format::NCHW, | |||||
| "only support NCHW format for remap backward"); | |||||
| auto stream = cuda_stream(this->handle()); | |||||
| int N, C, IH, IW, OH, OW; | |||||
| N = grad.layout.shape[0]; | |||||
| C = grad.layout.shape[1]; | |||||
| IH = grad.layout.shape[2]; | |||||
| IW = grad.layout.shape[3]; | |||||
| OH = map_xy.layout.shape[1]; | |||||
| OW = map_xy.layout.shape[2]; | |||||
| #define cb(dt, _format, bmode) \ | |||||
| if (param().format == param::Remap::Format::_format && \ | |||||
| param().border_type == param::Remap::BorderMode::bmode) { \ | |||||
| using ctype = DTypeTrait<dt>::ctype; \ | |||||
| remap::backwarddata_proxy<ctype, param_enumv::Remap::Format::_format, \ | |||||
| ::BorderMode::BORDER_##bmode>( \ | |||||
| grad.compatible_ptr<ctype>(), \ | |||||
| map_xy.compatible_ptr<dt_float32>(), \ | |||||
| diff.compatible_ptr<ctype>(), N, C, IH, IW, OH, OW, stream); \ | |||||
| break; \ | |||||
| } | |||||
| #define support_dtype(dt) \ | |||||
| case DTypeTrait<dt>::enumv: { \ | |||||
| cb(dt, NCHW, CONSTANT); \ | |||||
| cb(dt, NCHW, REPLICATE); \ | |||||
| cb(dt, NCHW, REFLECT); \ | |||||
| cb(dt, NCHW, REFLECT_101); \ | |||||
| cb(dt, NCHW, WRAP); \ | |||||
| megdnn_throw("unsupported border type in remap cuda"); \ | |||||
| } | |||||
| switch (grad.layout.dtype.enumv()) { | |||||
| support_dtype(dtype::Float32); | |||||
| support_dtype(dtype::BFloat16); | |||||
| default: | |||||
| megdnn_throw("unsupported dtype in remap backward cuda\n"); | |||||
| } | |||||
| #undef support_dtype | |||||
| #undef cb | |||||
| } | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,169 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/remap/backward_data.cu | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include <cuda_runtime.h> | |||||
| #include "src/common/rounding_converter.cuh" | |||||
| #include "src/cuda/cv/kernel_common.cuh" | |||||
| #include "src/cuda/remap/common.h" | |||||
| #include "src/cuda/utils.cuh" | |||||
| using namespace megdnn; | |||||
| using namespace cuda; | |||||
| using namespace remap; | |||||
| using namespace rounding; | |||||
| namespace { | |||||
| template <const uint32_t format> | |||||
| __device__ inline int get_offset(int height, int width, int channel, int h, | |||||
| int w, int c); | |||||
| template <> | |||||
| __device__ inline int get_offset<param_enumv::Remap::Format::NCHW>( | |||||
| int height, int width, int channel, int h, int w, int c) { | |||||
| return channel * h * w + height * w + width; | |||||
| } | |||||
| template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||||
| struct GetSrcData { | |||||
| __device__ static inline int get_index(int height, int width, int channel, | |||||
| int h, int w, int c) { | |||||
| height = megcv::border_interpolate<bmode>(height, h); | |||||
| width = megcv::border_interpolate<bmode>(width, w); | |||||
| return get_offset<format>(height, width, channel, h, w, c); | |||||
| } | |||||
| }; | |||||
| template <typename ctype, const uint32_t format> | |||||
| struct GetSrcData<ctype, format, ::BorderMode::BORDER_CONSTANT> { | |||||
| __device__ static inline int get_index(int height, int width, int channel, | |||||
| int h, int w, int c) { | |||||
| return (height >= 0 && height < h && width >= 0 && width < w) | |||||
| ? get_offset<format>(height, width, channel, h, w, c) | |||||
| : -1; | |||||
| } | |||||
| }; | |||||
| template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||||
| __global__ void kern_general(ctype* __restrict grad, const float* map_xy, | |||||
| const ctype* diff, int C, int IH, int IW, int OH, | |||||
| int OW) { | |||||
| int ow = blockIdx.x * blockDim.x + threadIdx.x; | |||||
| int oh = blockIdx.y * blockDim.y + threadIdx.y; | |||||
| grad += blockIdx.z * C * IH * IW; | |||||
| diff += blockIdx.z * C * OH * OW; | |||||
| map_xy += blockIdx.z * 2 * OH * OW; | |||||
| RoundingConverter<ctype> round_converter; | |||||
| if (ow < OW && oh < OH) { | |||||
| float index_col = map_xy[oh * OW * 2 + ow * 2 + 0]; | |||||
| float index_row = map_xy[oh * OW * 2 + ow * 2 + 1]; | |||||
| int col = static_cast<int>(floor(index_col)); | |||||
| int row = static_cast<int>(floor(index_row)); | |||||
| float v = index_col - col; // alphah | |||||
| float u = index_row - row; // alphaw | |||||
| const float one = 1.f; | |||||
| for (int c = 0; c < C; ++c) { | |||||
| float hidden = static_cast<float>( | |||||
| diff[get_offset<format>(oh, ow, c, OH, OW, C)]); | |||||
| int a00 = GetSrcData<ctype, format, bmode>::get_index( | |||||
| row + 0, col + 0, c, IH, IW, C); | |||||
| if (a00 != -1) { | |||||
| atomic_add(grad + a00, | |||||
| round_converter((one - u) * (one - v) * hidden)); | |||||
| } | |||||
| int a01 = GetSrcData<ctype, format, bmode>::get_index( | |||||
| row + 0, col + 1, c, IH, IW, C); | |||||
| if (a01 != -1) { | |||||
| atomic_add(grad + a01, round_converter((one - u) * v * hidden)); | |||||
| } | |||||
| int a10 = GetSrcData<ctype, format, bmode>::get_index( | |||||
| row + 1, col + 0, c, IH, IW, C); | |||||
| if (a10 != -1) { | |||||
| atomic_add(grad + a10, round_converter(u * (one - v) * hidden)); | |||||
| } | |||||
| int a11 = GetSrcData<ctype, param_enumv::Remap::Format::NCHW, | |||||
| bmode>::get_index(row + 1, col + 1, c, IH, IW, | |||||
| C); | |||||
| if (a11 != -1) { | |||||
| atomic_add(grad + a11, round_converter(u * v * hidden)); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||||
| void dispatch_backwarddata(ctype* grad, const float* map_xy, const ctype* diff, | |||||
| int N, int C, int IH, int IW, int OH, int OW, | |||||
| cudaStream_t stream) { | |||||
| const int BX = 32, BY = 16; | |||||
| const int max_batch_size = 65535; | |||||
| while (N) { | |||||
| size_t curr_batch_size = N < max_batch_size ? N : max_batch_size; | |||||
| dim3 threads(BX, BY); | |||||
| dim3 blocks((OW + BX - 1) / BX, (OH + BY - 1) / BY, curr_batch_size); | |||||
| cuda_check(cudaMemsetAsync( | |||||
| grad, 0, sizeof(ctype) * curr_batch_size * C * IH * IW, | |||||
| stream)); | |||||
| kern_general<ctype, format, bmode><<<blocks, threads, 0, stream>>>( | |||||
| grad, map_xy, diff, C, IH, IW, OH, OW); | |||||
| N -= curr_batch_size; | |||||
| grad += curr_batch_size * C * IH * IW; | |||||
| diff += curr_batch_size * C * OH * OW; | |||||
| map_xy += curr_batch_size * 2 * OH * OW; | |||||
| } | |||||
| } | |||||
| } // anonymous namespace | |||||
| namespace megdnn { | |||||
| namespace cuda { | |||||
| namespace remap { | |||||
| template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||||
| void backwarddata_proxy(ctype* grad, const float* map_xy, const ctype* diff, | |||||
| int N, int C, int IH, int IW, int OH, int OW, | |||||
| cudaStream_t stream) { | |||||
| dispatch_backwarddata<ctype, format, bmode>(grad, map_xy, diff, N, C, IH, | |||||
| IW, OH, OW, stream); | |||||
| after_kernel_launch(); | |||||
| } | |||||
| #define INST(ctype, format, bmode) \ | |||||
| template void backwarddata_proxy< \ | |||||
| ctype, param_enumv::Remap::Format::format, ::BorderMode::bmode>( \ | |||||
| ctype*, const float*, const ctype*, int, int, int, int, int, int, \ | |||||
| cudaStream_t); | |||||
| #define FOR_FORMAT_BMODE(ctype) \ | |||||
| INST(ctype, NCHW, BORDER_CONSTANT) \ | |||||
| INST(ctype, NCHW, BORDER_REPLICATE) \ | |||||
| INST(ctype, NCHW, BORDER_REFLECT) \ | |||||
| INST(ctype, NCHW, BORDER_REFLECT_101) \ | |||||
| INST(ctype, NCHW, BORDER_WRAP) | |||||
| FOR_FORMAT_BMODE(float) | |||||
| MEGDNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_bfloat16)) | |||||
| #undef FOR_FORMAT_BMODE | |||||
| #undef INST | |||||
| } // namespace remap | |||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,73 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/remap/backward_mat.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include "src/cuda/remap/common.h" | |||||
| #include "src/cuda/remap/opr_impl.h" | |||||
| #include "src/cuda/utils.h" | |||||
| using namespace megdnn; | |||||
| using namespace cuda; | |||||
| void RemapBackwardMatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy, | |||||
| _megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||||
| _megdnn_workspace workspace) { | |||||
| check_exec(src.layout, map_xy.layout, diff.layout, grad.layout, | |||||
| workspace.size); | |||||
| megdnn_assert(param().imode == param::Remap::InterpolationMode::LINEAR, | |||||
| "only support LINEAR interpolationMode"); | |||||
| megdnn_assert(param().format == param::Remap::Format::NCHW, | |||||
| "only support NCHW format for remap backward"); | |||||
| auto stream = cuda_stream(this->handle()); | |||||
| int N, C, IH, IW, OH, OW; | |||||
| N = src.layout.shape[0]; | |||||
| C = src.layout.shape[1]; | |||||
| IH = src.layout.shape[2]; | |||||
| IW = src.layout.shape[3]; | |||||
| OH = map_xy.layout.shape[1]; | |||||
| OW = map_xy.layout.shape[2]; | |||||
| #define cb(dt, _format, bmode) \ | |||||
| if (param().format == param::Remap::Format::_format && \ | |||||
| param().border_type == param::Remap::BorderMode::bmode) { \ | |||||
| using ctype = DTypeTrait<dt>::ctype; \ | |||||
| remap::backwardmat_proxy<ctype, param_enumv::Remap::Format::_format, \ | |||||
| ::BorderMode::BORDER_##bmode>( \ | |||||
| src.compatible_ptr<ctype>(), \ | |||||
| map_xy.compatible_ptr<dt_float32>(), \ | |||||
| diff.compatible_ptr<ctype>(), \ | |||||
| grad.compatible_ptr<dt_float32>(), N, C, IH, IW, OH, OW, \ | |||||
| param().scalar, stream); \ | |||||
| break; \ | |||||
| } | |||||
| #define support_dtype(dt) \ | |||||
| case DTypeTrait<dt>::enumv: { \ | |||||
| cb(dt, NCHW, CONSTANT); \ | |||||
| cb(dt, NCHW, REPLICATE); \ | |||||
| cb(dt, NCHW, REFLECT); \ | |||||
| cb(dt, NCHW, REFLECT_101); \ | |||||
| cb(dt, NCHW, WRAP); \ | |||||
| megdnn_throw("unsupported border type in remap cuda"); \ | |||||
| } | |||||
| switch (src.layout.dtype.enumv()) { | |||||
| support_dtype(dtype::Float32); | |||||
| support_dtype(dtype::BFloat16); | |||||
| default: | |||||
| megdnn_throw("unsupported dtype in remap backward cuda\n"); | |||||
| } | |||||
| #undef support_dtype | |||||
| #undef cb | |||||
| } | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,170 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/remap/backward_mat.cu | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | |||||
| #include <cuda_runtime.h> | |||||
| #include "src/common/rounding_converter.cuh" | |||||
| #include "src/cuda/cv/kernel_common.cuh" | |||||
| #include "src/cuda/remap/common.h" | |||||
| #include "src/cuda/utils.cuh" | |||||
| using namespace megdnn; | |||||
| using namespace cuda; | |||||
| using namespace remap; | |||||
| using namespace rounding; | |||||
| namespace { | |||||
| template <const uint32_t format> | |||||
| __device__ inline int get_offset(int height, int width, int channel, int h, | |||||
| int w, int c); | |||||
| template <> | |||||
| __device__ inline int get_offset<param_enumv::Remap::Format::NCHW>( | |||||
| int height, int width, int channel, int h, int w, int c) { | |||||
| return channel * h * w + height * w + width; | |||||
| } | |||||
| template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||||
| struct GetSrcData { | |||||
| __device__ static inline int get_index(int height, int width, int channel, | |||||
| int h, int w, int c) { | |||||
| height = megcv::border_interpolate<bmode>(height, h); | |||||
| width = megcv::border_interpolate<bmode>(width, w); | |||||
| return get_offset<format>(height, width, channel, h, w, c); | |||||
| } | |||||
| }; | |||||
| template <typename ctype, const uint32_t format> | |||||
| struct GetSrcData<ctype, format, ::BorderMode::BORDER_CONSTANT> { | |||||
| __device__ static inline int get_index(int height, int width, int channel, | |||||
| int h, int w, int c) { | |||||
| return (height >= 0 && height < h && width >= 0 && width < w) | |||||
| ? get_offset<format>(height, width, channel, h, w, c) | |||||
| : -1; | |||||
| } | |||||
| }; | |||||
| template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||||
| __global__ void kern_general(const ctype* src, const float* map_xy, | |||||
| const ctype* diff, float* __restrict grad, int C, | |||||
| int IH, int IW, int OH, int OW, float scalar) { | |||||
| int ow = blockIdx.x * blockDim.x + threadIdx.x; | |||||
| int oh = blockIdx.y * blockDim.y + threadIdx.y; | |||||
| src += blockIdx.z * C * IH * IW; | |||||
| diff += blockIdx.z * C * OH * OW; | |||||
| map_xy += blockIdx.z * 2 * OH * OW; | |||||
| grad += blockIdx.z * 2 * OH * OW; | |||||
| RoundingConverter<ctype> round_converter; | |||||
| if (ow < OW && oh < OH) { | |||||
| float index_col = map_xy[oh * OW * 2 + ow * 2 + 0]; | |||||
| float index_row = map_xy[oh * OW * 2 + ow * 2 + 1]; | |||||
| int col = static_cast<int>(floor(index_col)); | |||||
| int row = static_cast<int>(floor(index_row)); | |||||
| float v = index_col - col; // alphaw | |||||
| float u = index_row - row; // alphah | |||||
| const float one = 1.f; | |||||
| for (int c = 0; c < C; ++c) { | |||||
| float hidden = static_cast<float>( | |||||
| diff[get_offset<format>( | |||||
| oh, ow, c, OH, OW, C)]); | |||||
| float du = 0.f, dv = 0.f; | |||||
| int a00 = GetSrcData<ctype, format, bmode>::get_index( | |||||
| row + 0, col + 0, c, IH, IW, C); | |||||
| int a01 = GetSrcData<ctype, format, bmode>::get_index( | |||||
| row + 0, col + 1, c, IH, IW, C); | |||||
| int a10 = GetSrcData<ctype, format, bmode>::get_index( | |||||
| row + 1, col + 0, c, IH, IW, C); | |||||
| int a11 = GetSrcData<ctype, format, bmode>::get_index( | |||||
| row + 1, col + 1, c, IH, IW, C); | |||||
| dv -= ((a00 != -1) ? src[a00] : scalar) * (one - u); | |||||
| dv += ((a01 != -1) ? src[a01] : scalar) * (one - u); | |||||
| dv -= ((a10 != -1) ? src[a10] : scalar) * u; | |||||
| dv += ((a11 != -1) ? src[a11] : scalar) * u; | |||||
| du -= ((a00 != -1) ? src[a00] : scalar) * (one - v); | |||||
| du -= ((a01 != -1) ? src[a01] : scalar) * v; | |||||
| du += ((a10 != -1) ? src[a10] : scalar) * (one - v); | |||||
| du += ((a11 != -1) ? src[a11] : scalar) * v; | |||||
| grad[oh * OW * 2 + ow * 2 + 0] += round_converter(hidden * dv); | |||||
| grad[oh * OW * 2 + ow * 2 + 1] += round_converter(hidden * du); | |||||
| } | |||||
| } | |||||
| } | |||||
| template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||||
| void dispatch_backwardmat(const ctype* src, const float* map_xy, | |||||
| const ctype* diff, float* grad, int N, int C, int IH, | |||||
| int IW, int OH, int OW, float scalar, | |||||
| cudaStream_t stream) { | |||||
| const int BX = 32, BY = 16; | |||||
| const int max_batch_size = 65535; | |||||
| while (N) { | |||||
| size_t curr_batch_size = N < max_batch_size ? N : max_batch_size; | |||||
| dim3 threads(BX, BY); | |||||
| dim3 blocks((OW + BX - 1) / BX, (OH + BY - 1) / BY, curr_batch_size); | |||||
| cuda_check(cudaMemsetAsync( | |||||
| grad, 0, sizeof(float) * curr_batch_size * OH * OW * 2, | |||||
| stream)); | |||||
| kern_general<ctype, format, bmode><<<blocks, threads, 0, stream>>>( | |||||
| src, map_xy, diff, grad, C, IH, IW, OH, OW, scalar); | |||||
| N -= curr_batch_size; | |||||
| src += curr_batch_size * C * IH * IW; | |||||
| diff += curr_batch_size * C * OH * OW; | |||||
| map_xy += curr_batch_size * 2 * OH * OW; | |||||
| grad += curr_batch_size * 2 * OH * OW; | |||||
| } | |||||
| } | |||||
| } // anonymous namespace | |||||
| namespace megdnn { | |||||
| namespace cuda { | |||||
| namespace remap { | |||||
| template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||||
| void backwardmat_proxy(const ctype* src, const float* map_xy, const ctype* diff, | |||||
| float* grad, int N, int C, int IH, int IW, int OH, | |||||
| int OW, float scalar, cudaStream_t stream) { | |||||
| dispatch_backwardmat<ctype, format, bmode>(src, map_xy, diff, grad, N, C, | |||||
| IH, IW, OH, OW, scalar, stream); | |||||
| after_kernel_launch(); | |||||
| } | |||||
| #define INST(ctype, format, bmode) \ | |||||
| template void backwardmat_proxy<ctype, param_enumv::Remap::Format::format, \ | |||||
| ::BorderMode::bmode>( \ | |||||
| const ctype*, const float*, const ctype*, float*, int, int, int, \ | |||||
| int, int, int, float, cudaStream_t); | |||||
| #define FOR_FORMAT_BMODE(ctype) \ | |||||
| INST(ctype, NCHW, BORDER_CONSTANT) \ | |||||
| INST(ctype, NCHW, BORDER_REPLICATE) \ | |||||
| INST(ctype, NCHW, BORDER_REFLECT) \ | |||||
| INST(ctype, NCHW, BORDER_REFLECT_101) \ | |||||
| INST(ctype, NCHW, BORDER_WRAP) | |||||
| FOR_FORMAT_BMODE(float) | |||||
| MEGDNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_bfloat16)) | |||||
| #undef FOR_FORMAT_BMODE | |||||
| #undef INST | |||||
| } // namespace remap | |||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -24,7 +24,17 @@ namespace remap { | |||||
| template <typename ctype, const uint32_t format, ::BorderMode bmode> | template <typename ctype, const uint32_t format, ::BorderMode bmode> | ||||
| void forward_proxy(const ctype* src, const float* map_xy, ctype* dst, int N, | void forward_proxy(const ctype* src, const float* map_xy, ctype* dst, int N, | ||||
| int C, int IH, int IW, int OH, int OW, float scalar, | int C, int IH, int IW, int OH, int OW, float scalar, | ||||
| int S_IN, int S_IC, int S_IH, int S_IW, cudaStream_t stream); | |||||
| cudaStream_t stream); | |||||
| template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||||
| void backwarddata_proxy(ctype* grad, const float* map_xy, const ctype* diff, | |||||
| int N, int C, int IH, int IW, int OH, int OW, | |||||
| cudaStream_t stream); | |||||
| template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||||
| void backwardmat_proxy(const ctype* src, const float* map_xy, const ctype* diff, | |||||
| float* grad, int N, int C, int IH, int IW, int OH, | |||||
| int OW, float scalar, cudaStream_t stream); | |||||
| } // namespace remap | } // namespace remap | ||||
| } // namespace cuda | } // namespace cuda | ||||
| @@ -22,9 +22,10 @@ using namespace cuda; | |||||
| void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out map_xy, | void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out map_xy, | ||||
| _megdnn_tensor_in dst, _megdnn_workspace workspace) { | _megdnn_tensor_in dst, _megdnn_workspace workspace) { | ||||
| check_exec(src.layout, map_xy.layout, dst.layout, workspace.size); | check_exec(src.layout, map_xy.layout, dst.layout, workspace.size); | ||||
| megdnn_assert(map_xy.layout.dtype.enumv() == | |||||
| DTypeTrait<dtype::Float32>::enumv); | |||||
| auto stream = cuda_stream(this->handle()); | auto stream = cuda_stream(this->handle()); | ||||
| int N, C, IH, IW, OH, OW; | int N, C, IH, IW, OH, OW; | ||||
| ptrdiff_t S_IN = 0, S_IC = 0, S_IH = 0, S_IW = 0; | |||||
| OH = map_xy.layout.shape[1]; | OH = map_xy.layout.shape[1]; | ||||
| OW = map_xy.layout.shape[2]; | OW = map_xy.layout.shape[2]; | ||||
| @@ -36,10 +37,6 @@ void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out map_xy, | |||||
| C = src.layout.shape[1]; | C = src.layout.shape[1]; | ||||
| IH = src.layout.shape[2]; | IH = src.layout.shape[2]; | ||||
| IW = src.layout.shape[3]; | IW = src.layout.shape[3]; | ||||
| S_IN = src.layout.stride[0]; | |||||
| S_IC = src.layout.stride[1]; | |||||
| S_IH = src.layout.stride[2]; | |||||
| S_IW = src.layout.stride[3]; | |||||
| } else if (param().format == param::Remap::Format::NHWC) { | } else if (param().format == param::Remap::Format::NHWC) { | ||||
| N = src.layout.shape[0]; | N = src.layout.shape[0]; | ||||
| C = src.layout.shape[3]; | C = src.layout.shape[3]; | ||||
| @@ -58,7 +55,7 @@ void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out map_xy, | |||||
| src.compatible_ptr<ctype>(), \ | src.compatible_ptr<ctype>(), \ | ||||
| map_xy.compatible_ptr<dt_float32>(), \ | map_xy.compatible_ptr<dt_float32>(), \ | ||||
| dst.compatible_ptr<ctype>(), N, C, IH, IW, OH, OW, \ | dst.compatible_ptr<ctype>(), N, C, IH, IW, OH, OW, \ | ||||
| param().scalar, S_IN, S_IC, S_IH, S_IW, stream); \ | |||||
| param().scalar, stream); \ | |||||
| break; \ | break; \ | ||||
| } | } | ||||
| @@ -78,15 +75,16 @@ void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out map_xy, | |||||
| } | } | ||||
| switch (src.layout.dtype.enumv()) { | switch (src.layout.dtype.enumv()) { | ||||
| support_dtype(dtype::Float32) | |||||
| MEGDNN_INC_FLOAT16(support_dtype(dtype::Float16)) | |||||
| support_dtype(dtype::Int8) | |||||
| support_dtype(dtype::Uint8) | |||||
| support_dtype(dtype::Float32); | |||||
| MEGDNN_INC_FLOAT16(support_dtype(dtype::Float16)); | |||||
| MEGDNN_INC_FLOAT16(support_dtype(dtype::BFloat16)); | |||||
| support_dtype(dtype::Int8); | |||||
| support_dtype(dtype::Uint8); | |||||
| default: | default: | ||||
| megdnn_throw("unsupported dtype in remap cuda"); | megdnn_throw("unsupported dtype in remap cuda"); | ||||
| } | } | ||||
| #undef supported_dtype | |||||
| #undef support_dtype | |||||
| #undef cb | #undef cb | ||||
| } | } | ||||
| @@ -23,17 +23,6 @@ using namespace rounding; | |||||
| namespace { | namespace { | ||||
| template <typename ctype> | |||||
| struct DirectSrcVisitor { | |||||
| const ctype* ptr; | |||||
| __device__ __forceinline__ const ctype* get(int batch, int im_size) { | |||||
| return ptr + batch * im_size; | |||||
| } | |||||
| void move_batch(size_t batch, size_t im_size) { ptr += batch * im_size; } | |||||
| }; | |||||
| template <const uint32_t format> | template <const uint32_t format> | ||||
| __device__ inline int get_offset(int height, int width, int channel, int h, | __device__ inline int get_offset(int height, int width, int channel, int h, | ||||
| int w, int c); | int w, int c); | ||||
| @@ -74,14 +63,13 @@ struct GetSrcData<ctype, format, ::BorderMode::BORDER_CONSTANT> { | |||||
| } | } | ||||
| }; | }; | ||||
| template <typename ctype, typename SrcVisitor, ::BorderMode bmode> | |||||
| __global__ void kern_general(SrcVisitor src, const float* map_xy, | |||||
| template <typename ctype, ::BorderMode bmode> | |||||
| __global__ void kern_general(const ctype* __restrict sptr, const float* map_xy, | |||||
| ctype* __restrict dst, int C, int IH, int IW, | ctype* __restrict dst, int C, int IH, int IW, | ||||
| int OH, int OW, int S_IN, int S_IC, int S_IH, | |||||
| int S_IW, float scalar) { | |||||
| int OH, int OW, float scalar) { | |||||
| int ow = blockIdx.x * blockDim.x + threadIdx.x; | int ow = blockIdx.x * blockDim.x + threadIdx.x; | ||||
| int oh = blockIdx.y * blockDim.y + threadIdx.y; | int oh = blockIdx.y * blockDim.y + threadIdx.y; | ||||
| const ctype* __restrict sptr = src.get(blockIdx.z, S_IN); | |||||
| sptr += blockIdx.z * C * IH * IW; | |||||
| dst += blockIdx.z * C * OH * OW; | dst += blockIdx.z * C * OH * OW; | ||||
| map_xy += blockIdx.z * 2 * OH * OW; | map_xy += blockIdx.z * 2 * OH * OW; | ||||
| RoundingConverter<ctype> round_converter; | RoundingConverter<ctype> round_converter; | ||||
| @@ -89,8 +77,8 @@ __global__ void kern_general(SrcVisitor src, const float* map_xy, | |||||
| if (ow < OW && oh < OH) { | if (ow < OW && oh < OH) { | ||||
| float index_col = map_xy[oh * OW * 2 + ow * 2 + 0]; | float index_col = map_xy[oh * OW * 2 + ow * 2 + 0]; | ||||
| float index_row = map_xy[oh * OW * 2 + ow * 2 + 1]; | float index_row = map_xy[oh * OW * 2 + ow * 2 + 1]; | ||||
| int col = (int)floor(index_col); | |||||
| int row = (int)floor(index_row); | |||||
| int col = static_cast<int>(floor(index_col)); | |||||
| int row = static_cast<int>(floor(index_row)); | |||||
| float v = index_col - col; | float v = index_col - col; | ||||
| float u = index_row - row; | float u = index_row - row; | ||||
| for (int c = 0; c < C; ++c) { | for (int c = 0; c < C; ++c) { | ||||
| @@ -106,22 +94,25 @@ __global__ void kern_general(SrcVisitor src, const float* map_xy, | |||||
| ctype a11 = GetSrcData<ctype, param_enumv::Remap::Format::NCHW, | ctype a11 = GetSrcData<ctype, param_enumv::Remap::Format::NCHW, | ||||
| bmode>::get(sptr, row + 1, col + 1, c, IH, | bmode>::get(sptr, row + 1, col + 1, c, IH, | ||||
| IW, C, scalar); | IW, C, scalar); | ||||
| dst[get_offset<param_enumv::Remap::Format::NCHW>(oh, ow, c, OH, OW, | |||||
| C)] = | |||||
| round_converter(a00 * (1.f - u) * (1.f - v) + | |||||
| a01 * (1.f - u) * v + a10 * (1.f - v) * u + | |||||
| a11 * u * v); | |||||
| /* in remap, we use float as the type of intermediate result */ | |||||
| float result = static_cast<float>(a00) * (1.f - u) * (1.f - v) + | |||||
| static_cast<float>(a01) * (1.f - u) * v + | |||||
| static_cast<float>(a10) * (1.f - v) * u + | |||||
| static_cast<float>(a11) * u * v; | |||||
| dst[get_offset<param_enumv::Remap::Format::NCHW>( | |||||
| oh, ow, c, OH, OW, C)] = round_converter(result); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| template <typename ctype, typename SrcVisitor, ::BorderMode bmode> | |||||
| __global__ void kern_general_nhwc(SrcVisitor src, const float* map_xy, | |||||
| ctype* __restrict dst, int C, int IH, int IW, | |||||
| int OH, int OW, float scalar) { | |||||
| template <typename ctype, ::BorderMode bmode> | |||||
| __global__ void kern_general_nhwc(const ctype* __restrict sptr, | |||||
| const float* map_xy, ctype* __restrict dst, | |||||
| int C, int IH, int IW, int OH, int OW, | |||||
| float scalar) { | |||||
| int ow = blockIdx.x * blockDim.x + threadIdx.x; | int ow = blockIdx.x * blockDim.x + threadIdx.x; | ||||
| int oh = blockIdx.y * blockDim.y + threadIdx.y; | int oh = blockIdx.y * blockDim.y + threadIdx.y; | ||||
| const ctype* __restrict sptr = src.get(blockIdx.z, C * IH * IW); | |||||
| sptr += blockIdx.z * C * IH * IW; | |||||
| dst += blockIdx.z * C * OH * OW; | dst += blockIdx.z * C * OH * OW; | ||||
| map_xy += blockIdx.z * 2 * OH * OW; | map_xy += blockIdx.z * 2 * OH * OW; | ||||
| RoundingConverter<ctype> round_converter; | RoundingConverter<ctype> round_converter; | ||||
| @@ -129,8 +120,8 @@ __global__ void kern_general_nhwc(SrcVisitor src, const float* map_xy, | |||||
| if (ow < OW && oh < OH) { | if (ow < OW && oh < OH) { | ||||
| float index_col = map_xy[oh * OW * 2 + ow * 2 + 0]; | float index_col = map_xy[oh * OW * 2 + ow * 2 + 0]; | ||||
| float index_row = map_xy[oh * OW * 2 + ow * 2 + 1]; | float index_row = map_xy[oh * OW * 2 + ow * 2 + 1]; | ||||
| int col = (int)floor(index_col); | |||||
| int row = (int)floor(index_row); | |||||
| int col = static_cast<int>(floor(index_col)); | |||||
| int row = static_cast<int>(floor(index_row)); | |||||
| float v = index_col - col; | float v = index_col - col; | ||||
| float u = index_row - row; | float u = index_row - row; | ||||
| for (int c = 0; c < C; ++c) { | for (int c = 0; c < C; ++c) { | ||||
| @@ -146,21 +137,21 @@ __global__ void kern_general_nhwc(SrcVisitor src, const float* map_xy, | |||||
| ctype a11 = GetSrcData<ctype, param_enumv::Remap::Format::NHWC, | ctype a11 = GetSrcData<ctype, param_enumv::Remap::Format::NHWC, | ||||
| bmode>::get(sptr, row + 1, col + 1, c, IH, | bmode>::get(sptr, row + 1, col + 1, c, IH, | ||||
| IW, C, scalar); | IW, C, scalar); | ||||
| dst[get_offset<param_enumv::Remap::Format::NHWC>(oh, ow, c, OH, OW, | |||||
| C)] = | |||||
| round_converter(a00 * (1.f - u) * (1.f - v) + | |||||
| a01 * (1.f - u) * v + a10 * (1.f - v) * u + | |||||
| a11 * u * v); | |||||
| /* in remap, we use float as the type of intermediate result */ | |||||
| float result = static_cast<float>(a00) * (1.f - u) * (1.f - v) + | |||||
| static_cast<float>(a01) * (1.f - u) * v + | |||||
| static_cast<float>(a10) * (1.f - v) * u + | |||||
| static_cast<float>(a11) * u * v; | |||||
| dst[get_offset<param_enumv::Remap::Format::NHWC>( | |||||
| oh, ow, c, OH, OW, C)] = round_converter(result); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| template <typename ctype, typename SrcVisitor, const uint32_t format, | |||||
| ::BorderMode bmode> | |||||
| void dispatch_with_visitor(SrcVisitor src, const float* map_xy, ctype* dst, | |||||
| int N, int C, int IH, int IW, int OH, int OW, | |||||
| float scalar, int S_IN, int S_IC, int S_IH, int S_IW, | |||||
| cudaStream_t stream) { | |||||
| template <typename ctype, const uint32_t format, ::BorderMode bmode> | |||||
| void dispatch_forward(const ctype* src, const float* map_xy, ctype* dst, int N, | |||||
| int C, int IH, int IW, int OH, int OW, float scalar, | |||||
| cudaStream_t stream) { | |||||
| const int BX = 32, BY = 16; | const int BX = 32, BY = 16; | ||||
| const int max_batch_size = 65535; | const int max_batch_size = 65535; | ||||
| @@ -170,19 +161,17 @@ void dispatch_with_visitor(SrcVisitor src, const float* map_xy, ctype* dst, | |||||
| dim3 blocks((OW + BX - 1) / BX, (OH + BY - 1) / BY, curr_batch_size); | dim3 blocks((OW + BX - 1) / BX, (OH + BY - 1) / BY, curr_batch_size); | ||||
| if (format == param_enumv::Remap::Format::NCHW) { | if (format == param_enumv::Remap::Format::NCHW) { | ||||
| kern_general<ctype, SrcVisitor, bmode> | |||||
| <<<blocks, threads, 0, stream>>>(src, map_xy, dst, C, IH, | |||||
| IW, OH, OW, S_IN, S_IC, | |||||
| S_IH, S_IW, scalar); | |||||
| kern_general<ctype, bmode><<<blocks, threads, 0, stream>>>( | |||||
| src, map_xy, dst, C, IH, IW, OH, OW, scalar); | |||||
| } else if (format == param_enumv::Remap::Format::NHWC) { | } else if (format == param_enumv::Remap::Format::NHWC) { | ||||
| kern_general_nhwc<ctype, SrcVisitor, bmode> | |||||
| <<<blocks, threads, 0, stream>>>(src, map_xy, dst, C, IH, | |||||
| IW, OH, OW, scalar); | |||||
| kern_general_nhwc<ctype, bmode><<<blocks, threads, 0, stream>>>( | |||||
| src, map_xy, dst, C, IH, IW, OH, OW, scalar); | |||||
| } | } | ||||
| N -= curr_batch_size; | N -= curr_batch_size; | ||||
| src.move_batch(curr_batch_size, C * IH * IW); | |||||
| src += curr_batch_size * C * IH * IW; | |||||
| dst += curr_batch_size * C * OH * OW; | dst += curr_batch_size * C * OH * OW; | ||||
| map_xy += curr_batch_size * OH * OW * 2; | |||||
| } | } | ||||
| } | } | ||||
| @@ -195,22 +184,17 @@ namespace remap { | |||||
| template <typename ctype, const uint32_t format, ::BorderMode bmode> | template <typename ctype, const uint32_t format, ::BorderMode bmode> | ||||
| void forward_proxy(const ctype* src, const float* map_xy, ctype* dst, int N, | void forward_proxy(const ctype* src, const float* map_xy, ctype* dst, int N, | ||||
| int C, int IH, int IW, int OH, int OW, float scalar, | int C, int IH, int IW, int OH, int OW, float scalar, | ||||
| int S_IN, int S_IC, int S_IH, int S_IW, | |||||
| cudaStream_t stream) { | cudaStream_t stream) { | ||||
| DirectSrcVisitor<ctype> visitor; | |||||
| visitor.ptr = src; | |||||
| using SrcVisitor = DirectSrcVisitor<ctype>; | |||||
| dispatch_with_visitor<ctype, SrcVisitor, format, bmode>( | |||||
| visitor, map_xy, dst, N, C, IH, IW, OH, OW, scalar, S_IN, S_IC, | |||||
| S_IH, S_IW, stream); | |||||
| dispatch_forward<ctype, format, bmode>(src, map_xy, dst, N, C, IH, IW, OH, | |||||
| OW, scalar, stream); | |||||
| after_kernel_launch(); | after_kernel_launch(); | ||||
| } | } | ||||
| #define INST(ctype, format, bmode) \ | |||||
| template void forward_proxy<ctype, param_enumv::Remap::Format::format, \ | |||||
| ::BorderMode::bmode>( \ | |||||
| const ctype* src, const float*, ctype*, int, int, int, int, int, \ | |||||
| int, float, int, int, int, int, cudaStream_t); | |||||
| #define INST(ctype, format, bmode) \ | |||||
| template void forward_proxy<ctype, param_enumv::Remap::Format::format, \ | |||||
| ::BorderMode::bmode>( \ | |||||
| const ctype*, const float*, ctype*, int, int, int, int, int, int, \ | |||||
| float, cudaStream_t); | |||||
| #define FOR_FORMAT_BMODE(ctype) \ | #define FOR_FORMAT_BMODE(ctype) \ | ||||
| INST(ctype, NCHW, BORDER_CONSTANT) \ | INST(ctype, NCHW, BORDER_CONSTANT) \ | ||||
| @@ -226,11 +210,13 @@ void forward_proxy(const ctype* src, const float* map_xy, ctype* dst, int N, | |||||
| FOR_FORMAT_BMODE(float) | FOR_FORMAT_BMODE(float) | ||||
| MEGDNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_float16)) | MEGDNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_float16)) | ||||
| MEGDNN_INC_FLOAT16(FOR_FORMAT_BMODE(dt_bfloat16)) | |||||
| FOR_FORMAT_BMODE(int8_t) | FOR_FORMAT_BMODE(int8_t) | ||||
| FOR_FORMAT_BMODE(uint8_t) | FOR_FORMAT_BMODE(uint8_t) | ||||
| #undef FOR_BMODE | |||||
| #undef FOR_FORMAT_BMODE | |||||
| #undef INST | #undef INST | ||||
| } // namespace remap | } // namespace remap | ||||
| } // namespace cuda | } // namespace cuda | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -15,13 +15,41 @@ | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace cuda { | namespace cuda { | ||||
| class RemapImpl final : public Remap { | class RemapImpl final : public Remap { | ||||
| public: | |||||
| using Remap::Remap; | using Remap::Remap; | ||||
| void exec(_megdnn_tensor_in, _megdnn_tensor_in, _megdnn_tensor_out, | |||||
| _megdnn_workspace) override; | |||||
| void exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy, | |||||
| _megdnn_tensor_out dst, _megdnn_workspace workspace) override; | |||||
| size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||||
| const TensorLayout&) override { | |||||
| size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
| const TensorLayout& map_xy, | |||||
| const TensorLayout& dst) override { | |||||
| return 0; | |||||
| } | |||||
| }; | |||||
| class RemapBackwardDataImpl final : public RemapBackwardData { | |||||
| public: | |||||
| using RemapBackwardData::RemapBackwardData; | |||||
| void exec(_megdnn_tensor_in map_xy, _megdnn_tensor_in diff, | |||||
| _megdnn_tensor_out grad, _megdnn_workspace workspace) override; | |||||
| size_t get_workspace_in_bytes(const TensorLayout& map_xy, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad) override { | |||||
| return 0; | |||||
| } | |||||
| }; | |||||
| class RemapBackwardMatImpl final : public RemapBackwardMat { | |||||
| public: | |||||
| using RemapBackwardMat::RemapBackwardMat; | |||||
| void exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy, | |||||
| _megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||||
| _megdnn_workspace workspace) override; | |||||
| size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
| const TensorLayout& map_xy, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad) override { | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| }; | }; | ||||
| @@ -12,11 +12,13 @@ | |||||
| #include "src/naive/remap/opr_impl.h" | #include "src/naive/remap/opr_impl.h" | ||||
| #include "src/common/cv/helper.h" | #include "src/common/cv/helper.h" | ||||
| #include "src/common/rounding_converter.cuh" | |||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/naive/handle.h" | #include "src/naive/handle.h" | ||||
| using namespace megdnn; | using namespace megdnn; | ||||
| using namespace naive; | using namespace naive; | ||||
| using namespace rounding; | |||||
| namespace { | namespace { | ||||
| template <param::Remap::Format format> | template <param::Remap::Format format> | ||||
| @@ -36,35 +38,46 @@ inline int get_offset<param::Remap::Format::NHWC>(int height, int width, | |||||
| return height * w * c + width * c + channel; | return height * w * c + width * c + channel; | ||||
| } | } | ||||
| template <typename DataType, param::Remap::Format format, | |||||
| template <typename ctype, param::Remap::Format format, | |||||
| param::Remap::BorderMode bordertype> | param::Remap::BorderMode bordertype> | ||||
| struct GetSrcData { | struct GetSrcData { | ||||
| static inline DataType get(const DataType* src, int height, int width, | |||||
| int channel, int h, int w, int c, float, | |||||
| std::function<DataType(float)>) { | |||||
| static inline ctype get(const ctype* src, int height, int width, | |||||
| int channel, int h, int w, int c, float) { | |||||
| height = megcv::border_interpolate<bordertype>(height, h); | height = megcv::border_interpolate<bordertype>(height, h); | ||||
| width = megcv::border_interpolate<bordertype>(width, w); | width = megcv::border_interpolate<bordertype>(width, w); | ||||
| return src[get_offset<format>(height, width, channel, h, w, c)]; | return src[get_offset<format>(height, width, channel, h, w, c)]; | ||||
| } | } | ||||
| static inline int get_index(int height, int width, int channel, int h, | |||||
| int w, int c) { | |||||
| height = megcv::border_interpolate<bordertype>(height, h); | |||||
| width = megcv::border_interpolate<bordertype>(width, w); | |||||
| return get_offset<format>(height, width, channel, h, w, c); | |||||
| } | |||||
| }; | }; | ||||
| template <typename DataType, param::Remap::Format format> | |||||
| struct GetSrcData<DataType, format, param::Remap::BorderMode::CONSTANT> { | |||||
| static inline DataType get(const DataType* src, int height, int width, | |||||
| int channel, int h, int w, int c, float scalar, | |||||
| std::function<DataType(float)> round) { | |||||
| template <typename ctype, param::Remap::Format format> | |||||
| struct GetSrcData<ctype, format, param::Remap::BorderMode::CONSTANT> { | |||||
| static inline ctype get(const ctype* src, int height, int width, | |||||
| int channel, int h, int w, int c, float scalar) { | |||||
| RoundingConverter<ctype> round; | |||||
| return (height >= 0 && height < h && width >= 0 && width < w) | return (height >= 0 && height < h && width >= 0 && width < w) | ||||
| ? src[get_offset<format>(height, width, channel, h, w, | ? src[get_offset<format>(height, width, channel, h, w, | ||||
| c)] | c)] | ||||
| : static_cast<DataType>(round(scalar)); | |||||
| : round(scalar); | |||||
| } | |||||
| static inline int get_index(int height, int width, int channel, int h, | |||||
| int w, int c) { | |||||
| return (height >= 0 && height < h && width >= 0 && width < w) | |||||
| ? get_offset<format>(height, width, channel, h, w, c) | |||||
| : -1; | |||||
| } | } | ||||
| }; | }; | ||||
| template <typename DataType, param::Remap::Format format, | |||||
| template <typename ctype, param::Remap::Format format, | |||||
| param::Remap::BorderMode bordertype> | param::Remap::BorderMode bordertype> | ||||
| void remap_LINEAR(const DataType* src, const float* map_xy, DataType* dst, | |||||
| int N, int C, int IH, int IW, int OH, int OW, float scalar, | |||||
| std::function<DataType(float)> round) { | |||||
| void remap_LINEAR(const ctype* src, const float* map_xy, ctype* dst, int N, | |||||
| int C, int IH, int IW, int OH, int OW, float scalar) { | |||||
| RoundingConverter<ctype> round_converter; | |||||
| for (int n = 0; n < N; | for (int n = 0; n < N; | ||||
| ++n, src += C * IH * IW, dst += C * OH * OW, map_xy += OH * OW * 2) { | ++n, src += C * IH * IW, dst += C * OH * OW, map_xy += OH * OW * 2) { | ||||
| for (int h = 0; h < OH; ++h) { | for (int h = 0; h < OH; ++h) { | ||||
| @@ -73,47 +86,131 @@ void remap_LINEAR(const DataType* src, const float* map_xy, DataType* dst, | |||||
| float index_row = map_xy[h * OW * 2 + w * 2 + 1]; | float index_row = map_xy[h * OW * 2 + w * 2 + 1]; | ||||
| int col = static_cast<int>(floor(index_col)); | int col = static_cast<int>(floor(index_col)); | ||||
| int row = static_cast<int>(floor(index_row)); | int row = static_cast<int>(floor(index_row)); | ||||
| float v = index_col - col; | |||||
| float u = index_row - row; | |||||
| float one = 1.f; | |||||
| float v = index_col - col; // alphaw | |||||
| float u = index_row - row; // alphah | |||||
| const float one = 1.f; | |||||
| for (int c = 0; c < C; ++c) { | for (int c = 0; c < C; ++c) { | ||||
| DataType a00 = | |||||
| GetSrcData<DataType, format, bordertype>::get( | |||||
| src, row + 0, col + 0, c, IH, IW, C, scalar, | |||||
| round); | |||||
| DataType a01 = | |||||
| GetSrcData<DataType, format, bordertype>::get( | |||||
| src, row + 0, col + 1, c, IH, IW, C, scalar, | |||||
| round); | |||||
| DataType a10 = | |||||
| GetSrcData<DataType, format, bordertype>::get( | |||||
| src, row + 1, col + 0, c, IH, IW, C, scalar, | |||||
| round); | |||||
| DataType a11 = | |||||
| GetSrcData<DataType, format, bordertype>::get( | |||||
| src, row + 1, col + 1, c, IH, IW, C, scalar, | |||||
| round); | |||||
| ctype a00 = GetSrcData<ctype, format, bordertype>::get( | |||||
| src, row + 0, col + 0, c, IH, IW, C, scalar); | |||||
| ctype a01 = GetSrcData<ctype, format, bordertype>::get( | |||||
| src, row + 0, col + 1, c, IH, IW, C, scalar); | |||||
| ctype a10 = GetSrcData<ctype, format, bordertype>::get( | |||||
| src, row + 1, col + 0, c, IH, IW, C, scalar); | |||||
| ctype a11 = GetSrcData<ctype, format, bordertype>::get( | |||||
| src, row + 1, col + 1, c, IH, IW, C, scalar); | |||||
| dst[get_offset<format>(h, w, c, OH, OW, C)] = | dst[get_offset<format>(h, w, c, OH, OW, C)] = | ||||
| static_cast<DataType>( | |||||
| round(a00 * (one - u) * (one - v) + | |||||
| a01 * (one - u) * v + | |||||
| a10 * (one - v) * u + a11 * u * v)); | |||||
| round_converter(a00 * (one - v) * (one - u) + | |||||
| a01 * (one - u) * v + | |||||
| a10 * (one - v) * u + a11 * u * v); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| template <typename DataType, DTypeCategory cat> | |||||
| struct Round { | |||||
| static inline DataType round(float x) { return std::round(x); } | |||||
| }; | |||||
| template <typename ctype, param::Remap::Format format, | |||||
| param::Remap::BorderMode bordertype> | |||||
| void remap_LINEAR_backwarddata(ctype* grad, const float* map_xy, | |||||
| const ctype* diff, int N, int C, int IH, int IW, | |||||
| int OH, int OW) { | |||||
| RoundingConverter<ctype> round_converter; | |||||
| std::memset(grad, 0, sizeof(ctype) * N * C * IH * IW); | |||||
| for (int n = 0; n < N; | |||||
| ++n, grad += C * IH * IW, diff += C * OH * OW, map_xy += OH * OW * 2) { | |||||
| for (int h = 0; h < OH; ++h) { | |||||
| for (int w = 0; w < OW; ++w) { | |||||
| float index_col = map_xy[h * OW * 2 + w * 2 + 0]; | |||||
| float index_row = map_xy[h * OW * 2 + w * 2 + 1]; | |||||
| int col = static_cast<int>(floor(index_col)); | |||||
| int row = static_cast<int>(floor(index_row)); | |||||
| float v = index_col - col; // alphaw | |||||
| float u = index_row - row; // alphah | |||||
| const float one = 1.f; | |||||
| for (int c = 0; c < C; ++c) { | |||||
| ctype hidden = diff[get_offset<format>(h, w, c, OH, OW, C)]; | |||||
| template <typename DataType> | |||||
| struct Round<DataType, DTypeCategory::FLOAT> { | |||||
| static inline DataType round(float x) { return static_cast<DataType>(x); } | |||||
| }; | |||||
| int a00 = GetSrcData<ctype, format, bordertype>::get_index( | |||||
| row + 0, col + 0, c, IH, IW, C); | |||||
| if (a00 != -1) { | |||||
| grad[a00] += | |||||
| round_converter((one - v) * (one - u) * hidden); | |||||
| } | |||||
| int a01 = GetSrcData<ctype, format, bordertype>::get_index( | |||||
| row + 0, col + 1, c, IH, IW, C); | |||||
| if (a01 != -1) { | |||||
| grad[a01] += round_converter((one - u) * v * hidden); | |||||
| } | |||||
| int a10 = GetSrcData<ctype, format, bordertype>::get_index( | |||||
| row + 1, col + 0, c, IH, IW, C); | |||||
| if (a10 != -1) { | |||||
| grad[a10] += round_converter(u * (one - v) * hidden); | |||||
| } | |||||
| int a11 = GetSrcData<ctype, format, bordertype>::get_index( | |||||
| row + 1, col + 1, c, IH, IW, C); | |||||
| if (a11 != -1) { | |||||
| grad[a11] += round_converter(v * u * hidden); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| template <typename ctype, param::Remap::Format format, | |||||
| param::Remap::BorderMode bordertype> | |||||
| void remap_LINEAR_backwardmat(const ctype* src, const float* map_xy, | |||||
| const ctype* diff, float* grad, int N, int C, | |||||
| int IH, int IW, int OH, int OW, float scalar) { | |||||
| RoundingConverter<ctype> round_converter; | |||||
| std::memset(grad, 0, sizeof(float) * N * 2 * OH * OW); | |||||
| for (int n = 0; n < N; ++n, src += C * IH * IW, diff += C * OH * OW, | |||||
| map_xy += OH * OW * 2, grad += OH * OW * 2) { | |||||
| for (int h = 0; h < OH; ++h) { | |||||
| for (int w = 0; w < OW; ++w) { | |||||
| float index_col = map_xy[h * OW * 2 + w * 2 + 0]; | |||||
| float index_row = map_xy[h * OW * 2 + w * 2 + 1]; | |||||
| int col = static_cast<int>(floor(index_col)); | |||||
| int row = static_cast<int>(floor(index_row)); | |||||
| float v = index_col - col; // alphaw | |||||
| float u = index_row - row; // alphah | |||||
| const float one = 1.f; | |||||
| for (int c = 0; c < C; ++c) { | |||||
| float hidden = static_cast<float>( | |||||
| diff[get_offset<format>(h, w, c, OH, OW, C)]); | |||||
| float du = 0.f, dv = 0.f; | |||||
| int a00 = GetSrcData<ctype, format, bordertype>::get_index( | |||||
| row + 0, col + 0, c, IH, IW, C); | |||||
| int a01 = GetSrcData<ctype, format, bordertype>::get_index( | |||||
| row + 0, col + 1, c, IH, IW, C); | |||||
| int a10 = GetSrcData<ctype, format, bordertype>::get_index( | |||||
| row + 1, col + 0, c, IH, IW, C); | |||||
| int a11 = GetSrcData<ctype, format, bordertype>::get_index( | |||||
| row + 1, col + 1, c, IH, IW, C); | |||||
| dv -= ((a00 != -1) ? src[a00] : scalar) * (one - u); | |||||
| dv += ((a01 != -1) ? src[a01] : scalar) * (one - u); | |||||
| dv -= ((a10 != -1) ? src[a10] : scalar) * u; | |||||
| dv += ((a11 != -1) ? src[a11] : scalar) * u; | |||||
| du -= ((a00 != -1) ? src[a00] : scalar) * (one - v); | |||||
| du -= ((a01 != -1) ? src[a01] : scalar) * v; | |||||
| du += ((a10 != -1) ? src[a10] : scalar) * (one - v); | |||||
| du += ((a11 != -1) ? src[a11] : scalar) * v; | |||||
| grad[h * OW * 2 + w * 2 + 0] += | |||||
| round_converter(hidden * dv); | |||||
| grad[h * OW * 2 + w * 2 + 1] += | |||||
| round_converter(hidden * du); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| @@ -148,8 +245,7 @@ void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy, | |||||
| src.compatible_ptr<ctype>(), \ | src.compatible_ptr<ctype>(), \ | ||||
| map_xy.compatible_ptr<dt_float32>(), \ | map_xy.compatible_ptr<dt_float32>(), \ | ||||
| dst.compatible_ptr<ctype>(), N, C, IH, IW, OH, OW, \ | dst.compatible_ptr<ctype>(), N, C, IH, IW, OH, OW, \ | ||||
| param().scalar, \ | |||||
| Round<ctype, DTypeTrait<dt>::category>::round))); \ | |||||
| param().scalar))); \ | |||||
| break; \ | break; \ | ||||
| } | } | ||||
| @@ -172,6 +268,7 @@ void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy, | |||||
| support_dtype(dtype::Float32); | support_dtype(dtype::Float32); | ||||
| MEGDNN_INC_FLOAT16(support_dtype(dtype::Float16)); | MEGDNN_INC_FLOAT16(support_dtype(dtype::Float16)); | ||||
| MEGDNN_INC_FLOAT16(support_dtype(dtype::BFloat16)); | |||||
| support_dtype(dtype::Int8); | support_dtype(dtype::Int8); | ||||
| support_dtype(dtype::Uint8); | support_dtype(dtype::Uint8); | ||||
| #undef cb | #undef cb | ||||
| @@ -181,3 +278,109 @@ void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy, | |||||
| megdnn_throw("unsupported dtype in remap naive\n"); | megdnn_throw("unsupported dtype in remap naive\n"); | ||||
| } | } | ||||
| } | } | ||||
| void RemapBackwardDataImpl::exec(_megdnn_tensor_in map_xy, | |||||
| _megdnn_tensor_in diff, | |||||
| _megdnn_tensor_out grad, | |||||
| _megdnn_workspace workspace) { | |||||
| check_exec(map_xy.layout, diff.layout, grad.layout, workspace.size); | |||||
| megdnn_assert(param().format == param::Remap::Format::NCHW, | |||||
| "only support NCHW format for remap backward"); | |||||
| int N, C, IH, IW, OH, OW; | |||||
| N = grad.layout.shape[0]; | |||||
| C = grad.layout.shape[1]; | |||||
| IH = grad.layout.shape[2]; | |||||
| IW = grad.layout.shape[3]; | |||||
| OH = map_xy.layout.shape[1]; | |||||
| OW = map_xy.layout.shape[2]; | |||||
| switch (diff.layout.dtype.enumv()) { | |||||
| #define cb(dt, fmt, border, interpolation) \ | |||||
| if (param().format == param::Remap::Format::fmt && \ | |||||
| param().border_type == param::Remap::BorderMode::border && \ | |||||
| param().imode == param::Remap::InterpolationMode::interpolation) { \ | |||||
| using ctype = DTypeTrait<dt>::ctype; \ | |||||
| MEGDNN_DISPATCH_CPU_KERN_OPR((remap_##interpolation##_backwarddata< \ | |||||
| ctype, param::Remap::Format::fmt, \ | |||||
| param::Remap::BorderMode::border>( \ | |||||
| grad.compatible_ptr<ctype>(), \ | |||||
| map_xy.compatible_ptr<dt_float32>(), \ | |||||
| diff.compatible_ptr<ctype>(), N, C, IH, IW, OH, OW))); \ | |||||
| break; \ | |||||
| } | |||||
| #define support_dtype(dt) \ | |||||
| case DTypeTrait<dt>::enumv: { \ | |||||
| cb(dt, NCHW, CONSTANT, LINEAR); \ | |||||
| cb(dt, NCHW, REPLICATE, LINEAR); \ | |||||
| cb(dt, NCHW, REFLECT, LINEAR); \ | |||||
| cb(dt, NCHW, REFLECT_101, LINEAR); \ | |||||
| cb(dt, NCHW, WRAP, LINEAR); \ | |||||
| megdnn_throw( \ | |||||
| "format, border type or imode is incorrect in remap navie " \ | |||||
| "with dtype = " #dt); \ | |||||
| } | |||||
| support_dtype(dtype::Float32); | |||||
| MEGDNN_INC_FLOAT16(support_dtype(dtype::BFloat16)); | |||||
| #undef cb | |||||
| #undef support_dtype | |||||
| default: | |||||
| megdnn_throw("unsupported dtype in remap backward naive\n"); | |||||
| } | |||||
| } | |||||
| void RemapBackwardMatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy, | |||||
| _megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||||
| _megdnn_workspace workspace) { | |||||
| check_exec(src.layout, map_xy.layout, diff.layout, grad.layout, | |||||
| workspace.size); | |||||
| megdnn_assert(param().format == param::Remap::Format::NCHW, | |||||
| "only support NCHW format for remap backward"); | |||||
| int N, C, IH, IW, OH, OW; | |||||
| N = src.layout.shape[0]; | |||||
| C = src.layout.shape[1]; | |||||
| IH = src.layout.shape[2]; | |||||
| IW = src.layout.shape[3]; | |||||
| OH = map_xy.layout.shape[1]; | |||||
| OW = map_xy.layout.shape[2]; | |||||
| switch (src.layout.dtype.enumv()) { | |||||
| #define cb(dt, fmt, border, interpolation) \ | |||||
| if (param().format == param::Remap::Format::fmt && \ | |||||
| param().border_type == param::Remap::BorderMode::border && \ | |||||
| param().imode == param::Remap::InterpolationMode::interpolation) { \ | |||||
| using ctype = DTypeTrait<dt>::ctype; \ | |||||
| MEGDNN_DISPATCH_CPU_KERN_OPR((remap_##interpolation##_backwardmat< \ | |||||
| ctype, param::Remap::Format::fmt, \ | |||||
| param::Remap::BorderMode::border>( \ | |||||
| src.compatible_ptr<ctype>(), \ | |||||
| map_xy.compatible_ptr<dt_float32>(), \ | |||||
| diff.compatible_ptr<ctype>(), \ | |||||
| grad.compatible_ptr<dt_float32>(), N, C, IH, IW, OH, OW, \ | |||||
| param().scalar))); \ | |||||
| break; \ | |||||
| } | |||||
| #define support_dtype(dt) \ | |||||
| case DTypeTrait<dt>::enumv: { \ | |||||
| cb(dt, NCHW, CONSTANT, LINEAR); \ | |||||
| cb(dt, NCHW, REPLICATE, LINEAR); \ | |||||
| cb(dt, NCHW, REFLECT, LINEAR); \ | |||||
| cb(dt, NCHW, REFLECT_101, LINEAR); \ | |||||
| cb(dt, NCHW, WRAP, LINEAR); \ | |||||
| megdnn_throw( \ | |||||
| "format, border type or imode is incorrect in remap navie " \ | |||||
| "with dtype = " #dt); \ | |||||
| } | |||||
| support_dtype(dtype::Float32); | |||||
| MEGDNN_INC_FLOAT16(support_dtype(dtype::BFloat16)); | |||||
| #undef cb | |||||
| #undef support_dtype | |||||
| default: | |||||
| megdnn_throw("unsupported dtype in remap backward naive\n"); | |||||
| } | |||||
| } | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -23,6 +23,33 @@ class RemapImpl final : public Remap { | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| }; | }; | ||||
| class RemapBackwardDataImpl final : public RemapBackwardData { | |||||
| public: | |||||
| using RemapBackwardData::RemapBackwardData; | |||||
| void exec(_megdnn_tensor_in map_xy, _megdnn_tensor_in diff, | |||||
| _megdnn_tensor_out grad, _megdnn_workspace workspace) override; | |||||
| size_t get_workspace_in_bytes(const TensorLayout&, | |||||
| const TensorLayout&, | |||||
| const TensorLayout&) override { | |||||
| return 0; | |||||
| } | |||||
| }; | |||||
| class RemapBackwardMatImpl final : public RemapBackwardMat { | |||||
| public: | |||||
| using RemapBackwardMat::RemapBackwardMat; | |||||
| void exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy, | |||||
| _megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||||
| _megdnn_workspace workspace) override; | |||||
| size_t get_workspace_in_bytes(const TensorLayout&, | |||||
| const TensorLayout&, | |||||
| const TensorLayout&, | |||||
| const TensorLayout&) override { | |||||
| return 0; | |||||
| } | |||||
| }; | |||||
| } // namespace naive | } // namespace naive | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -106,6 +106,8 @@ DEF(DeformablePSROIPoolingForward, 5, true, true); | |||||
| DEF(DeformablePSROIPoolingBackward, 7, true, false); | DEF(DeformablePSROIPoolingBackward, 7, true, false); | ||||
| DEF(BatchConvBiasForward, 5, true, true); | DEF(BatchConvBiasForward, 5, true, true); | ||||
| DEF(Remap, 3, true, true); | DEF(Remap, 3, true, true); | ||||
| DEF(RemapBackwardData, 3, true, false); | |||||
| DEF(RemapBackwardMat, 4, true, false); | |||||
| } // namespace test | } // namespace test | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -46,6 +46,9 @@ static inline std::vector<TestArg> get_nchw_args() { | |||||
| for (auto border_type : border_mode_vec) { | for (auto border_type : border_mode_vec) { | ||||
| param.format = fmt; | param.format = fmt; | ||||
| param.border_type = border_type; | param.border_type = border_type; | ||||
| args.emplace_back(param, TensorShape{70000, 1, 2, 2}, | |||||
| TensorShape{70000, 2, 2, 2}, TensorShape{70000, 1, 2, 2}); | |||||
| args.emplace_back(param, TensorShape{1, 1, 2, 2}, | args.emplace_back(param, TensorShape{1, 1, 2, 2}, | ||||
| TensorShape{1, 2, 2, 2}, TensorShape{1, 1, 2, 2}); | TensorShape{1, 2, 2, 2}, TensorShape{1, 1, 2, 2}); | ||||
| @@ -90,6 +93,9 @@ static inline std::vector<TestArg> get_nhwc_args() { | |||||
| param.format = fmt; | param.format = fmt; | ||||
| param.border_type = border_type; | param.border_type = border_type; | ||||
| param.scalar = 12.f; | param.scalar = 12.f; | ||||
| args.emplace_back(param, TensorShape{70000, 2, 2, 1}, | |||||
| TensorShape{70000, 2, 2, 2}, TensorShape{70000, 2, 2, 1}); | |||||
| args.emplace_back(param, TensorShape{1, 2, 2, 1}, | args.emplace_back(param, TensorShape{1, 2, 2, 1}, | ||||
| TensorShape{1, 2, 2, 2}, TensorShape{1, 2, 2, 1}); | TensorShape{1, 2, 2, 2}, TensorShape{1, 2, 2, 1}); | ||||
| @@ -40,6 +40,22 @@ TEST_F(CUDA, REMAP_NCHW_FLOAT) { | |||||
| cb(dtype::Float32(), float_rng); | cb(dtype::Float32(), float_rng); | ||||
| cb(dtype::Float16(), float_rng); | cb(dtype::Float16(), float_rng); | ||||
| #undef cb | #undef cb | ||||
| #define cb(data_type, data_rng) \ | |||||
| for (auto arg : args) { \ | |||||
| UniformFloatRNG map_rng( \ | |||||
| -2, std::max(arg.map_xy.shape[2], arg.map_xy.shape[1]) + 2); \ | |||||
| checker.set_dtype(0, data_type) \ | |||||
| .set_dtype(1, dtype::Float32()) \ | |||||
| .set_dtype(2, data_type) \ | |||||
| .set_rng(0, &data_rng) \ | |||||
| .set_rng(1, &map_rng) \ | |||||
| .set_rng(2, &data_rng) \ | |||||
| .set_param(arg.param) \ | |||||
| .set_epsilon(1e-2) \ | |||||
| .execs({arg.src, arg.map_xy, arg.dst}); \ | |||||
| } | |||||
| cb(dtype::BFloat16(), float_rng); | |||||
| #undef cb | |||||
| } | } | ||||
| TEST_F(CUDA, REMAP_NCHW_INT) { | TEST_F(CUDA, REMAP_NCHW_INT) { | ||||
| @@ -87,6 +103,22 @@ TEST_F(CUDA, REMAP_NHWC_FLOAT) { | |||||
| cb(dtype::Float32(), float_rng); | cb(dtype::Float32(), float_rng); | ||||
| cb(dtype::Float16(), float_rng); | cb(dtype::Float16(), float_rng); | ||||
| #undef cb | #undef cb | ||||
| #define cb(data_type, data_rng) \ | |||||
| for (auto arg : args) { \ | |||||
| UniformFloatRNG map_rng( \ | |||||
| -2, std::max(arg.map_xy.shape[2], arg.map_xy.shape[1]) + 2); \ | |||||
| checker.set_dtype(0, data_type) \ | |||||
| .set_dtype(1, dtype::Float32()) \ | |||||
| .set_dtype(2, data_type) \ | |||||
| .set_rng(0, &data_rng) \ | |||||
| .set_rng(1, &map_rng) \ | |||||
| .set_rng(2, &data_rng) \ | |||||
| .set_param(arg.param) \ | |||||
| .set_epsilon(1e-2) \ | |||||
| .execs({arg.src, arg.map_xy, arg.dst}); \ | |||||
| } | |||||
| cb(dtype::BFloat16(), float_rng); | |||||
| #undef cb | |||||
| } | } | ||||
| TEST_F(CUDA, REMAP_NHWC_INT) { | TEST_F(CUDA, REMAP_NHWC_INT) { | ||||
| @@ -114,6 +146,85 @@ TEST_F(CUDA, REMAP_NHWC_INT) { | |||||
| #undef cb | #undef cb | ||||
| } | } | ||||
| TEST_F(CUDA, REMAP_BACKWARD_DATA) { | |||||
| Checker<RemapBackwardData> checker(handle_cuda()); | |||||
| std::vector<TestArg> args = get_nchw_args(); | |||||
| UniformFloatRNG float_rng(0, 255); | |||||
| #define cb(data_type, data_rng) \ | |||||
| for (auto arg : args) { \ | |||||
| UniformFloatRNG map_rng( \ | |||||
| -2, std::max(arg.map_xy.shape[2], arg.map_xy.shape[1]) + 2); \ | |||||
| checker.set_dtype(1, data_type) \ | |||||
| .set_dtype(0, dtype::Float32()) \ | |||||
| .set_dtype(2, data_type) \ | |||||
| .set_rng(1, &data_rng) \ | |||||
| .set_rng(0, &map_rng) \ | |||||
| .set_rng(2, &data_rng) \ | |||||
| .set_param(arg.param) \ | |||||
| .execs({arg.map_xy, arg.dst, arg.src}); \ | |||||
| } | |||||
| cb(dtype::Float32(), float_rng); | |||||
| #undef cb | |||||
| #define cb(data_type, data_rng) \ | |||||
| for (auto arg : args) { \ | |||||
| UniformFloatRNG map_rng( \ | |||||
| -2, std::max(arg.map_xy.shape[2], arg.map_xy.shape[1]) + 2); \ | |||||
| checker.set_dtype(1, data_type) \ | |||||
| .set_dtype(0, dtype::Float32()) \ | |||||
| .set_dtype(2, data_type) \ | |||||
| .set_rng(1, &data_rng) \ | |||||
| .set_rng(0, &map_rng) \ | |||||
| .set_rng(2, &data_rng) \ | |||||
| .set_param(arg.param) \ | |||||
| .set_epsilon(1e-1) \ | |||||
| .execs({arg.map_xy, arg.dst, arg.src}); \ | |||||
| } | |||||
| cb(dtype::BFloat16(), float_rng); | |||||
| #undef cb | |||||
| } | |||||
| TEST_F(CUDA, REMAP_BACKWARD_MAT) { | |||||
| Checker<RemapBackwardMat> checker(handle_cuda()); | |||||
| std::vector<TestArg> args = get_nchw_args(); | |||||
| UniformFloatRNG float_rng(0, 255); | |||||
| #define cb(data_type, data_rng) \ | |||||
| for (auto arg : args) { \ | |||||
| UniformFloatRNG map_rng( \ | |||||
| -2, std::max(arg.map_xy.shape[2], arg.map_xy.shape[1]) + 2); \ | |||||
| checker.set_dtype(0, data_type) \ | |||||
| .set_dtype(1, dtype::Float32()) \ | |||||
| .set_dtype(2, data_type) \ | |||||
| .set_dtype(3, dtype::Float32()) \ | |||||
| .set_rng(0, &data_rng) \ | |||||
| .set_rng(1, &map_rng) \ | |||||
| .set_rng(2, &data_rng) \ | |||||
| .set_rng(3, &map_rng) \ | |||||
| .set_param(arg.param) \ | |||||
| .set_epsilon(2e-2) \ | |||||
| .execs({arg.src, arg.map_xy, arg.dst, arg.map_xy}); \ | |||||
| } | |||||
| cb(dtype::Float32(), float_rng); | |||||
| #undef cb | |||||
| #define cb(data_type, data_rng) \ | |||||
| for (auto arg : args) { \ | |||||
| UniformFloatRNG map_rng( \ | |||||
| -2, std::max(arg.map_xy.shape[2], arg.map_xy.shape[1]) + 2); \ | |||||
| checker.set_dtype(0, data_type) \ | |||||
| .set_dtype(1, dtype::Float32()) \ | |||||
| .set_dtype(2, data_type) \ | |||||
| .set_dtype(3, dtype::Float32()) \ | |||||
| .set_rng(0, &data_rng) \ | |||||
| .set_rng(1, &map_rng) \ | |||||
| .set_rng(2, &data_rng) \ | |||||
| .set_rng(3, &map_rng) \ | |||||
| .set_param(arg.param) \ | |||||
| .set_epsilon(1e-1) \ | |||||
| .execs({arg.src, arg.map_xy, arg.dst, arg.map_xy}); \ | |||||
| } | |||||
| cb(dtype::BFloat16(), float_rng); | |||||
| #undef cb | |||||
| } | |||||
| #if MEGDNN_WITH_BENCHMARK | #if MEGDNN_WITH_BENCHMARK | ||||
| TEST_F(CUDA, BENCHMARK_REMAP) { | TEST_F(CUDA, BENCHMARK_REMAP) { | ||||
| @@ -144,13 +255,31 @@ TEST_F(CUDA, BENCHMARK_REMAP) { | |||||
| .execs(shapes); | .execs(shapes); | ||||
| auto t2 = benchmarker_cuda.set_display(false).set_param(param).execs( | auto t2 = benchmarker_cuda.set_display(false).set_param(param).execs( | ||||
| shapes); | shapes); | ||||
| int size = 0; | |||||
| if (dtype == dtype::Float32{}) { | |||||
| size = sizeof(float); | |||||
| printf("float32: "); | |||||
| } else if (dtype == dtype::Float16{}) { | |||||
| size = sizeof(dt_float16); | |||||
| printf("float16: "); | |||||
| } else if (dtype == dtype::Int8{}) { | |||||
| size = sizeof(dt_int8); | |||||
| printf("int8: "); | |||||
| } else if (dtype == dtype::Uint8{}) { | |||||
| size = sizeof(dt_uint8); | |||||
| printf("uint8: "); | |||||
| } | |||||
| const TensorShape map_xy = shapes[1]; | |||||
| const TensorShape dst_layout = shapes[2]; | const TensorShape dst_layout = shapes[2]; | ||||
| float calc_amount = dst_layout.total_nr_elems(); | |||||
| printf("naive={%.3fms, %.3fMflops}, " | |||||
| "cuda={%.3fms, %.3fMflops}\n", | |||||
| t1 / RUN, calc_amount / (t1 / RUN * 1000), t2, | |||||
| calc_amount / (t2 * 1000)); | |||||
| float calc_amount = (dst_layout.total_nr_elems() * (4.f + 1.f) * size + | |||||
| map_xy.total_nr_elems() * sizeof(float)) / | |||||
| (1024 * 1024 * 1024); | |||||
| printf("naive={%.3fms, %.3fGBPS}, " | |||||
| "cuda={%.3fms, %.3fGBPS}\n", | |||||
| t1 / RUN, calc_amount / (t1 / RUN) * 1e3, t2, | |||||
| calc_amount / t2 * 1e3); | |||||
| }; | }; | ||||
| Param param; | Param param; | ||||
| param.imode = param::Remap::InterpolationMode::LINEAR; | param.imode = param::Remap::InterpolationMode::LINEAR; | ||||
| @@ -84,6 +84,7 @@ from .nn import ( | |||||
| max_pool2d, | max_pool2d, | ||||
| one_hot, | one_hot, | ||||
| prelu, | prelu, | ||||
| remap, | |||||
| roi_align, | roi_align, | ||||
| roi_pooling, | roi_pooling, | ||||
| softmax, | softmax, | ||||
| @@ -705,6 +705,61 @@ def warp_perspective( | |||||
| ) | ) | ||||
| @wrap_io_tensor | |||||
| def remap( | |||||
| inp: Tensor, | |||||
| map_xy: Tensor, | |||||
| border_mode: str = "REPLICATE", | |||||
| scalar: float = 0.0, | |||||
| interp_mode: str = "LINEAR", | |||||
| ) -> Tensor: | |||||
| r""" | |||||
| Applies remap transformation to batched 2D images. | |||||
| The input images are transformed to the output images by the tensor map_xy. | |||||
| The output's H and W are same as map_xy's H and W. | |||||
| :param inp: input image | |||||
| :param map_xy: (batch, oh, ow, 2) transformation matrix | |||||
| :param border_mode: pixel extrapolation method. Default: ``"REPLICATE"`` | |||||
| :param scalar: value used in case of a constant border. Default: ``0`` | |||||
| :param interp_mode: interpolation methods. Default: ``"LINEAR"`` | |||||
| Examples: | |||||
| .. testcode:: | |||||
| import numpy as np | |||||
| from megengine import tensor | |||||
| import megengine.functional as F | |||||
| inp_shape = (1, 1, 4, 4) | |||||
| inp = tensor(np.arange(16, dtype=np.float32).reshape(inp_shape)) | |||||
| map_xy_shape = (1, 2, 2, 2) | |||||
| map_xy = tensor(np.array([[[1., 0.],[0., 1.]], | |||||
| [[0., 1.],[0., 1.]]], | |||||
| dtype=np.float32).reshape(map_xy_shape)) | |||||
| out = F.remap(inp, map_xy) | |||||
| print(out.numpy()) | |||||
| Outputs: | |||||
| .. testoutput:: | |||||
| [[[[1. 4.] | |||||
| [4. 4.]]]] | |||||
| """ | |||||
| return mgb.opr.remap( | |||||
| inp, | |||||
| map_xy, | |||||
| border_type=border_mode, | |||||
| scalar=scalar, | |||||
| imode=interp_mode, | |||||
| format="NCHW", | |||||
| ) | |||||
| @wrap_io_tensor | @wrap_io_tensor | ||||
| def eye( | def eye( | ||||
| n: int, | n: int, | ||||
| @@ -443,4 +443,29 @@ void RemapForward::init_output_dtype() { | |||||
| output(0)->dtype(input(0)->dtype()); | output(0)->dtype(input(0)->dtype()); | ||||
| } | } | ||||
| #ifdef MGB_ENABLE_GRAD | |||||
| MGB_IMPL_OPR_GRAD(RemapForward) { | |||||
| mgb_assert(opr.input().size() == 2); | |||||
| if (wrt_idx == 0) { | |||||
| SymbolVar grad = | |||||
| RemapBackwardData::make(opr.input(1), out_grad[0], | |||||
| opr.input(0), opr.param()); | |||||
| return grad.node(); | |||||
| } else if (wrt_idx == 1) { | |||||
| SymbolVar grad = | |||||
| RemapBackwardMat::make(opr.input(0), opr.input(1), | |||||
| out_grad[0], opr.param()); | |||||
| return grad.node(); | |||||
| } else | |||||
| return InvalidGrad::make(opr, wrt_idx); | |||||
| } | |||||
| #endif | |||||
| /* ====================== RemapBackward ====================== */ | |||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemapBackwardData); | |||||
| MEGDNN_OPR_INIT3(RemapBackwardData, "remap_bwd_data", 2, false); | |||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(RemapBackwardMat); | |||||
| MEGDNN_OPR_INIT3(RemapBackwardMat, "remap_bwd_mat", 1, true); | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| @@ -97,6 +97,8 @@ namespace opr { | |||||
| MGB_SEREG_OPR(ResizeBackward, 2); | MGB_SEREG_OPR(ResizeBackward, 2); | ||||
| MGB_SEREG_OPR(Remap, 2); | MGB_SEREG_OPR(Remap, 2); | ||||
| MGB_SEREG_OPR(RemapBackwardData, 3); | |||||
| MGB_SEREG_OPR(RemapBackwardMat, 3); | |||||
| //! current warp affine version | //! current warp affine version | ||||
| using WarpAffineV1 = opr::WarpAffine; | using WarpAffineV1 = opr::WarpAffine; | ||||
| @@ -74,7 +74,7 @@ size_t get_workspace_size_bytes( | |||||
| const TensorShapeArray& output_shapes) const override; | const TensorShapeArray& output_shapes) const override; | ||||
| void record_execute_deps(ExecDependencyArray& deps) override; | void record_execute_deps(ExecDependencyArray& deps) override; | ||||
| }; // namespace opr | |||||
| }; | |||||
| using WarpPerspective = WarpPerspectiveForward; | using WarpPerspective = WarpPerspectiveForward; | ||||
| MGB_DEFINE_OPR_CLASS( | MGB_DEFINE_OPR_CLASS( | ||||
| @@ -98,7 +98,7 @@ static SymbolVar make(SymbolVar mat, SymbolVar mat_idx, SymbolVar out_diff, | |||||
| const OperatorNodeConfig& config = {}); | const OperatorNodeConfig& config = {}); | ||||
| void scn_do_execute() override; | void scn_do_execute() override; | ||||
| }; // namespace mgb | |||||
| }; | |||||
| MGB_DEFINE_OPR_CLASS( | MGB_DEFINE_OPR_CLASS( | ||||
| WarpPerspectiveBackwardMat, | WarpPerspectiveBackwardMat, | ||||
| @@ -119,8 +119,7 @@ static SymbolVar make(SymbolVar src, SymbolVar mat, SymbolVar mat_idx, | |||||
| const OperatorNodeConfig& config = {}); | const OperatorNodeConfig& config = {}); | ||||
| void scn_do_execute() override; | void scn_do_execute() override; | ||||
| } | |||||
| ; | |||||
| }; | |||||
| /* ============================= shape infer ============================== */ | /* ============================= shape infer ============================== */ | ||||
| //! param: src, dst | //! param: src, dst | ||||
| @@ -164,8 +163,7 @@ size_t get_workspace_size_bytes( | |||||
| const TensorShapeArray& input_shapes, | const TensorShapeArray& input_shapes, | ||||
| const TensorShapeArray& output_shapes) const override; | const TensorShapeArray& output_shapes) const override; | ||||
| void record_execute_deps(ExecDependencyArray& deps) override; | void record_execute_deps(ExecDependencyArray& deps) override; | ||||
| } | |||||
| ; | |||||
| }; | |||||
| using Resize = ResizeForward; | using Resize = ResizeForward; | ||||
| MGB_DEFINE_OPR_CLASS(ResizeBackward, | MGB_DEFINE_OPR_CLASS(ResizeBackward, | ||||
| @@ -177,8 +175,7 @@ ResizeBackward(VarNode* out_diff, VarNode* in_for_shape, const Param& param, | |||||
| static SymbolVar make(SymbolVar out_diff, SymbolVar in_for_shape, | static SymbolVar make(SymbolVar out_diff, SymbolVar in_for_shape, | ||||
| const Param& param = {}, | const Param& param = {}, | ||||
| const OperatorNodeConfig& config = {}); | const OperatorNodeConfig& config = {}); | ||||
| } | |||||
| ; | |||||
| }; | |||||
| MGB_DEFINE_OPR_CLASS(RemapForward, | MGB_DEFINE_OPR_CLASS(RemapForward, | ||||
| intl::MegDNNOprWrapperFwd<megdnn::RemapForward>) // { | intl::MegDNNOprWrapperFwd<megdnn::RemapForward>) // { | ||||
| @@ -192,10 +189,31 @@ static SymbolVar make(SymbolVar in_tensor, SymbolVar map, | |||||
| private: | private: | ||||
| void init_output_dtype() override; | void init_output_dtype() override; | ||||
| } | |||||
| ; | |||||
| }; | |||||
| using Remap = RemapForward; | using Remap = RemapForward; | ||||
| MGB_DEFINE_OPR_CLASS(RemapBackwardData, | |||||
| intl::MegDNNOprWrapperBwd<megdnn::RemapBackwardData>) // { | |||||
| public: | |||||
| RemapBackwardData(VarNode *map, VarNode *out_diff, | |||||
| VarNode *in_for_shape, const Param ¶m, | |||||
| const OperatorNodeConfig &config); | |||||
| static SymbolVar make(SymbolVar map, SymbolVar out_diff, | |||||
| SymbolVar in_for_shape, const Param ¶m = {}, | |||||
| const OperatorNodeConfig &config = {}); | |||||
| }; | |||||
| MGB_DEFINE_OPR_CLASS(RemapBackwardMat, | |||||
| intl::MegDNNOprWrapperBwd<megdnn::RemapBackwardMat>) // { | |||||
| public: | |||||
| RemapBackwardMat(VarNode *src, VarNode *map, VarNode *out_diff, | |||||
| const Param ¶m, const OperatorNodeConfig &config); | |||||
| static SymbolVar make(SymbolVar src, SymbolVar map, SymbolVar out_diff, | |||||
| const Param ¶m = {}, const OperatorNodeConfig &config = {}); | |||||
| }; | |||||
| /*! | /*! | ||||
| * \brief apply affine transformation to batched 2D images | * \brief apply affine transformation to batched 2D images | ||||
| * | * | ||||
| @@ -238,8 +256,7 @@ size_t get_workspace_size_bytes( | |||||
| const TensorShapeArray& input_shapes, | const TensorShapeArray& input_shapes, | ||||
| const TensorShapeArray& output_shapes) const override; | const TensorShapeArray& output_shapes) const override; | ||||
| void record_execute_deps(ExecDependencyArray& deps) override; | void record_execute_deps(ExecDependencyArray& deps) override; | ||||
| } | |||||
| ; | |||||
| }; | |||||
| using WarpAffine = WarpAffineForward; | using WarpAffine = WarpAffineForward; | ||||
| } // opr | } // opr | ||||
| @@ -640,11 +640,11 @@ TEST(TestOprImgproc, WarpAffineForward) { | |||||
| } | } | ||||
| TEST(TestOprImgproc, Remap_NCHW) { | TEST(TestOprImgproc, Remap_NCHW) { | ||||
| constexpr size_t N = 2, C = 8; | |||||
| constexpr size_t N = 2, C = 8, OH = 10, OW = 10; | |||||
| opr::Remap::Param param; | opr::Remap::Param param; | ||||
| using Checker = AutoOprChecker<2, 1>; | using Checker = AutoOprChecker<2, 1>; | ||||
| TensorShape out_shp{N, C, 10, 10}; | |||||
| TensorShape out_shp{N, C, OH, OW}; | |||||
| param.format = opr::Remap::Param::Format::NCHW; | param.format = opr::Remap::Param::Format::NCHW; | ||||
| auto make_graph = [&](const Checker::SymInpArray &inputs) -> | auto make_graph = [&](const Checker::SymInpArray &inputs) -> | ||||
| Checker::SymOutArray { | Checker::SymOutArray { | ||||
| @@ -657,12 +657,34 @@ TEST(TestOprImgproc, Remap_NCHW) { | |||||
| opr->exec(inp[0]->as_megdnn(), inp[1]->as_megdnn(), dest[0].as_megdnn(), {}); | opr->exec(inp[0]->as_megdnn(), inp[1]->as_megdnn(), dest[0].as_megdnn(), {}); | ||||
| }; | }; | ||||
| std::mt19937 rng(next_rand_seed()); | |||||
| auto rand_real = [&](double lo, double hi) { | |||||
| auto real = rng() / (std::mt19937::max() + 1.0) * (hi - lo) + lo; | |||||
| if(std::abs(std::round(real) - real) <= 1e-2) | |||||
| return real + 1e-1; | |||||
| return real; | |||||
| }; | |||||
| auto rand_real2 = [&](double range) { | |||||
| return rand_real(-range, range); | |||||
| }; | |||||
| auto gen_mat = [&](HostTensorND& mat) { | |||||
| auto ptr = mat.ptr<float>(); | |||||
| for (size_t i = 0; i < N; ++ i) { | |||||
| for(size_t j = 0; j < OH * OW * 2; j++) { | |||||
| //! undifferentiable when map is an integer | |||||
| ptr[j] = static_cast<float>(rand_real2(20)); | |||||
| } | |||||
| ptr += OH * OW * 2; | |||||
| } | |||||
| mgb_assert(ptr == mat.ptr<float>() + mat.shape().total_nr_elems()); | |||||
| }; | |||||
| Checker::RunOptions opt; | Checker::RunOptions opt; | ||||
| Checker(make_graph, fwd, CompNode::load("cpu1")) | Checker(make_graph, fwd, CompNode::load("cpu1")) | ||||
| .disable_grad_check() | |||||
| .run({TensorShape{N, C, 3, 20}, TensorShape{N, 10, 10, 2}}, opt) | |||||
| .run({TensorShape{N, C, 6, 5}, TensorShape{N, 10, 10, 2}}, opt) | |||||
| .run({TensorShape{N, C, 20, 20}, TensorShape{N, 10, 10, 2}}, opt); | |||||
| .set_input_generator(1, gen_mat) | |||||
| .run({TensorShape{N, C, 3, 20}, TensorShape{N, OH, OW, 2}}, opt) | |||||
| .run({TensorShape{N, C, 6, 5}, TensorShape{N, OH, OW, 2}}, opt) | |||||
| .run({TensorShape{N, C, 20, 20}, TensorShape{N, OH, OW, 2}}, opt); | |||||
| } | } | ||||
| TEST(TestOprImgproc, Remap_NHWC) { | TEST(TestOprImgproc, Remap_NHWC) { | ||||
| @@ -690,4 +712,5 @@ TEST(TestOprImgproc, Remap_NHWC) { | |||||
| .run({TensorShape{N, 6, 5, C}, TensorShape{N, 10, 10, 2}}, opt) | .run({TensorShape{N, 6, 5, C}, TensorShape{N, 10, 10, 2}}, opt) | ||||
| .run({TensorShape{N, 20, 20, C}, TensorShape{N, 10, 10, 2}}, opt); | .run({TensorShape{N, 20, 20, C}, TensorShape{N, 10, 10, 2}}, opt); | ||||
| } | } | ||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||