GitOrigin-RevId: 5f66d51de4
tags/v0.5.0
| @@ -32,9 +32,11 @@ void IndexingRemapBase::check_layout_fwd(const TensorLayout &src, | |||||
| } | } | ||||
| megdnn_assert(map.shape[dst.ndim] == src.ndim, "%s", errmsg_c); | megdnn_assert(map.shape[dst.ndim] == src.ndim, "%s", errmsg_c); | ||||
| megdnn_assert(src.dtype == dtype::Float32()); | |||||
| megdnn_assert(dst.dtype == src.dtype); | |||||
| megdnn_assert(src.dtype == dtype::Float32() || src.dtype == dtype::Int32(), | |||||
| "indexing remap only support float32/int32, got %s", | |||||
| src.dtype.name()); | |||||
| megdnn_assert(map.dtype == dtype::Int32()); | megdnn_assert(map.dtype == dtype::Int32()); | ||||
| megdnn_assert(dst.dtype == dtype::Float32()); | |||||
| } | } | ||||
| void IndexingRemapForward::deduce_layout(const TensorLayout &src, | void IndexingRemapForward::deduce_layout(const TensorLayout &src, | ||||
| @@ -36,13 +36,23 @@ void IndexingRemapForwardImpl::exec(_megdnn_tensor_in src, | |||||
| for (size_t i = 0_z; i < dst.layout.ndim; ++i) { | for (size_t i = 0_z; i < dst.layout.ndim; ++i) { | ||||
| dshape.data[i] = dst.layout.shape[i]; | dshape.data[i] = dst.layout.shape[i]; | ||||
| } | } | ||||
| // Invoke kernel | |||||
| tensor_remap::forward(src.ptr<dt_float32>(), | |||||
| map.ptr<dt_int32>(), | |||||
| dst.ptr<dt_float32>(), | |||||
| src.layout.ndim, dst.layout.ndim, | |||||
| sstride, dstride, dshape, | |||||
| cuda_stream(handle())); | |||||
| // Invoke kernel | |||||
| #define cb(dt) \ | |||||
| if (src.layout.dtype.enumv() == DTypeTrait<dt>::enumv) { \ | |||||
| using ctype = DTypeTrait<dt>::ctype; \ | |||||
| tensor_remap::forward<ctype>(src.ptr<ctype>(), map.ptr<dt_int32>(), \ | |||||
| dst.ptr<ctype>(), src.layout.ndim, \ | |||||
| dst.layout.ndim, sstride, dstride, \ | |||||
| dshape, cuda_stream(handle())); \ | |||||
| return; \ | |||||
| } | |||||
| cb(dtype::Float32) | |||||
| cb(dtype::Int32) | |||||
| #undef cb | |||||
| megdnn_throw( | |||||
| ssprintf("cuda indexing remap forward only support " | |||||
| "float32/int32 dtype, got %s", | |||||
| src.layout.dtype.name())); | |||||
| } | } | ||||
| void IndexingRemapBackwardImpl::exec(_megdnn_tensor_in diff, | void IndexingRemapBackwardImpl::exec(_megdnn_tensor_in diff, | ||||
| @@ -69,18 +79,27 @@ void IndexingRemapBackwardImpl::exec(_megdnn_tensor_in diff, | |||||
| for (size_t i = 0_z; i < diff.layout.ndim; ++i) { | for (size_t i = 0_z; i < diff.layout.ndim; ++i) { | ||||
| dshape.data[i] = diff.layout.shape[i]; | dshape.data[i] = diff.layout.shape[i]; | ||||
| } | } | ||||
| // Invoke kernel | |||||
| tensor_remap::backward(diff.ptr<dt_float32>(), | |||||
| map.ptr<dt_int32>(), | |||||
| grad.ptr<dt_float32>(), | |||||
| grad.layout.ndim, diff.layout.ndim, | |||||
| sstride, dstride, sshape, dshape, | |||||
| param().is_non_overlapping, | |||||
| cuda_stream(handle())); | |||||
| // Invoke kernel | |||||
| #define cb(dt) \ | |||||
| if (diff.layout.dtype.enumv() == DTypeTrait<dt>::enumv) { \ | |||||
| using ctype = DTypeTrait<dt>::ctype; \ | |||||
| tensor_remap::backward<ctype>( \ | |||||
| diff.ptr<ctype>(), map.ptr<dt_int32>(), grad.ptr<ctype>(), \ | |||||
| grad.layout.ndim, diff.layout.ndim, sstride, dstride, sshape, \ | |||||
| dshape, param().is_non_overlapping, cuda_stream(handle())); \ | |||||
| return; \ | |||||
| } | |||||
| cb(dtype::Float32) | |||||
| cb(dtype::Int32) | |||||
| megdnn_throw( | |||||
| ssprintf("cuda indexing remap forward only support " | |||||
| "float32/int32 dtype, got %s", | |||||
| diff.layout.dtype.name())); | |||||
| } | } | ||||
| } // namespace cuda | } // namespace cuda | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -6,28 +6,29 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #include "src/cuda/tensor_remap/tensor_remap.cuh" | |||||
| #include "src/cuda/query_blocksize.cuh" | #include "src/cuda/query_blocksize.cuh" | ||||
| #include "src/cuda/tensor_remap/tensor_remap.cuh" | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace cuda { | namespace cuda { | ||||
| namespace tensor_remap { | |||||
| namespace { | |||||
| __global__ void forward_kernel(const float *src, const int *map, float *dst, | |||||
| uint32_t sdim, uint32_t ddim, | |||||
| array_wrapper<int, MEGDNN_MAX_NDIM> sstride, | |||||
| array_wrapper<int, MEGDNN_MAX_NDIM> dstride, | |||||
| array_wrapper<uint32_t, MEGDNN_MAX_NDIM> dshape, | |||||
| uint32_t total) | |||||
| { | |||||
| template <typename ctype> | |||||
| __global__ void forward_kernel(const ctype* src, const int* map, ctype* dst, | |||||
| uint32_t sdim, uint32_t ddim, | |||||
| array_wrapper<int, MEGDNN_MAX_NDIM> sstride, | |||||
| array_wrapper<int, MEGDNN_MAX_NDIM> dstride, | |||||
| array_wrapper<uint32_t, MEGDNN_MAX_NDIM> dshape, | |||||
| uint32_t total) { | |||||
| uint32_t didx_cont = threadIdx.x + blockIdx.x * blockDim.x; | uint32_t didx_cont = threadIdx.x + blockIdx.x * blockDim.x; | ||||
| if (didx_cont < total) { | if (didx_cont < total) { | ||||
| uint32_t midx = didx_cont * sdim; | uint32_t midx = didx_cont * sdim; | ||||
| uint32_t didx = 0u; | uint32_t didx = 0u; | ||||
| for (uint32_t j = ddim; j > 0u; --j) { | for (uint32_t j = ddim; j > 0u; --j) { | ||||
| uint32_t i = j-1u; | |||||
| uint32_t i = j - 1u; | |||||
| uint32_t didx_cur = didx_cont % dshape.data[i]; | uint32_t didx_cur = didx_cont % dshape.data[i]; | ||||
| didx_cont /= dshape.data[i]; | didx_cont /= dshape.data[i]; | ||||
| didx += didx_cur * dstride.data[i]; | didx += didx_cur * dstride.data[i]; | ||||
| @@ -41,34 +42,16 @@ __global__ void forward_kernel(const float *src, const int *map, float *dst, | |||||
| } | } | ||||
| } | } | ||||
| void forward(const float *src, const int *map, float *dst, | |||||
| uint32_t sdim, uint32_t ddim, | |||||
| const array_wrapper<int, MEGDNN_MAX_NDIM> &sstride, | |||||
| const array_wrapper<int, MEGDNN_MAX_NDIM> &dstride, | |||||
| const array_wrapper<uint32_t, MEGDNN_MAX_NDIM> &dshape, | |||||
| cudaStream_t stream) | |||||
| { | |||||
| uint32_t total = 1u; | |||||
| for (uint32_t i = 0u; i < ddim; ++i) total *= dshape.data[i]; | |||||
| uint32_t threads = query_blocksize_for_kernel((void *)&forward_kernel); | |||||
| uint32_t blocks = DIVUP(total, threads); | |||||
| forward_kernel<<<blocks, threads, 0, stream>>>(src, map, dst, | |||||
| sdim, ddim, | |||||
| sstride, dstride, dshape, | |||||
| total); | |||||
| after_kernel_launch(); | |||||
| } | |||||
| __global__ void fill_zero_kernel(float *a, uint32_t dim, | |||||
| array_wrapper<int, MEGDNN_MAX_NDIM> stride, | |||||
| array_wrapper<uint32_t, MEGDNN_MAX_NDIM> shape, | |||||
| uint32_t total) | |||||
| { | |||||
| template <typename ctype> | |||||
| __global__ void fill_zero_kernel(ctype* a, uint32_t dim, | |||||
| array_wrapper<int, MEGDNN_MAX_NDIM> stride, | |||||
| array_wrapper<uint32_t, MEGDNN_MAX_NDIM> shape, | |||||
| uint32_t total) { | |||||
| uint32_t idx_cont = threadIdx.x + blockIdx.x * blockDim.x; | uint32_t idx_cont = threadIdx.x + blockIdx.x * blockDim.x; | ||||
| if (idx_cont < total) { | if (idx_cont < total) { | ||||
| uint32_t idx = 0u; | uint32_t idx = 0u; | ||||
| for (uint32_t j = dim; j > 0u; --j) { | for (uint32_t j = dim; j > 0u; --j) { | ||||
| uint32_t i = j-1u; | |||||
| uint32_t i = j - 1u; | |||||
| uint32_t idx_cur = idx_cont % shape.data[i]; | uint32_t idx_cur = idx_cont % shape.data[i]; | ||||
| idx_cont /= shape.data[i]; | idx_cont /= shape.data[i]; | ||||
| idx += idx_cur * stride.data[i]; | idx += idx_cur * stride.data[i]; | ||||
| @@ -77,19 +60,19 @@ __global__ void fill_zero_kernel(float *a, uint32_t dim, | |||||
| } | } | ||||
| } | } | ||||
| __global__ void backward_kernel(const float *diff, const int *map, float *grad, | |||||
| uint32_t sdim, uint32_t ddim, | |||||
| array_wrapper<int, MEGDNN_MAX_NDIM> sstride, | |||||
| array_wrapper<int, MEGDNN_MAX_NDIM> dstride, | |||||
| array_wrapper<uint32_t, MEGDNN_MAX_NDIM> dshape, | |||||
| uint32_t total) | |||||
| { | |||||
| template <typename ctype> | |||||
| __global__ void backward_kernel(const ctype* diff, const int* map, ctype* grad, | |||||
| uint32_t sdim, uint32_t ddim, | |||||
| array_wrapper<int, MEGDNN_MAX_NDIM> sstride, | |||||
| array_wrapper<int, MEGDNN_MAX_NDIM> dstride, | |||||
| array_wrapper<uint32_t, MEGDNN_MAX_NDIM> dshape, | |||||
| uint32_t total) { | |||||
| uint32_t didx_cont = threadIdx.x + blockIdx.x * blockDim.x; | uint32_t didx_cont = threadIdx.x + blockIdx.x * blockDim.x; | ||||
| if (didx_cont < total) { | if (didx_cont < total) { | ||||
| uint32_t midx = didx_cont * sdim; | uint32_t midx = didx_cont * sdim; | ||||
| uint32_t didx = 0u; | uint32_t didx = 0u; | ||||
| for (uint32_t j = ddim; j > 0u; --j) { | for (uint32_t j = ddim; j > 0u; --j) { | ||||
| uint32_t i = j-1u; | |||||
| uint32_t i = j - 1u; | |||||
| uint32_t didx_cur = didx_cont % dshape.data[i]; | uint32_t didx_cur = didx_cont % dshape.data[i]; | ||||
| didx_cont /= dshape.data[i]; | didx_cont /= dshape.data[i]; | ||||
| didx += didx_cur * dstride.data[i]; | didx += didx_cur * dstride.data[i]; | ||||
| @@ -103,20 +86,18 @@ __global__ void backward_kernel(const float *diff, const int *map, float *grad, | |||||
| } | } | ||||
| } | } | ||||
| template <typename ctype> | |||||
| __global__ void backward_kernel_non_overlapping( | __global__ void backward_kernel_non_overlapping( | ||||
| const float *diff, const int *map, float *grad, | |||||
| uint32_t sdim, uint32_t ddim, | |||||
| array_wrapper<int, MEGDNN_MAX_NDIM> sstride, | |||||
| const ctype* diff, const int* map, ctype* grad, uint32_t sdim, | |||||
| uint32_t ddim, array_wrapper<int, MEGDNN_MAX_NDIM> sstride, | |||||
| array_wrapper<int, MEGDNN_MAX_NDIM> dstride, | array_wrapper<int, MEGDNN_MAX_NDIM> dstride, | ||||
| array_wrapper<uint32_t, MEGDNN_MAX_NDIM> dshape, | |||||
| uint32_t total) | |||||
| { | |||||
| array_wrapper<uint32_t, MEGDNN_MAX_NDIM> dshape, uint32_t total) { | |||||
| uint32_t didx_cont = threadIdx.x + blockIdx.x * blockDim.x; | uint32_t didx_cont = threadIdx.x + blockIdx.x * blockDim.x; | ||||
| if (didx_cont < total) { | if (didx_cont < total) { | ||||
| uint32_t midx = didx_cont * sdim; | uint32_t midx = didx_cont * sdim; | ||||
| uint32_t didx = 0u; | uint32_t didx = 0u; | ||||
| for (uint32_t j = ddim; j > 0u; --j) { | for (uint32_t j = ddim; j > 0u; --j) { | ||||
| uint32_t i = j-1u; | |||||
| uint32_t i = j - 1u; | |||||
| uint32_t didx_cur = didx_cont % dshape.data[i]; | uint32_t didx_cur = didx_cont % dshape.data[i]; | ||||
| didx_cont /= dshape.data[i]; | didx_cont /= dshape.data[i]; | ||||
| didx += didx_cur * dstride.data[i]; | didx += didx_cur * dstride.data[i]; | ||||
| @@ -130,55 +111,91 @@ __global__ void backward_kernel_non_overlapping( | |||||
| } | } | ||||
| } | } | ||||
| void backward(const float *diff, const int *map, float *grad, | |||||
| uint32_t sdim, uint32_t ddim, | |||||
| const array_wrapper<int, MEGDNN_MAX_NDIM> &sstride, | |||||
| const array_wrapper<int, MEGDNN_MAX_NDIM> &dstride, | |||||
| const array_wrapper<uint32_t, MEGDNN_MAX_NDIM> &sshape, | |||||
| const array_wrapper<uint32_t, MEGDNN_MAX_NDIM> &dshape, | |||||
| bool is_non_overlapping, | |||||
| cudaStream_t stream) | |||||
| { | |||||
| } // anonymous namespace | |||||
| namespace tensor_remap { | |||||
| template <typename ctype> | |||||
| void forward(const ctype* src, const int* map, ctype* dst, uint32_t sdim, | |||||
| uint32_t ddim, const array_wrapper<int, MEGDNN_MAX_NDIM>& sstride, | |||||
| const array_wrapper<int, MEGDNN_MAX_NDIM>& dstride, | |||||
| const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& dshape, | |||||
| cudaStream_t stream) { | |||||
| uint32_t total = 1u; | |||||
| for (uint32_t i = 0u; i < ddim; ++i) | |||||
| total *= dshape.data[i]; | |||||
| uint32_t threads = | |||||
| query_blocksize_for_kernel((void*)&forward_kernel<ctype>); | |||||
| uint32_t blocks = DIVUP(total, threads); | |||||
| forward_kernel<ctype><<<blocks, threads, 0, stream>>>( | |||||
| src, map, dst, sdim, ddim, sstride, dstride, dshape, total); | |||||
| after_kernel_launch(); | |||||
| } | |||||
| template <typename ctype> | |||||
| void backward(const ctype* diff, const int* map, ctype* grad, uint32_t sdim, | |||||
| uint32_t ddim, const array_wrapper<int, MEGDNN_MAX_NDIM>& sstride, | |||||
| const array_wrapper<int, MEGDNN_MAX_NDIM>& dstride, | |||||
| const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& sshape, | |||||
| const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& dshape, | |||||
| bool is_non_overlapping, cudaStream_t stream) { | |||||
| { | { | ||||
| // Fill grad with zeros. | // Fill grad with zeros. | ||||
| uint32_t total = 1u; | uint32_t total = 1u; | ||||
| for (uint32_t i = 0u; i < sdim; ++i) total *= sshape.data[i]; | |||||
| uint32_t threads = query_blocksize_for_kernel((void *)&fill_zero_kernel); | |||||
| for (uint32_t i = 0u; i < sdim; ++i) | |||||
| total *= sshape.data[i]; | |||||
| uint32_t threads = | |||||
| query_blocksize_for_kernel((void*)&fill_zero_kernel<ctype>); | |||||
| uint32_t blocks = DIVUP(total, threads); | uint32_t blocks = DIVUP(total, threads); | ||||
| fill_zero_kernel<<<blocks, threads, 0, stream>>>( | |||||
| fill_zero_kernel<ctype><<<blocks, threads, 0, stream>>>( | |||||
| grad, sdim, sstride, sshape, total); | grad, sdim, sstride, sshape, total); | ||||
| after_kernel_launch(); | after_kernel_launch(); | ||||
| } | } | ||||
| { | { | ||||
| // Update grad. | // Update grad. | ||||
| uint32_t total = 1u; | uint32_t total = 1u; | ||||
| for (uint32_t i = 0u; i < ddim; ++i) total *= dshape.data[i]; | |||||
| for (uint32_t i = 0u; i < ddim; ++i) | |||||
| total *= dshape.data[i]; | |||||
| if (is_non_overlapping) { | if (is_non_overlapping) { | ||||
| uint32_t threads = query_blocksize_for_kernel( | uint32_t threads = query_blocksize_for_kernel( | ||||
| (void *)&backward_kernel_non_overlapping); | |||||
| (void*)&backward_kernel_non_overlapping<ctype>); | |||||
| uint32_t blocks = DIVUP(total, threads); | uint32_t blocks = DIVUP(total, threads); | ||||
| backward_kernel_non_overlapping<<<blocks, threads, 0, stream>>>( | |||||
| diff, map, grad, | |||||
| sdim, ddim, | |||||
| sstride, dstride, dshape, | |||||
| total); | |||||
| backward_kernel_non_overlapping<ctype> | |||||
| <<<blocks, threads, 0, stream>>>(diff, map, grad, sdim, | |||||
| ddim, sstride, dstride, | |||||
| dshape, total); | |||||
| } else { | } else { | ||||
| uint32_t threads = query_blocksize_for_kernel( | |||||
| (void *)&backward_kernel); | |||||
| uint32_t threads = | |||||
| query_blocksize_for_kernel((void*)&backward_kernel<ctype>); | |||||
| uint32_t blocks = DIVUP(total, threads); | uint32_t blocks = DIVUP(total, threads); | ||||
| backward_kernel<<<blocks, threads, 0, stream>>>(diff, map, grad, | |||||
| sdim, ddim, | |||||
| sstride, dstride, dshape, | |||||
| backward_kernel<ctype><<<blocks, threads, 0, stream>>>( | |||||
| diff, map, grad, sdim, ddim, sstride, dstride, dshape, | |||||
| total); | total); | ||||
| } | } | ||||
| after_kernel_launch(); | after_kernel_launch(); | ||||
| } | } | ||||
| } | } | ||||
| } // namespace tensor_remap | |||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| #define INST(T) \ | |||||
| template void forward<T>( \ | |||||
| const T* src, const int* map, T* dst, uint32_t sdim, \ | |||||
| uint32_t ddim, const array_wrapper<int, MEGDNN_MAX_NDIM>& sstride, \ | |||||
| const array_wrapper<int, MEGDNN_MAX_NDIM>& dstride, \ | |||||
| const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& dshape, \ | |||||
| cudaStream_t stream); \ | |||||
| template void backward<T>( \ | |||||
| const T* diff, const int* map, T* grad, uint32_t sdim, \ | |||||
| uint32_t ddim, const array_wrapper<int, MEGDNN_MAX_NDIM>& sstride, \ | |||||
| const array_wrapper<int, MEGDNN_MAX_NDIM>& dstride, \ | |||||
| const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& sshape, \ | |||||
| const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& dshape, \ | |||||
| bool is_non_overlapping, cudaStream_t stream); | |||||
| INST(dt_float32) | |||||
| INST(dt_int32) | |||||
| // vim: syntax=cpp.doxygen | |||||
| #undef INST | |||||
| } // namespace tensor_remap | |||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -17,25 +17,23 @@ namespace megdnn { | |||||
| namespace cuda { | namespace cuda { | ||||
| namespace tensor_remap { | namespace tensor_remap { | ||||
| void forward(const float *src, const int *map, float *dst, | |||||
| uint32_t sdim, uint32_t ddim, | |||||
| const array_wrapper<int, MEGDNN_MAX_NDIM> &sstride, | |||||
| const array_wrapper<int, MEGDNN_MAX_NDIM> &dstride, | |||||
| const array_wrapper<uint32_t, MEGDNN_MAX_NDIM> &dshape, | |||||
| cudaStream_t stream); | |||||
| template <typename ctype> | |||||
| void forward(const ctype* src, const int* map, ctype* dst, uint32_t sdim, | |||||
| uint32_t ddim, const array_wrapper<int, MEGDNN_MAX_NDIM>& sstride, | |||||
| const array_wrapper<int, MEGDNN_MAX_NDIM>& dstride, | |||||
| const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& dshape, | |||||
| cudaStream_t stream); | |||||
| void backward(const float *diff, const int *map, float *grad, | |||||
| uint32_t sdim, uint32_t ddim, | |||||
| const array_wrapper<int, MEGDNN_MAX_NDIM> &sstride, | |||||
| const array_wrapper<int, MEGDNN_MAX_NDIM> &dstride, | |||||
| const array_wrapper<uint32_t, MEGDNN_MAX_NDIM> &sshape, | |||||
| const array_wrapper<uint32_t, MEGDNN_MAX_NDIM> &dshape, | |||||
| bool is_non_overlapping, | |||||
| cudaStream_t stream); | |||||
| template <typename ctype> | |||||
| void backward(const ctype* diff, const int* map, ctype* grad, uint32_t sdim, | |||||
| uint32_t ddim, const array_wrapper<int, MEGDNN_MAX_NDIM>& sstride, | |||||
| const array_wrapper<int, MEGDNN_MAX_NDIM>& dstride, | |||||
| const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& sshape, | |||||
| const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& dshape, | |||||
| bool is_non_overlapping, cudaStream_t stream); | |||||
| } // namespace tensor_remap | |||||
| } // namespace tensor_remap | |||||
| } // namespace cuda | } // namespace cuda | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -6,75 +6,107 @@ | |||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
| * software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
| * implied. | |||||
| */ | */ | ||||
| #include "src/naive/tensor_remap/opr_impl.h" | #include "src/naive/tensor_remap/opr_impl.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/naive/handle.h" | #include "src/naive/handle.h" | ||||
| namespace megdnn { | |||||
| namespace naive { | |||||
| using namespace megdnn; | |||||
| using namespace naive; | |||||
| namespace { | |||||
| template <typename ctype> | |||||
| void forward(const TensorND& src, const TensorND& map, const TensorND& dst) { | |||||
| auto&& sshape = src.layout; | |||||
| auto&& mshape = map.layout; | |||||
| auto&& dshape = dst.layout; | |||||
| // Last element is zero to facilitate maddr calculation. | |||||
| std::vector<size_t> didx(dshape.ndim + 1, 0_z); | |||||
| do { | |||||
| auto maddr = get_linear_addr(didx.data(), mshape.shape, mshape.ndim); | |||||
| std::vector<size_t> sidx(sshape.ndim); | |||||
| for (size_t i = 0_z; i < sshape.ndim; ++i) { | |||||
| sidx[i] = map.ptr<dt_int32>()[maddr + i]; | |||||
| } | |||||
| auto saddr = get_linear_addr_noncont(sidx.data(), src.layout); | |||||
| auto daddr = get_linear_addr_noncont(didx.data(), dst.layout); | |||||
| dst.ptr<ctype>()[daddr] = src.ptr<ctype>()[saddr]; | |||||
| } while (get_next_addr(didx.data(), dshape.shape, dshape.ndim)); | |||||
| } | |||||
| template <typename ctype> | |||||
| void backward(const TensorND& diff, const TensorND& map, const TensorND& grad) { | |||||
| auto&& sshape = grad.layout; | |||||
| auto&& mshape = map.layout; | |||||
| auto&& dshape = diff.layout; | |||||
| std::vector<size_t> sidx(sshape.ndim, 0_z); | |||||
| { | |||||
| // Set grad to zero. | |||||
| do { | |||||
| auto saddr = get_linear_addr_noncont(sidx.data(), grad.layout); | |||||
| grad.ptr<ctype>()[saddr] = 0.0f; | |||||
| } while (get_next_addr(sidx.data(), sshape.shape, sshape.ndim)); | |||||
| } | |||||
| std::vector<size_t> didx(dshape.ndim + 1, 0_z); | |||||
| do { | |||||
| auto maddr = get_linear_addr(didx.data(), mshape.shape, mshape.ndim); | |||||
| std::vector<size_t> sidx(sshape.ndim); | |||||
| for (size_t i = 0_z; i < sshape.ndim; ++i) { | |||||
| sidx[i] = map.ptr<dt_int32>()[maddr + i]; | |||||
| } | |||||
| auto saddr = get_linear_addr_noncont(sidx.data(), grad.layout); | |||||
| auto daddr = get_linear_addr_noncont(didx.data(), diff.layout); | |||||
| grad.ptr<ctype>()[saddr] += diff.ptr<ctype>()[daddr]; | |||||
| } while (get_next_addr(didx.data(), dshape.shape, dshape.ndim)); | |||||
| } | |||||
| } // anonymous namespace | |||||
| void IndexingRemapForwardImpl::exec(_megdnn_tensor_in src, | void IndexingRemapForwardImpl::exec(_megdnn_tensor_in src, | ||||
| _megdnn_tensor_in map, | |||||
| _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) | |||||
| { | |||||
| _megdnn_tensor_in map, | |||||
| _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) { | |||||
| check_exec(src.layout, map.layout, dst.layout, workspace.size); | check_exec(src.layout, map.layout, dst.layout, workspace.size); | ||||
| auto kern = [=]() { | |||||
| auto &&sshape = src.layout; | |||||
| auto &&mshape = map.layout; | |||||
| auto &&dshape = dst.layout; | |||||
| // Last element is zero to facilitate maddr calculation. | |||||
| std::vector<size_t> didx(dshape.ndim+1, 0_z); | |||||
| do { | |||||
| auto maddr = get_linear_addr(didx.data(), mshape.shape, mshape.ndim); | |||||
| std::vector<size_t> sidx(sshape.ndim); | |||||
| for (size_t i = 0_z; i < sshape.ndim; ++i) { | |||||
| sidx[i] = map.ptr<dt_int32>()[maddr+i]; | |||||
| } | |||||
| auto saddr = get_linear_addr_noncont(sidx.data(), src.layout); | |||||
| auto daddr = get_linear_addr_noncont(didx.data(), dst.layout); | |||||
| dst.ptr<dt_float32>()[daddr] = src.ptr<dt_float32>()[saddr]; | |||||
| } while (get_next_addr(didx.data(), dshape.shape, dshape.ndim)); | |||||
| }; | |||||
| MEGDNN_DISPATCH_CPU_KERN_OPR(kern()); | |||||
| switch (src.layout.dtype.enumv()) { | |||||
| #define cb(dt) \ | |||||
| case DTypeTrait<dt>::enumv: \ | |||||
| MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||||
| forward<DTypeTrait<dt>::ctype>(src, map, dst)); \ | |||||
| return; | |||||
| cb(dtype::Float32) | |||||
| cb(dtype::Int32) | |||||
| #undef cb | |||||
| default: | |||||
| megdnn_throw( | |||||
| ssprintf("unsupported dtype %s in indexing " | |||||
| "remap forward naive\n", | |||||
| src.layout.dtype.name())); | |||||
| } | |||||
| } | } | ||||
| void IndexingRemapBackwardImpl::exec(_megdnn_tensor_in diff, | void IndexingRemapBackwardImpl::exec(_megdnn_tensor_in diff, | ||||
| _megdnn_tensor_in map, | |||||
| _megdnn_tensor_out grad, | |||||
| _megdnn_workspace workspace) | |||||
| { | |||||
| _megdnn_tensor_in map, | |||||
| _megdnn_tensor_out grad, | |||||
| _megdnn_workspace workspace) { | |||||
| check_exec(diff.layout, map.layout, grad.layout, workspace.size); | check_exec(diff.layout, map.layout, grad.layout, workspace.size); | ||||
| auto kern = [=]() { | |||||
| auto &&sshape = grad.layout; | |||||
| auto &&mshape = map.layout; | |||||
| auto &&dshape = diff.layout; | |||||
| std::vector<size_t> sidx(sshape.ndim, 0_z); | |||||
| { | |||||
| // Set grad to zero. | |||||
| do { | |||||
| auto saddr = get_linear_addr_noncont(sidx.data(), grad.layout); | |||||
| grad.ptr<dt_float32>()[saddr] = 0.0f; | |||||
| } while (get_next_addr(sidx.data(), sshape.shape, sshape.ndim)); | |||||
| } | |||||
| std::vector<size_t> didx(dshape.ndim+1, 0_z); | |||||
| do { | |||||
| auto maddr = get_linear_addr(didx.data(), mshape.shape, mshape.ndim); | |||||
| std::vector<size_t> sidx(sshape.ndim); | |||||
| for (size_t i = 0_z; i < sshape.ndim; ++i) { | |||||
| sidx[i] = map.ptr<dt_int32>()[maddr+i]; | |||||
| } | |||||
| auto saddr = get_linear_addr_noncont(sidx.data(), grad.layout); | |||||
| auto daddr = get_linear_addr_noncont(didx.data(), diff.layout); | |||||
| grad.ptr<dt_float32>()[saddr] += diff.ptr<dt_float32>()[daddr]; | |||||
| } while (get_next_addr(didx.data(), dshape.shape, dshape.ndim)); | |||||
| }; | |||||
| MEGDNN_DISPATCH_CPU_KERN_OPR(kern()); | |||||
| switch (diff.layout.dtype.enumv()) { | |||||
| #define cb(dt) \ | |||||
| case DTypeTrait<dt>::enumv: \ | |||||
| MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||||
| backward<DTypeTrait<dt>::ctype>(diff, map, grad)); \ | |||||
| return; | |||||
| cb(dtype::Float32) | |||||
| cb(dtype::Int32) | |||||
| #undef cb | |||||
| default: | |||||
| megdnn_throw(ssprintf( | |||||
| "unsupported dtype %s in indexing remap backward naive\n", | |||||
| diff.layout.dtype.name())); | |||||
| } | |||||
| } | } | ||||
| } // namespace naive | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -16,39 +16,42 @@ | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace test { | namespace test { | ||||
| TEST_F(CUDA, TENSOR_REMAP_FORWARD) | |||||
| { | |||||
| TEST_F(CUDA, TENSOR_REMAP_FORWARD) { | |||||
| Checker<IndexingRemapForward> checker(handle_cuda()); | Checker<IndexingRemapForward> checker(handle_cuda()); | ||||
| TensorShape src{11, 13, 17}, map{3, 5, 7, 3}, dst{3, 5, 7}; | |||||
| checker.set_dtype(1, dtype::Int32()); | checker.set_dtype(1, dtype::Int32()); | ||||
| TensorShape src{11, 13, 17}, | |||||
| map{3, 5, 7, 3}, | |||||
| dst{3, 5, 7}; | |||||
| using namespace tensor_remap; | |||||
| { | |||||
| MapRNG rng(src); | |||||
| checker.set_rng(1, &rng).execs({src, map, {}}); | |||||
| } | |||||
| { | |||||
| NonoverlappingMapRNG rng(src); | |||||
| checker.set_rng(1, &rng).execs({src, map, {}}); | |||||
| for (auto dt : std::vector<DType>{dtype::Float32(), dtype::Int32()}) { | |||||
| checker.set_dtype(0, dt); | |||||
| checker.set_dtype(2, dt); | |||||
| using namespace tensor_remap; | |||||
| { | |||||
| MapRNG rng(src); | |||||
| checker.set_rng(1, &rng).execs({src, map, {}}); | |||||
| } | |||||
| { | |||||
| NonoverlappingMapRNG rng(src); | |||||
| checker.set_rng(1, &rng).execs({src, map, {}}); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| TEST_F(CUDA, TENSOR_REMAP_BACKWARD) | |||||
| { | |||||
| TEST_F(CUDA, TENSOR_REMAP_BACKWARD) { | |||||
| Checker<IndexingRemapBackward> checker(handle_cuda()); | Checker<IndexingRemapBackward> checker(handle_cuda()); | ||||
| checker.set_dtype(1, dtype::Int32()); | checker.set_dtype(1, dtype::Int32()); | ||||
| TensorShape src{11, 13, 17}, | |||||
| map{3, 5, 7, 3}, | |||||
| dst{3, 5, 7}; | |||||
| using namespace tensor_remap; | |||||
| { | |||||
| MapRNG rng(src); | |||||
| checker.set_rng(1, &rng).execs({dst, map, src}); | |||||
| } | |||||
| { | |||||
| NonoverlappingMapRNG rng(src); | |||||
| checker.set_rng(1, &rng).execs({dst, map, src}); | |||||
| TensorShape src{11, 13, 17}, map{3, 5, 7, 3}, dst{3, 5, 7}; | |||||
| checker.set_dtype(1, dtype::Int32()); | |||||
| for (auto dt : std::vector<DType>{dtype::Float32(), dtype::Int32()}) { | |||||
| checker.set_dtype(0, dt); | |||||
| checker.set_dtype(2, dt); | |||||
| using namespace tensor_remap; | |||||
| { | |||||
| MapRNG rng(src); | |||||
| checker.set_rng(1, &rng).execs({dst, map, src}); | |||||
| } | |||||
| { | |||||
| NonoverlappingMapRNG rng(src); | |||||
| checker.set_rng(1, &rng).execs({dst, map, src}); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -56,5 +59,3 @@ TEST_F(CUDA, TENSOR_REMAP_BACKWARD) | |||||
| } // namespace megdnn | } // namespace megdnn | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||