GitOrigin-RevId: 43016ffa2b
tags/v1.8.0
| @@ -998,6 +998,28 @@ protected: | |||
| void check_exec(const TensorLayout& dst, size_t workspace_in_bytes); | |||
| }; | |||
| class Diag : public OperatorBase { | |||
| DEF_OPR_IMPL(Diag, OperatorBase, 1, 1); | |||
| DEF_OPR_PARAM(Diag); | |||
| public: | |||
| /** | |||
| * \see http://docs.scipy.org/doc/numpy/reference/generated/numpy.diag.html | |||
| */ | |||
| virtual void exec( | |||
| _megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) = 0; | |||
| void deduce_layout(const TensorLayout& src, TensorLayout& dst); | |||
| virtual size_t get_workspace_in_bytes( | |||
| const TensorLayout& src, const TensorLayout& dst) = 0; | |||
| protected: | |||
| void check_exec( | |||
| const TensorLayout& src, const TensorLayout& dst, | |||
| size_t workspace_in_bytes); | |||
| }; | |||
| class IndexingOneHotBase : public OperatorBase { | |||
| DEF_OPR_IMPL_CTOR(IndexingOneHotBase, OperatorBase); | |||
| DEF_OPR_PARAM(Axis); | |||
| @@ -759,6 +759,14 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0) | |||
| 'dtype', Doc('dtype', 'data type of output value'), | |||
| 'DTypeEnum::Float32')) | |||
| (pdef('Diag'). | |||
| add_fields( | |||
| 'int32', | |||
| Doc('k', 'Index of the diagonal: 0 (the default) refers to the main ' | |||
| 'diagonal, a positive value refers to an upper diagonal, and a ' | |||
| 'negative value to a lower diagonal.'), | |||
| 0)) | |||
| (pdef('UniformRNG', version=0, is_legacy=True). | |||
| add_fields('uint64', 'seed', 0)) | |||
| @@ -0,0 +1,47 @@ | |||
| /** | |||
| * \file dnn/src/common/diag.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "megdnn/oprs.h" | |||
| #include "src/common/utils.h" | |||
| namespace megdnn { | |||
| void Diag::deduce_layout(const TensorLayout& src, TensorLayout& dst) { | |||
| megdnn_assert( | |||
| src.ndim == 1 || src.ndim == 2, "Only support vector or matrix as input."); | |||
| int k = param().k; | |||
| if (src.ndim == 1) { | |||
| size_t o = src.total_nr_elems() + std::abs(k); | |||
| dst = TensorLayout(TensorShape({o, o}), src.dtype); | |||
| } else { // src.ndim == 2 | |||
| size_t m = src.shape[0]; | |||
| size_t n = src.shape[1]; | |||
| size_t o = (k >= 0 ? std::min(n - k, m) : std::min(m + k, n)); | |||
| megdnn_assert(o > 0, "The moved diagonal is out of the input matrix."); | |||
| dst = TensorLayout(TensorShape({o}), src.dtype); | |||
| } | |||
| } | |||
| void Diag::check_exec( | |||
| const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) { | |||
| TensorLayout dst_expected; | |||
| megdnn_assert_eq_dtype(src, dst); | |||
| deduce_layout(src, dst_expected); | |||
| megdnn_assert_eq_layout(dst_expected, dst); | |||
| megdnn_assert_contiguous(dst); | |||
| auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst); | |||
| megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
| } | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -146,6 +146,7 @@ private: | |||
| cb(BatchedSetMeshIndexing) \ | |||
| cb(Linspace) \ | |||
| cb(Eye) \ | |||
| cb(Diag) \ | |||
| cb(SleepForward) \ | |||
| cb(UniformRNG) \ | |||
| cb(GaussianRNG) \ | |||
| @@ -88,6 +88,7 @@ DEF(IndexingRemapForward, 3, true, true); | |||
| DEF(IndexingRemapBackward, 3, true, false); | |||
| DEF(Linspace, 1, true, false); | |||
| DEF(Eye, 1, true, false); | |||
| DEF(Diag, 2, true, true); | |||
| DEF(Flip, 2, true, true); | |||
| DEF(ROICopy, 2, true, true); | |||
| DEF(Rotate, 2, true, true); | |||
| @@ -0,0 +1,87 @@ | |||
| /** | |||
| * \file dnn/src/cuda/diag/diag.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "megdnn/dtype.h" | |||
| #include "src/cuda/diag/diag.cuh" | |||
| #include "src/cuda/utils.cuh" | |||
| namespace { | |||
| template <typename T> | |||
| __global__ void kernel_to_vector( | |||
| T* src, T* dst, ptrdiff_t start, ptrdiff_t size, ptrdiff_t stride_sum, | |||
| ptrdiff_t dst_stride) { | |||
| ptrdiff_t i = threadIdx.x + blockIdx.x * blockDim.x; | |||
| if (i < size) { | |||
| dst[dst_stride * i] = src[start + stride_sum * i]; | |||
| } | |||
| } | |||
| template <typename T> | |||
| __global__ void kernel_to_matrix( | |||
| T* src, T* dst, ptrdiff_t offset, ptrdiff_t n, ptrdiff_t k, | |||
| ptrdiff_t dst_stride0, ptrdiff_t dst_stride1, ptrdiff_t src_stride) { | |||
| ptrdiff_t i = threadIdx.x + blockIdx.x * blockDim.x; | |||
| ptrdiff_t x = i % n; | |||
| ptrdiff_t y = i / n; | |||
| ptrdiff_t p = dst_stride0 * y + dst_stride1 * x; | |||
| if (i < n * n) { | |||
| if (y + k == x) | |||
| dst[p] = src[src_stride * (y - offset)]; | |||
| else | |||
| dst[p] = 0; | |||
| } | |||
| } | |||
| } // anonymous namespace | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| namespace diag { | |||
| template <typename T> | |||
| void exec_internal_to_vector( | |||
| T* src, T* dst, ptrdiff_t start, ptrdiff_t size, ptrdiff_t stride_sum, | |||
| ptrdiff_t dst_stride, cudaStream_t stream) { | |||
| kernel_to_vector<T><<<DIVUP(size, NR_THREADS), NR_THREADS, 0, stream>>>( | |||
| src, dst, start, size, stride_sum, dst_stride); | |||
| after_kernel_launch(); | |||
| } | |||
| template <typename T> | |||
| void exec_internal_to_matrix( | |||
| T* src, T* dst, ptrdiff_t offset, ptrdiff_t n, ptrdiff_t k, | |||
| ptrdiff_t dst_stride0, ptrdiff_t dst_stride1, ptrdiff_t src_stride, | |||
| cudaStream_t stream) { | |||
| kernel_to_matrix<T><<<DIVUP(n * n, NR_THREADS), NR_THREADS, 0, stream>>>( | |||
| src, dst, offset, n, k, dst_stride0, dst_stride1, src_stride); | |||
| after_kernel_launch(); | |||
| } | |||
| #define INST(T) \ | |||
| template void exec_internal_to_vector<T>( \ | |||
| T*, T*, ptrdiff_t, ptrdiff_t, ptrdiff_t, ptrdiff_t, cudaStream_t); | |||
| #define cb(DType) INST(typename DTypeTrait<DType>::ctype) | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
| cb(::megdnn::dtype::Bool) | |||
| #undef INST | |||
| #undef cb | |||
| #define INST(T) \ | |||
| template void exec_internal_to_matrix<T>( \ | |||
| T*, T*, ptrdiff_t, ptrdiff_t, ptrdiff_t, ptrdiff_t, ptrdiff_t, ptrdiff_t, \ | |||
| cudaStream_t); | |||
| #define cb(DType) INST(typename DTypeTrait<DType>::ctype) | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE(cb) cb(::megdnn::dtype::Bool) | |||
| } // namespace diag | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -0,0 +1,33 @@ | |||
| /** | |||
| * \file dnn/src/cuda/diag/diag.cuh | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include <cuda_runtime_api.h> | |||
| #include <stdint.h> | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| namespace diag { | |||
| template <typename T> | |||
| void exec_internal_to_vector( | |||
| T* src, T* dst, ptrdiff_t start, ptrdiff_t size, ptrdiff_t stride_sum, | |||
| ptrdiff_t dst_stride, cudaStream_t stream); | |||
| template <typename T> | |||
| void exec_internal_to_matrix( | |||
| T* src, T* dst, ptrdiff_t start, ptrdiff_t n, ptrdiff_t k, | |||
| ptrdiff_t dst_stride0, ptrdiff_t dst_stride1, ptrdiff_t src_stride, | |||
| cudaStream_t stream); | |||
| } // namespace diag | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -0,0 +1,61 @@ | |||
| /** | |||
| * \file dnn/src/cuda/diag/opr_impl.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "src/cuda/diag/opr_impl.h" | |||
| #include "src/cuda/diag/diag.cuh" | |||
| #include "src/cuda/utils.h" | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| void DiagImpl::exec( | |||
| _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||
| check_exec(src.layout, dst.layout, workspace.size); | |||
| if (src.layout.ndim == 2) { | |||
| auto src_stride0 = src.layout.stride[0]; | |||
| auto src_stride1 = src.layout.stride[1]; | |||
| auto dst_stride = dst.layout.stride[0]; | |||
| auto start = | |||
| (param().k >= 0) ? param().k * src_stride1 : -param().k * src_stride0; | |||
| #define cb(DType) \ | |||
| if (dst.layout.dtype.enumv() == DTypeTrait<DType>::enumv) { \ | |||
| using ctype = typename DTypeTrait<DType>::ctype; \ | |||
| diag::exec_internal_to_vector<ctype>( \ | |||
| src.ptr<ctype>(), dst.ptr<ctype>(), start, dst.layout.shape[0], \ | |||
| src_stride0 + src_stride1, dst_stride, cuda_stream(handle())); \ | |||
| } | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
| cb(::megdnn::dtype::Bool) | |||
| #undef cb | |||
| } else { | |||
| auto n = dst.layout.shape[0]; | |||
| auto src_stride = src.layout.stride[0]; | |||
| auto dst_stride0 = dst.layout.stride[0]; | |||
| auto dst_stride1 = dst.layout.stride[1]; | |||
| auto offset = (param().k >= 0) ? 0 : -param().k; | |||
| #define cb(DType) \ | |||
| if (dst.layout.dtype.enumv() == DTypeTrait<DType>::enumv) { \ | |||
| using ctype = typename DTypeTrait<DType>::ctype; \ | |||
| diag::exec_internal_to_matrix<ctype>( \ | |||
| src.ptr<ctype>(), dst.ptr<ctype>(), offset, n, param().k, dst_stride0, \ | |||
| dst_stride1, src_stride, cuda_stream(handle())); \ | |||
| } | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
| cb(::megdnn::dtype::Bool) | |||
| #undef cb | |||
| } | |||
| } | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -0,0 +1,31 @@ | |||
| /** | |||
| * \file dnn/src/cuda/diag/opr_impl.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "megdnn/oprs.h" | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| class DiagImpl final : public Diag { | |||
| public: | |||
| using Diag::Diag; | |||
| void exec( | |||
| _megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) override; | |||
| size_t get_workspace_in_bytes( | |||
| const TensorLayout& src, const TensorLayout& dst) override { | |||
| return 0; | |||
| } | |||
| }; | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -33,6 +33,7 @@ | |||
| #include "src/cuda/dct/opr_impl.h" | |||
| #include "src/cuda/deformable_conv/opr_impl.h" | |||
| #include "src/cuda/deformable_ps_roi_pooling/opr_impl.h" | |||
| #include "src/cuda/diag/opr_impl.h" | |||
| #include "src/cuda/dot/opr_impl.h" | |||
| #include "src/cuda/dropout/opr_impl.h" | |||
| #include "src/cuda/elemwise/opr_impl.h" | |||
| @@ -154,6 +155,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(BatchedIncrMeshIndexing); | |||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(BatchedSetMeshIndexing); | |||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(Linspace); | |||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(Eye); | |||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(Diag); | |||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(SleepForward); | |||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(UniformRNG); | |||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(GaussianRNG); | |||
| @@ -0,0 +1,60 @@ | |||
| /** | |||
| * \file dnn/src/naive/diag/opr_impl.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "src/naive/diag/opr_impl.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/naive/handle.h" | |||
| namespace megdnn { | |||
| namespace naive { | |||
| template <typename ctype> | |||
| void DiagImpl::exec_internal( | |||
| ctype* src, const TensorLayout& src_layout, ctype* dst, | |||
| const TensorLayout& dst_layout, size_t input_ndim, int k) { | |||
| if (input_ndim == 1) { | |||
| size_t l = src_layout.shape[0]; | |||
| size_t s0 = dst_layout.stride[0]; | |||
| size_t s1 = dst_layout.stride[1]; | |||
| size_t start = (k >= 0) ? (k * s1) : (-k * s0); | |||
| for (size_t i = 0; i < dst_layout.shape[0]; ++i) | |||
| for (size_t j = 0; j < dst_layout.shape[1]; ++j) | |||
| dst[i * s0 + j * s1] = 0; | |||
| for (size_t i = 0; i < l; ++i) | |||
| dst[start + i * (s0 + s1)] = src[i]; | |||
| } else { | |||
| size_t l = dst_layout.shape[0]; | |||
| size_t s0 = src_layout.stride[0]; | |||
| size_t s1 = src_layout.stride[1]; | |||
| size_t start = (k >= 0) ? (k * s1) : (-k * s0); | |||
| for (size_t i = 0; i < l; ++i) | |||
| dst[i] = src[start + i * (s0 + s1)]; | |||
| } | |||
| } | |||
| void DiagImpl::exec( | |||
| _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||
| check_exec(src.layout, dst.layout, workspace.size); | |||
| #define cb(DType) \ | |||
| if (src.layout.dtype == DType()) { \ | |||
| using ctype = typename DTypeTrait<DType>::ctype; \ | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR(exec_internal<ctype>( \ | |||
| src.ptr<ctype>(), src.layout, dst.ptr<ctype>(), dst.layout, \ | |||
| src.layout.ndim, param().k)); \ | |||
| } | |||
| MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||
| cb(::megdnn::dtype::Bool) | |||
| #undef cb | |||
| } | |||
| } // namespace naive | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,37 @@ | |||
| /** | |||
| * \file dnn/src/naive/diag/opr_impl.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "megdnn/oprs.h" | |||
| namespace megdnn { | |||
| namespace naive { | |||
| class DiagImpl : public Diag { | |||
| public: | |||
| using Diag::Diag; | |||
| void exec( | |||
| _megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) override { | |||
| return 0; | |||
| } | |||
| private: | |||
| template <typename ctype> | |||
| void exec_internal( | |||
| ctype* src, const TensorLayout& src_layout, ctype* dst, | |||
| const TensorLayout& dst_layout, size_t input_ndim, int k); | |||
| }; | |||
| } // namespace naive | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -34,6 +34,7 @@ | |||
| #include "src/naive/dct/opr_impl.h" | |||
| #include "src/naive/deformable_conv/opr_impl.h" | |||
| #include "src/naive/deformable_ps_roi_pooling/opr_impl.h" | |||
| #include "src/naive/diag/opr_impl.h" | |||
| #include "src/naive/dot/opr_impl.h" | |||
| #include "src/naive/dropout/opr_impl.h" | |||
| #include "src/naive/elemwise/opr_impl.h" | |||
| @@ -0,0 +1,42 @@ | |||
| /** | |||
| * \file dnn/test/cuda/diag.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "test/cuda/fixture.h" | |||
| #include "megdnn/oprs.h" | |||
| #include "test/common/checker.h" | |||
| namespace megdnn { | |||
| namespace test { | |||
| TEST_F(CUDA, DIAG) { | |||
| Checker<Diag> checker(handle_cuda()); | |||
| for (DType dtype : | |||
| std::vector<DType>{dtype::Float16(), dtype::Int32(), dtype::Float32()}) | |||
| for (int k = -5; k < 5; ++k) { | |||
| checker.set_param({k}); | |||
| checker.set_dtype(0, dtype); | |||
| checker.set_dtype(1, dtype); | |||
| size_t absk = static_cast<size_t>(std::abs(k)); | |||
| checker.exec(TensorShapeArray{{8}, {8 + absk, 8 + absk}}); | |||
| auto oshape = [&](int n, int m) -> TensorShape { | |||
| size_t o = (k >= 0 ? std::min(n - k, m) : std::min(m + k, n)); | |||
| return {o, o}; | |||
| }; | |||
| checker.exec(TensorShapeArray{{8, 6}, oshape(8, 6)}); | |||
| checker.exec(TensorShapeArray{{6, 8}, oshape(6, 8)}); | |||
| checker.exec(TensorShapeArray{{8, 8}, oshape(8, 8)}); | |||
| } | |||
| } | |||
| } // namespace test | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -0,0 +1,111 @@ | |||
| /** | |||
| * \file dnn/test/naive/diag.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "megdnn/dtype.h" | |||
| #include "megdnn/oprs.h" | |||
| #include "test/common/checker.h" | |||
| #include "test/naive/fixture.h" | |||
| namespace megdnn { | |||
| namespace test { | |||
| TEST_F(NAIVE, DiagVector2Matrix) { | |||
| Checker<Diag> checker(handle(), false); | |||
| Diag::Param param; | |||
| param.k = 0; | |||
| checker.set_param(param).exect( | |||
| Testcase{TensorValue({3}, dtype::Float32(), {1, 2, 3}), {}}, | |||
| Testcase{ | |||
| {}, | |||
| // clang-format off | |||
| TensorValue({3, 3}, dtype::Float32(), {1, 0, 0, | |||
| 0, 2, 0, | |||
| 0, 0, 3})}); | |||
| // clang-format on | |||
| } | |||
| TEST_F(NAIVE, DiagVector2Matrix_PositiveK) { | |||
| Checker<Diag> checker(handle(), false); | |||
| Diag::Param param; | |||
| param.k = 1; | |||
| checker.set_param(param).exect( | |||
| Testcase{TensorValue({3}, dtype::Float32(), {1, 2, 3}), {}}, | |||
| Testcase{ | |||
| {}, | |||
| // clang-format off | |||
| TensorValue({4, 4}, dtype::Float32(), {0, 1, 0, 0, | |||
| 0, 0, 2, 0, | |||
| 0, 0, 0, 3, | |||
| 0, 0, 0, 0,})}); | |||
| // clang-format on | |||
| } | |||
| TEST_F(NAIVE, DiagVector2Matrix_NegativeK) { | |||
| Checker<Diag> checker(handle(), false); | |||
| Diag::Param param; | |||
| param.k = -1; | |||
| checker.set_param(param).exect( | |||
| Testcase{TensorValue({3}, dtype::Float32(), {1, 2, 3}), {}}, | |||
| Testcase{ | |||
| {}, | |||
| // clang-format off | |||
| TensorValue({4, 4}, dtype::Float32(), {0, 0, 0, 0, | |||
| 1, 0, 0, 0, | |||
| 0, 2, 0, 0, | |||
| 0, 0, 3, 0,})}); | |||
| // clang-format on | |||
| } | |||
| TEST_F(NAIVE, DiagMatrix2Vector) { | |||
| Checker<Diag> checker(handle(), false); | |||
| Diag::Param param; | |||
| param.k = 0; | |||
| checker.set_param(param).exect( | |||
| // clang-format off | |||
| Testcase{TensorValue({3, 3}, dtype::Float32(), {1, 2, 3, | |||
| 4, 5, 6, | |||
| 7, 8, 9}), | |||
| // clang-format on | |||
| {}}, | |||
| Testcase{{}, TensorValue({3}, dtype::Float32(), {1, 5, 9})}); | |||
| } | |||
| TEST_F(NAIVE, DiagMatrix2Vector_PositiveK) { | |||
| Checker<Diag> checker(handle(), false); | |||
| Diag::Param param; | |||
| param.k = 1; | |||
| checker.set_param(param).exect( | |||
| // clang-format off | |||
| Testcase{TensorValue({3, 3}, dtype::Float32(), {1, 2, 3, | |||
| 4, 5, 6, | |||
| 7, 8, 9}), | |||
| // clang-format on | |||
| {}}, | |||
| Testcase{{}, TensorValue({2}, dtype::Float32(), {2, 6})}); | |||
| } | |||
| TEST_F(NAIVE, DiagMatrix2Vector_NegativeK) { | |||
| Checker<Diag> checker(handle(), false); | |||
| Diag::Param param; | |||
| param.k = -1; | |||
| checker.set_param(param).exect( | |||
| // clang-format off | |||
| Testcase{TensorValue({3, 3}, dtype::Float32(), {1, 2, 3, | |||
| 4, 5, 6, | |||
| 7, 8, 9}), | |||
| // clang-format on | |||
| {}}, | |||
| Testcase{{}, TensorValue({2}, dtype::Float32(), {4, 8})}); | |||
| } | |||
| } // namespace test | |||
| } // namespace megdnn | |||
| @@ -28,6 +28,7 @@ __all__ = [ | |||
| "concat", | |||
| "cond_take", | |||
| "cumsum", | |||
| "diag", | |||
| "expand_dims", | |||
| "eye", | |||
| "flatten", | |||
| @@ -53,6 +54,32 @@ __all__ = [ | |||
| ] | |||
| def diag(inp, k=0) -> Tensor: | |||
| r"""If ``inp`` is a 1D tensor, then returns a 2D tensor with the elements of ``inp`` as the diagonal. | |||
| If ``inp`` is a 2D tensor, then returns a 1D tensor with the diagonal elements of ``inp``. | |||
| Args: | |||
| inp: input tensor. | |||
| k: diagonal in consider. Use :math:`k=0` for the main diagonal, :math:`k>0` for diagonals above the | |||
| main diagonal, and :math:`k<0` for diagonals below the main diagonal. Default: 0. | |||
| Returns: | |||
| the extracted diagonal or constructed diagonal array. | |||
| Examples: | |||
| >>> inp = F.arange(6, dtype='int32').reshape(2,3) | |||
| >>> out = F.diag(inp, k=1) | |||
| >>> out | |||
| Tensor([1 5], dtype=int32, device=xpux:0) | |||
| >>> F.diag(out) | |||
| Tensor([[1 0] | |||
| [0 5]], dtype=int32, device=xpux:0) | |||
| """ | |||
| op = builtin.Diag(k=k) | |||
| (result,) = apply(op, inp) | |||
| return result | |||
| def eye(N, M=None, *, dtype="float32", device: Optional[CompNode] = None) -> Tensor: | |||
| r"""Returns a 2D tensor with ones on the diagonal and zeros elsewhere. | |||
| @@ -42,6 +42,26 @@ def test_eye(): | |||
| ) | |||
| @pytest.mark.parametrize("is_varnode", [False, True]) | |||
| def test_diag(is_varnode): | |||
| if is_varnode: | |||
| network = Network() | |||
| else: | |||
| network = None | |||
| shapes = [(10, 10), (6, 9), (8, 7), (8,)] | |||
| cases = [] | |||
| for shp in shapes: | |||
| cases.append({"input": [np.random.random(shp).astype("float32")]}) | |||
| for axis in range(-2, 3): | |||
| def run(data): | |||
| return F.diag(data, k=axis) | |||
| opr_test(cases, run, ref_fn=lambda x: np.diag(x, axis), network=network) | |||
| def test_full(): | |||
| shape = (2, 3) | |||
| values = [True, 4, 5.0] | |||
| @@ -432,6 +432,19 @@ OP_TRAIT_REG(Eye, Eye).apply_on_var_node(apply_on_var_node).fallback(); | |||
| } // namespace eye | |||
| } // namespace | |||
| namespace { | |||
| namespace diag { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const Diag&>(def); | |||
| mgb_assert(inputs.size() == 1); | |||
| cg::OperatorNodeConfig config{op.make_name()}; | |||
| opr::Diag::Param param{op.k}; | |||
| return opr::Diag::make(inputs[0], param, config); | |||
| } | |||
| OP_TRAIT_REG(Diag, Diag).apply_on_var_node(apply_on_var_node).fallback(); | |||
| } // namespace diag | |||
| } // namespace | |||
| namespace { | |||
| namespace roi_pooling { | |||
| VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| @@ -240,6 +240,8 @@ def Eye: MgbHashableOp<"Eye", [EyeParam]> { | |||
| ); | |||
| } | |||
| def Diag: MgbHashableOp<"Diag", [DiagParam]>; | |||
| def GetVarShape : MgbHashableOp<"GetVarShape", [OptionalAxisV1Param]>; | |||
| def Concat: MgbHashableOp<"Concat", [AxisParam]> { | |||
| @@ -75,6 +75,91 @@ struct MegDNNOprInitInputsModifier<IndexingSetOneHot> | |||
| } // namespace opr | |||
| } // namespace mgb | |||
| /* ==================== Diag ==================== */ | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(Diag); | |||
| MEGDNN_OPR_INIT1(Diag, "diag") | |||
| #if MGB_ENABLE_GRAD | |||
| MGB_IMPL_OPR_GRAD(Diag) { | |||
| if (wrt_idx == 0) { | |||
| SymbolVar data_sym{opr.input(0)}; | |||
| return DiagBackward::make(data_sym.symshape(), out_grad[0], opr.param()).node(); | |||
| } | |||
| return InvalidGrad::make(opr, wrt_idx); | |||
| } | |||
| #endif | |||
| /* ==================== DiagBackward ==================== */ | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(DiagBackward); | |||
| DiagBackward::DiagBackward( | |||
| VarNode* shape, VarNode* value, const Param& param, | |||
| const OperatorNodeConfig& config) | |||
| : Super{shape->owner_graph(), config, "diag_backward", {shape, value}}, | |||
| m_param{param} { | |||
| add_input({shape, value}); | |||
| add_output(None)->dtype(value->dtype()); | |||
| add_equivalence_component<PODHash<Param>>(&m_param); | |||
| } | |||
| SymbolVar DiagBackward::make( | |||
| SymbolVar shape, SymbolVar value, const Param& param, | |||
| const OperatorNodeConfig& config) { | |||
| return shape.insert_single_output_opr<DiagBackward>( | |||
| shape.node(), value.node(), param, config); | |||
| } | |||
| cg::OperatorNodeBase::NodeProp* DiagBackward::do_make_node_prop() const { | |||
| auto prop = Super::do_make_node_prop(); | |||
| using D = NodeProp::DepType; | |||
| prop->add_dep_type(input(0), D::HOST_VALUE); | |||
| return prop; | |||
| } | |||
| void DiagBackward::scn_do_execute() { | |||
| auto&& dest = output(0)->dev_tensor(); | |||
| auto&& val = input(1)->dev_tensor(); | |||
| auto&& layout = dest.layout(); | |||
| mgb_assert(layout.ndim == 1 || layout.ndim == 2); | |||
| if (layout.ndim == 2) { | |||
| dev_tensor_memset(dest, 0); | |||
| size_t offset = (m_param.k >= 0) ? (m_param.k * layout.stride[1]) | |||
| : (-m_param.k * layout.stride[0]); | |||
| auto dest_sub = dest.sub(SubTensorSpec::make_from_offset_elem( | |||
| {val.shape(), {layout.stride[0] + layout.stride[1]}, val.dtype()}, | |||
| offset)); | |||
| dest_sub.copy_from_fixlayout(val); | |||
| } else { | |||
| auto&& opr = m_dnn_opr; | |||
| if (!opr) { | |||
| opr = intl::create_megdnn_opr<megdnn::Diag>(comp_node()); | |||
| opr->param() = m_param; | |||
| } | |||
| opr->exec(val.as_megdnn(), dest.as_megdnn(), {}); | |||
| } | |||
| } | |||
| void DiagBackward::record_execute_deps(ExecDependencyArray& deps) { | |||
| deps.emplace_back(std::make_unique<intl::MegDNNGraphDep>(std::move(m_dnn_opr))); | |||
| } | |||
| void DiagBackward::init_output_static_infer_desc() { | |||
| using namespace cg::static_infer; | |||
| auto&& mgr = owner_graph()->static_infer_manager(); | |||
| auto infer_shape = [](TensorShape& dest, const InpVal& inp) { | |||
| cg::copy_tensor_value_to_shape(dest, inp.val.at(0).value()); | |||
| return true; | |||
| }; | |||
| mgr.register_shape_infer( | |||
| output(0), {SourceType::DEP, {{input(0), DepType::VALUE}}, infer_shape}); | |||
| } | |||
| #if MGB_ENABLE_GRAD | |||
| MGB_IMPL_OPR_GRAD(DiagBackward) { | |||
| return InvalidGrad::make(opr, wrt_idx); | |||
| } | |||
| #endif | |||
| /* ==================== IndexingOneHot ==================== */ | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(IndexingOneHot); | |||
| MEGDNN_OPR_INIT2(IndexingOneHot, "indexing_one_hot") | |||
| @@ -1,3 +1,25 @@ | |||
| decl_opr( | |||
| 'Diag', | |||
| desc='Extract a diagonal or construct a diagonal array', | |||
| inputs=[ | |||
| Doc('k', 'Index of the diagonal: 0 (the default) refers to the main ' | |||
| 'diagonal, a positive value refers to an upper diagonal, and a ' | |||
| 'negative value to a lower diagonal.') | |||
| ], | |||
| params='Diag' | |||
| ) | |||
| decl_opr( | |||
| 'DiagBackward', | |||
| desc='backward function of Diag', | |||
| inputs=[ | |||
| Doc('k', 'Index of the diagonal: 0 (the default) refers to the main ' | |||
| 'diagonal, a positive value refers to an upper diagonal, and a ' | |||
| 'negative value to a lower diagonal.') | |||
| ], | |||
| params='Diag' | |||
| ) | |||
| decl_opr('IndexingOneHot', pyname='_indexing_one_hot', | |||
| inputs=['src', 'index'], | |||
| params=[('axis', 'Axis')]) | |||
| @@ -25,6 +25,8 @@ MGB_SEREG_MODIFY_SUBTENSOR_OPR(BatchedSetMeshIndexing); | |||
| namespace mgb { | |||
| namespace opr { | |||
| MGB_SEREG_OPR(Diag, 1); | |||
| MGB_SEREG_OPR(DiagBackward, 2); | |||
| MGB_SEREG_OPR(IndexingOneHot, 2); | |||
| MGB_SEREG_OPR(IndexingRemap, 2); | |||
| MGB_SEREG_OPR(IndexingRemapBackward, 3); | |||
| @@ -19,6 +19,37 @@ | |||
| namespace mgb { | |||
| namespace opr { | |||
| MGB_DEFINE_OPR_CLASS(Diag, intl::MegDNNOprWrapperFwd<megdnn::Diag>) // { | |||
| public: | |||
| MGE_WIN_DECLSPEC_FUC Diag( | |||
| VarNode* src, const Param& param, const OperatorNodeConfig& config); | |||
| MGE_WIN_DECLSPEC_FUC static SymbolVar make( | |||
| SymbolVar src, const Param& param, const OperatorNodeConfig& config = {}); | |||
| }; | |||
| MGB_DEFINE_OPR_CLASS(DiagBackward, cg::SingleCNOperatorNodeBase) // { | |||
| public: | |||
| using Param = megdnn::Diag::Param; | |||
| MGE_WIN_DECLSPEC_FUC DiagBackward( | |||
| VarNode* shape, VarNode* value, const Param& param, | |||
| const OperatorNodeConfig& config); | |||
| MGE_WIN_DECLSPEC_FUC static SymbolVar make( | |||
| SymbolVar shape, SymbolVar value, const Param& param, | |||
| const OperatorNodeConfig& config = {}); | |||
| const Param& param() const { return m_param; } | |||
| private: | |||
| Param m_param; | |||
| intl::UniqPtrWithCN<megdnn::Diag> m_dnn_opr; | |||
| void scn_do_execute() override; | |||
| void init_output_static_infer_desc() override; | |||
| NodeProp* do_make_node_prop() const override; | |||
| void record_execute_deps(ExecDependencyArray& deps) override; | |||
| }; | |||
| MGB_DEFINE_OPR_CLASS( | |||
| IndexingOneHot, intl::MegDNNOprWrapperFwd<megdnn::IndexingOneHotForward>) // { | |||
| public: | |||
| @@ -52,6 +52,37 @@ void gen_index_onehot(int* max_value, HostTensorND& dest) { | |||
| } | |||
| } | |||
| void test_diag(int32_t axis, const TensorShapeArray& test_cases) { | |||
| using Checker = AutoOprChecker<1, 1>; | |||
| auto nopr = megdnn_naive_handle()->create_operator<megdnn::Diag>(); | |||
| nopr->param() = {axis}; | |||
| auto make_graph = [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray { | |||
| return {opr::Diag::make(inputs[0], {axis})}; | |||
| }; | |||
| auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) { | |||
| auto&& src = *inp[0]; | |||
| TensorShape oshp(src.shape()); | |||
| if (oshp.ndim == 1) { | |||
| size_t o = oshp.shape[0] + std::abs(axis); | |||
| oshp = {o, o}; | |||
| } else { | |||
| size_t m = oshp.shape[0]; | |||
| size_t n = oshp.shape[1]; | |||
| size_t o = (axis >= 0) ? std::min(n - axis, m) : std::min(m + axis, n); | |||
| oshp = {o}; | |||
| } | |||
| dest[0].resize(oshp); | |||
| nopr->exec(src.as_megdnn(), dest[0].as_megdnn(), {}); | |||
| }; | |||
| Checker checker{make_graph, fwd}; | |||
| for (auto&& i : test_cases) { | |||
| checker.run({i}); | |||
| } | |||
| } | |||
| void test_one_hot_get(int32_t axis, const TensorShapeArray& test_cases) { | |||
| using Checker = AutoOprChecker<2, 1>; | |||
| @@ -145,6 +176,12 @@ void test_one_hot(int32_t axis, const TensorShapeArray& test_cases) { | |||
| } // anonymous namespace | |||
| TEST(TestOprDiag, Diag) { | |||
| TensorShapeArray cases = {{7, 7}, {7, 9}, {9, 7}, {8}}; | |||
| for (int32_t k = -3; k < 3; ++k) | |||
| test_diag(k, cases); | |||
| } | |||
| TEST(TestOprIndexing, OneHot2D) { | |||
| TensorShapeArray cases = {{1, 1}, {2, 2}, {10, 8}, {8, 10}}; | |||
| test_one_hot(0, cases); | |||
| @@ -122,6 +122,7 @@ union OperatorParam { | |||
| param.RNN = 88, | |||
| param.LSTM = 89, | |||
| param.Softmax = 90, | |||
| param.Diag = 91, | |||
| } | |||
| table Operator { | |||