| @@ -1741,6 +1741,67 @@ protected: | |||
| const TensorLayout& grad_s, size_t workspace_in_bytes); | |||
| }; | |||
| class LSQBase : public OperatorBase { | |||
| DEF_OPR_IMPL_CTOR(LSQBase, OperatorBase); | |||
| DEF_OPR_PARAM(LSQ); | |||
| protected: | |||
| void deduce_layout_fwd(const TensorLayout& input, TensorLayout& output); | |||
| void check_layout_fwd(const TensorLayout& input, const TensorLayout& scale, | |||
| const TensorLayout& zero_point, | |||
| const TensorLayout& grad_scale, | |||
| const TensorLayout& output); | |||
| }; | |||
| class LSQForward : public LSQBase { | |||
| DEF_OPR_IMPL(LSQForward, LSQBase, 4, 1); | |||
| public: | |||
| virtual void exec(_megdnn_tensor_in input, _megdnn_tensor_in scale, | |||
| _megdnn_tensor_in zero_point, | |||
| _megdnn_tensor_in grad_scale, _megdnn_tensor_out output, | |||
| _megdnn_workspace workspace) = 0; | |||
| void deduce_layout(const TensorLayout& input, const TensorLayout& scale, | |||
| const TensorLayout& zero_point, | |||
| const TensorLayout& grad_scale, TensorLayout& output); | |||
| virtual size_t get_workspace_in_bytes(const TensorLayout& input, | |||
| const TensorLayout& scale, | |||
| const TensorLayout& zero_point, | |||
| const TensorLayout& grad_scale, | |||
| const TensorLayout& output) = 0; | |||
| protected: | |||
| void check_exec(const TensorLayout& input, const TensorLayout& scale, | |||
| const TensorLayout& zero_point, | |||
| const TensorLayout& grad_scale, const TensorLayout& output, | |||
| size_t workspace_in_bytes); | |||
| }; | |||
| using LSQ = LSQForward; | |||
| class LSQBackward : public LSQBase { | |||
| DEF_OPR_IMPL(LSQBackward, LSQBase, 5, 2); | |||
| public: | |||
| virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in input, | |||
| _megdnn_tensor_in scale, _megdnn_tensor_in zero_point, | |||
| _megdnn_tensor_in grad_scale, _megdnn_tensor_out grad_x, | |||
| _megdnn_tensor_out grad_s, | |||
| _megdnn_workspace workspace) = 0; | |||
| virtual size_t get_workspace_in_bytes(const TensorLayout& diff, | |||
| const TensorLayout& input, | |||
| const TensorLayout& scale, | |||
| const TensorLayout& zero_point, | |||
| const TensorLayout& grad_scale, | |||
| const TensorLayout& grad_x, | |||
| const TensorLayout& grad_s) = 0; | |||
| protected: | |||
| void check_exec(const TensorLayout& diff, const TensorLayout& input, | |||
| const TensorLayout& scale, const TensorLayout& zero_point, | |||
| const TensorLayout& grad_scale, const TensorLayout& grad_x, | |||
| const TensorLayout& grad_s, size_t workspace_in_bytes); | |||
| }; | |||
| } // namespace megdnn | |||
| #include "megdnn/internal/opr_header_epilogue.h" | |||
| @@ -1124,3 +1124,8 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o | |||
| add_fields('int32', 'qmin', '-2147483648'). | |||
| add_fields('int32', 'qmax', '2147483647') | |||
| ) | |||
| (pdef('LSQ'). | |||
| add_fields('int32', 'qmin', '-2147483648'). | |||
| add_fields('int32', 'qmax', '2147483647') | |||
| ) | |||
| @@ -37,6 +37,7 @@ namespace megdnn { | |||
| megdnn_assert(size, "uninitialized ElemwiseOpParamN"); | |||
| } | |||
| template struct ElemwiseOpParamN<7>; | |||
| template struct ElemwiseOpParamN<6>; | |||
| template struct ElemwiseOpParamN<5>; | |||
| template struct ElemwiseOpParamN<4>; | |||
| @@ -208,7 +208,9 @@ private: | |||
| cb(FakeQuantBackward) \ | |||
| cb(TQTForward) \ | |||
| cb(TQTBackward) \ | |||
| cb(CheckHasInf) | |||
| cb(CheckHasInf) \ | |||
| cb(LSQForward) \ | |||
| cb(LSQBackward) | |||
| /*! | |||
| * \brief specialize HandleImpl::create_operator for a single opr type; | |||
| @@ -0,0 +1,69 @@ | |||
| /** | |||
| * \file dnn/src/common/lsq.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "megdnn/oprs.h" | |||
| #include "src/common/utils.h" | |||
| namespace megdnn { | |||
| void LSQBase::deduce_layout_fwd(const TensorLayout& input, | |||
| TensorLayout& output) { | |||
| output = TensorLayout(input, input.dtype); | |||
| } | |||
| void LSQBase::check_layout_fwd(const TensorLayout& input, | |||
| const TensorLayout& scale, | |||
| const TensorLayout& zero_point, | |||
| const TensorLayout& grad_scale, | |||
| const TensorLayout& output) { | |||
| megdnn_assert(input.dtype == dtype::Float32()); | |||
| megdnn_assert(scale.dtype == dtype::Float32()); | |||
| megdnn_assert(zero_point.dtype == dtype::Float32()); | |||
| megdnn_assert(grad_scale.dtype == dtype::Float32()); | |||
| TensorLayout expected; | |||
| deduce_layout_fwd(input, expected); | |||
| megdnn_assert_eq_layout(expected, output); | |||
| } | |||
| void LSQForward::deduce_layout(const TensorLayout& input, | |||
| const TensorLayout& /* scale */, | |||
| const TensorLayout& /*zero_point*/, | |||
| const TensorLayout& /*grad_scale*/, | |||
| TensorLayout& output) { | |||
| deduce_layout_fwd(input, output); | |||
| } | |||
| void LSQForward::check_exec(const TensorLayout& input, | |||
| const TensorLayout& scale, | |||
| const TensorLayout& zero_point, | |||
| const TensorLayout& grad_scale, | |||
| const TensorLayout& output, | |||
| size_t workspace_in_bytes) { | |||
| check_layout_fwd(input, scale, zero_point, grad_scale, output); | |||
| auto required_workspace_space = get_workspace_in_bytes( | |||
| input, scale, zero_point, grad_scale, output); | |||
| megdnn_assert(workspace_in_bytes >= required_workspace_space); | |||
| } | |||
| void LSQBackward::check_exec( | |||
| const TensorLayout& diff, const TensorLayout& input, | |||
| const TensorLayout& scale, const TensorLayout& zero_point, | |||
| const TensorLayout& grad_scale, const TensorLayout& grad_x, | |||
| const TensorLayout& grad_s, size_t workspace_in_bytes) { | |||
| megdnn_assert_eq_shape(diff, input); | |||
| megdnn_assert_eq_shape(grad_x, input); | |||
| auto required_worspace_space = get_workspace_in_bytes( | |||
| diff, input, scale, zero_point, grad_scale, grad_x, grad_s); | |||
| megdnn_assert(workspace_in_bytes >= required_worspace_space); | |||
| } | |||
| } // namespace megdnn | |||
| @@ -6,7 +6,8 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "megdnn/oprs.h" | |||
| @@ -121,6 +122,8 @@ DEF(UniformRNG, 1, true, true); | |||
| DEF(GaussianRNG, 1, true, true); | |||
| DEF(ChecksumForward, 1, true, false); | |||
| DEF(CheckHasInf, 2, true, true); | |||
| DEF(LSQForward, 5, true, true); | |||
| DEF(LSQBackward, 7, true, false); | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -947,6 +947,119 @@ struct OpCallerUniform<Op, 5, PVis> { | |||
| } | |||
| }; | |||
| //! specialization for arity == 6 | |||
| template <class Op, class PVis> | |||
| struct OpCallerUniform<Op, 6, PVis> { | |||
| Op op; | |||
| PVis par[6]; | |||
| static const uint32_t packed_size = PVis::packed_size; | |||
| devfunc void thread_init(uint32_t idx) { | |||
| idx = idx * packed_size; | |||
| par[0].thread_init(idx); | |||
| par[1].thread_init(idx); | |||
| par[2].thread_init(idx); | |||
| par[3].thread_init(idx); | |||
| par[4].thread_init(idx); | |||
| par[5].thread_init(idx); | |||
| } | |||
| devfunc void on(uint32_t idx) { | |||
| idx = idx * packed_size; | |||
| op(idx, par[0].at(idx), par[1].at(idx), par[2].at(idx), par[3].at(idx), | |||
| par[4].at(idx), par[5].at(idx)); | |||
| } | |||
| devfunc void on(uint32_t idx, uint32_t remain) { | |||
| idx = idx * packed_size; | |||
| if (remain >= packed_size) { | |||
| op(idx, par[0].at(idx), par[1].at(idx), par[2].at(idx), | |||
| par[3].at(idx), par[4].at(idx), par[5].at(idx)); | |||
| } else { | |||
| auto ptr0 = par[0].ptr(); | |||
| auto ptr1 = par[1].ptr(); | |||
| auto ptr2 = par[2].ptr(); | |||
| auto ptr3 = par[3].ptr(); | |||
| auto ptr4 = par[4].ptr(); | |||
| auto ptr5 = par[5].ptr(); | |||
| for (int i = 0; i < remain; i++) { | |||
| op(idx + i, ptr0[par[0].offset(idx + i)], | |||
| ptr1[par[1].offset(idx + i)], ptr2[par[2].offset(idx + i)], | |||
| ptr3[par[3].offset(idx + i)], ptr4[par[4].offset(idx + i)], | |||
| ptr5[par[5].offset(idx + i)]); | |||
| } | |||
| } | |||
| } | |||
| devfunc void next() { | |||
| par[0].next(); | |||
| par[1].next(); | |||
| par[2].next(); | |||
| par[3].next(); | |||
| par[4].next(); | |||
| par[5].next(); | |||
| } | |||
| }; | |||
| //! specialization for arity == 7 | |||
| template <class Op, class PVis> | |||
| struct OpCallerUniform<Op, 7, PVis> { | |||
| Op op; | |||
| PVis par[7]; | |||
| static const uint32_t packed_size = PVis::packed_size; | |||
| devfunc void thread_init(uint32_t idx) { | |||
| idx = idx * packed_size; | |||
| par[0].thread_init(idx); | |||
| par[1].thread_init(idx); | |||
| par[2].thread_init(idx); | |||
| par[3].thread_init(idx); | |||
| par[4].thread_init(idx); | |||
| par[5].thread_init(idx); | |||
| par[6].thread_init(idx); | |||
| } | |||
| devfunc void on(uint32_t idx) { | |||
| idx = idx * packed_size; | |||
| op(idx, par[0].at(idx), par[1].at(idx), par[2].at(idx), par[3].at(idx), | |||
| par[4].at(idx), par[5].at(idx), par[6].at(idx)); | |||
| } | |||
| devfunc void on(uint32_t idx, uint32_t remain) { | |||
| idx = idx * packed_size; | |||
| if (remain >= packed_size) { | |||
| op(idx, par[0].at(idx), par[1].at(idx), par[2].at(idx), | |||
| par[3].at(idx), par[4].at(idx), par[5].at(idx), par[6].at(idx)); | |||
| } else { | |||
| auto ptr0 = par[0].ptr(); | |||
| auto ptr1 = par[1].ptr(); | |||
| auto ptr2 = par[2].ptr(); | |||
| auto ptr3 = par[3].ptr(); | |||
| auto ptr4 = par[4].ptr(); | |||
| auto ptr5 = par[5].ptr(); | |||
| auto ptr6 = par[6].ptr(); | |||
| for (int i = 0; i < remain; i++) { | |||
| op(idx + i, ptr0[par[0].offset(idx + i)], | |||
| ptr1[par[1].offset(idx + i)], ptr2[par[2].offset(idx + i)], | |||
| ptr3[par[3].offset(idx + i)], ptr4[par[4].offset(idx + i)], | |||
| ptr5[par[5].offset(idx + i)], ptr6[par[6].offset(idx + i)]); | |||
| } | |||
| } | |||
| } | |||
| devfunc void next() { | |||
| par[0].next(); | |||
| par[1].next(); | |||
| par[2].next(); | |||
| par[3].next(); | |||
| par[4].next(); | |||
| par[5].next(); | |||
| par[6].next(); | |||
| } | |||
| }; | |||
| /*! | |||
| * \brief call binary (i.e. arity == 2) operator with different param | |||
| * visitors | |||
| @@ -6,7 +6,8 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/common/handle_impl.h" | |||
| @@ -15,6 +16,7 @@ | |||
| #include "src/cuda/add_update/opr_impl.h" | |||
| #include "src/cuda/argmxx/opr_impl.h" | |||
| #include "src/cuda/argsort/opr_impl.h" | |||
| #include "src/cuda/batch_conv_bias/opr_impl.h" | |||
| #include "src/cuda/batch_normalization/opr_impl.h" | |||
| #include "src/cuda/batched_matrix_mul/opr_impl.h" | |||
| #include "src/cuda/check_has_inf/opr_impl.h" | |||
| @@ -35,6 +37,7 @@ | |||
| #include "src/cuda/elemwise/opr_impl.h" | |||
| #include "src/cuda/elemwise_multi_type/opr_impl.h" | |||
| #include "src/cuda/eye/opr_impl.h" | |||
| #include "src/cuda/fake_quant/opr_impl.h" | |||
| #include "src/cuda/flip/opr_impl.h" | |||
| #include "src/cuda/gaussian_blur/opr_impl.h" | |||
| #include "src/cuda/group_local/opr_impl.h" | |||
| @@ -45,6 +48,7 @@ | |||
| #include "src/cuda/local/opr_impl.h" | |||
| #include "src/cuda/local_share/opr_impl.h" | |||
| #include "src/cuda/lrn/opr_impl.h" | |||
| #include "src/cuda/lsq/opr_impl.h" | |||
| #include "src/cuda/mask_conv/opr_impl.h" | |||
| #include "src/cuda/matrix_inverse/opr_impl.h" | |||
| #include "src/cuda/matrix_mul/opr_impl.h" | |||
| @@ -56,9 +60,11 @@ | |||
| #include "src/cuda/reduce/opr_impl.h" | |||
| #include "src/cuda/relayout/opr_impl.h" | |||
| #include "src/cuda/relayout_format/opr_impl.h" | |||
| #include "src/cuda/remap/opr_impl.h" | |||
| #include "src/cuda/repeat/opr_impl.h" | |||
| #include "src/cuda/resize/opr_impl.h" | |||
| #include "src/cuda/rng/opr_impl.h" | |||
| #include "src/cuda/roi_align/opr_impl.h" | |||
| #include "src/cuda/roi_copy/opr_impl.h" | |||
| #include "src/cuda/roi_pooling/opr_impl.h" | |||
| #include "src/cuda/rotate/opr_impl.h" | |||
| @@ -70,16 +76,11 @@ | |||
| #include "src/cuda/tensor_remap/opr_impl.h" | |||
| #include "src/cuda/tile/opr_impl.h" | |||
| #include "src/cuda/topk/opr_impl.h" | |||
| #include "src/cuda/tqt/opr_impl.h" | |||
| #include "src/cuda/transpose/opr_impl.h" | |||
| #include "src/cuda/type_cvt/opr_impl.h" | |||
| #include "src/cuda/warp_affine/opr_impl.h" | |||
| #include "src/cuda/warp_perspective/opr_impl.h" | |||
| #include "src/cuda/local_share/opr_impl.h" | |||
| #include "src/cuda/roi_align/opr_impl.h" | |||
| #include "src/cuda/batch_conv_bias/opr_impl.h" | |||
| #include "src/cuda/remap/opr_impl.h" | |||
| #include "src/cuda/fake_quant/opr_impl.h" | |||
| #include "src/cuda/tqt/opr_impl.h" | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| @@ -0,0 +1,30 @@ | |||
| /** | |||
| * \file dnn/src/cuda/lsq/kern.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "./kern.cuh" | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| #define cb(_dtype) \ | |||
| INST_RUN_ELEMWISE(LSQKernOp<DTypeTrait<_dtype>::ctype>, \ | |||
| DTypeTrait<_dtype>::ctype, 3); \ | |||
| INST_RUN_ELEMWISE(LSQBwdKernOp<DTypeTrait<_dtype>::ctype>, \ | |||
| DTypeTrait<_dtype>::ctype, 3); \ | |||
| INST_RUN_ELEMWISE(LSQKernOpNonContig<DTypeTrait<_dtype>::ctype>, \ | |||
| DTypeTrait<_dtype>::ctype, 5); \ | |||
| INST_RUN_ELEMWISE(LSQBwdKernOpNonContig<DTypeTrait<_dtype>::ctype>, \ | |||
| DTypeTrait<_dtype>::ctype, 7); | |||
| cb(megdnn::dtype::Float32) | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| @@ -0,0 +1,126 @@ | |||
| /** | |||
| * \file dnn/src/cuda/lsq/kern.cuh | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "src/cuda/elemwise_helper.cuh" | |||
| #include "src/cuda/utils.cuh" | |||
| #if MEGDNN_CC_HOST | |||
| #include "megdnn/oprs.h" | |||
| #endif | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| template <typename ctype> | |||
| struct LSQKernOp { | |||
| ctype* input; | |||
| ctype* output; | |||
| ctype qmin, qmax; | |||
| __device__ void operator()(uint32_t idx, ctype scale, ctype zero_point, | |||
| ctype grad_scale) { | |||
| ctype x = input[idx] / scale + zero_point; | |||
| x = fmaxf(fminf(x, qmax), qmin); | |||
| x = round(x); | |||
| output[idx] = (x - zero_point) * scale; | |||
| } | |||
| #if MEGDNN_CC_HOST | |||
| LSQKernOp(const TensorND& input, const TensorND& output, | |||
| const LSQ::Param& param) | |||
| : input{input.ptr<ctype>()}, | |||
| output{output.ptr<ctype>()}, | |||
| qmin(param.qmin), | |||
| qmax(param.qmax) {} | |||
| #endif | |||
| }; | |||
| template <typename ctype> | |||
| struct LSQBwdKernOp { | |||
| ctype* diff; | |||
| ctype* input; | |||
| ctype* grad_x; | |||
| ctype* grad_s; | |||
| ctype qmin, qmax; | |||
| __device__ void operator()(uint32_t idx, ctype scale, ctype zero_point, | |||
| ctype grad_scale) { | |||
| ctype x = input[idx] / scale + zero_point; | |||
| bool ind_small = x < qmin; | |||
| bool ind_big = x > qmax; | |||
| bool ind_middle = ind_small ^ ind_big; | |||
| ind_middle = !ind_middle; | |||
| grad_s[idx] = ind_small * qmin + ind_big * qmax + | |||
| ind_middle * (-x + round(x)); | |||
| grad_s[idx] = grad_s[idx] * grad_scale * diff[idx]; | |||
| grad_x[idx] = ind_middle * diff[idx]; | |||
| } | |||
| #if MEGDNN_CC_HOST | |||
| LSQBwdKernOp(const TensorND& diff, const TensorND& input, | |||
| const TensorND& grad_x, const TensorND& grad_s, | |||
| const LSQ::Param& param) | |||
| : diff{diff.ptr<ctype>()}, | |||
| input{input.ptr<ctype>()}, | |||
| grad_x{grad_x.ptr<ctype>()}, | |||
| grad_s{grad_s.ptr<ctype>()}, | |||
| qmin(param.qmin), | |||
| qmax(param.qmax) {} | |||
| #endif | |||
| }; | |||
| template <typename ctype> | |||
| struct LSQKernOpNonContig { | |||
| ctype qmin; | |||
| ctype qmax; | |||
| __device__ void operator()(uint32_t, ctype& output, ctype& input, | |||
| ctype& scale, ctype& zero_point, | |||
| ctype grad_scale) { | |||
| ctype x = input / scale + zero_point; | |||
| x = fmaxf(fminf(x, qmax), qmin); | |||
| x = round(x); | |||
| output = (x - zero_point) * scale; | |||
| } | |||
| #if MEGDNN_CC_HOST | |||
| LSQKernOpNonContig(const LSQ::Param& param) | |||
| : qmin(param.qmin), qmax(param.qmax) {} | |||
| #endif | |||
| }; | |||
| template <typename ctype> | |||
| struct LSQBwdKernOpNonContig { | |||
| ctype qmin; | |||
| ctype qmax; | |||
| __device__ void operator()(uint32_t, ctype& grad_x, ctype& grad_s, | |||
| ctype& diff, ctype& input, ctype& scale, | |||
| ctype& zero_point, ctype grad_scale) { | |||
| ctype x = input / scale + zero_point; | |||
| bool ind_small = x < qmin; | |||
| bool ind_big = x > qmax; | |||
| bool ind_middle = ind_small ^ ind_big; | |||
| ind_middle = !ind_middle; | |||
| grad_s = ind_small * qmin + ind_big * qmax + | |||
| ind_middle * (-x + round(x)); | |||
| grad_s = grad_s * grad_scale * diff; | |||
| grad_x = ind_middle * diff; | |||
| } | |||
| #if MEGDNN_CC_HOST | |||
| LSQBwdKernOpNonContig(const LSQ::Param& param) | |||
| : qmin(param.qmin), qmax(param.qmax) {} | |||
| #endif | |||
| }; | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| @@ -0,0 +1,151 @@ | |||
| /** | |||
| * \file dnn/src/cuda/lsq/opr_impl.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "./opr_impl.h" | |||
| #include "./kern.cuh" | |||
| #include "src/common/utils.h" | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| void LSQForwardImpl::exec(_megdnn_tensor_in input, _megdnn_tensor_in scale, | |||
| _megdnn_tensor_in zero_point, | |||
| _megdnn_tensor_in grad_scale, | |||
| _megdnn_tensor_out output, | |||
| _megdnn_workspace workspace) { | |||
| check_exec(input.layout, scale.layout, zero_point.layout, grad_scale.layout, | |||
| output.layout, workspace.size); | |||
| if (!input.layout.is_contiguous() || !output.layout.is_contiguous()) | |||
| return exec_noncontig(input, scale, zero_point, grad_scale, output); | |||
| ElemwiseOpParamN<3> ele_param; | |||
| ele_param[0] = scale; | |||
| ele_param[0].layout = ele_param[0].layout.broadcast(input.layout); | |||
| ele_param[1] = zero_point; | |||
| ele_param[1].layout = ele_param[1].layout.broadcast(input.layout); | |||
| ele_param[2] = grad_scale; | |||
| ele_param[2].layout = ele_param[2].layout.broadcast(input.layout); | |||
| ele_param.init_from_given_tensor(); | |||
| auto m_param = param(); | |||
| auto stream = cuda_stream(handle()); | |||
| #define cb(DType) \ | |||
| if (input.layout.dtype == DType()) { \ | |||
| using T = typename DTypeTrait<DType>::ctype; \ | |||
| run_elemwise<LSQKernOp<T>, T, 3>(ele_param, stream, \ | |||
| {input, output, m_param}); \ | |||
| return; \ | |||
| } | |||
| cb(megdnn::dtype::Float32) | |||
| #undef cb | |||
| } | |||
| void LSQForwardImpl::exec_noncontig(_megdnn_tensor_in input, | |||
| _megdnn_tensor_in scale, | |||
| _megdnn_tensor_in zero_point, | |||
| _megdnn_tensor_in grad_scale, | |||
| _megdnn_tensor_out output) { | |||
| ElemwiseOpParamN<5> ele_param; | |||
| ele_param[0] = output; | |||
| ele_param[1] = input; | |||
| ele_param[2] = scale; | |||
| ele_param[2].layout = ele_param[2].layout.broadcast(input.layout); | |||
| ele_param[3] = zero_point; | |||
| ele_param[3].layout = ele_param[3].layout.broadcast(input.layout); | |||
| ele_param[4] = grad_scale; | |||
| ele_param[4].layout = ele_param[4].layout.broadcast(input.layout); | |||
| ele_param.init_from_given_tensor(); | |||
| auto m_param = param(); | |||
| auto stream = cuda_stream(handle()); | |||
| #define cb(DType) \ | |||
| if (input.layout.dtype == DType()) { \ | |||
| using T = typename DTypeTrait<DType>::ctype; \ | |||
| run_elemwise<LSQKernOpNonContig<T>, T, 5>(ele_param, stream, \ | |||
| {m_param}); \ | |||
| return; \ | |||
| } | |||
| cb(megdnn::dtype::Float32) | |||
| #undef cb | |||
| } | |||
| void LSQBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_in input, | |||
| _megdnn_tensor_in scale, | |||
| _megdnn_tensor_in zero_point, | |||
| _megdnn_tensor_in grad_scale, | |||
| _megdnn_tensor_out grad_x, _megdnn_tensor_out grad_s, | |||
| _megdnn_workspace workspace) { | |||
| check_exec(diff.layout, input.layout, scale.layout, zero_point.layout, | |||
| grad_scale.layout, grad_x.layout, grad_s.layout, workspace.size); | |||
| if (!input.layout.is_contiguous() || !diff.layout.is_contiguous() || | |||
| !grad_x.layout.is_contiguous() || !grad_s.layout.is_contiguous()) | |||
| return exec_noncontig(diff, input, scale, zero_point, grad_scale, | |||
| grad_x, grad_s); | |||
| ElemwiseOpParamN<3> ele_param; | |||
| ele_param[0] = scale; | |||
| ele_param[0].layout = ele_param[0].layout.broadcast(input.layout); | |||
| ele_param[1] = zero_point; | |||
| ele_param[1].layout = ele_param[1].layout.broadcast(input.layout); | |||
| ele_param[2] = grad_scale; | |||
| ele_param[2].layout = ele_param[2].layout.broadcast(input.layout); | |||
| ele_param.init_from_given_tensor(); | |||
| auto m_param = param(); | |||
| auto stream = cuda_stream(handle()); | |||
| #define cb(DType) \ | |||
| if (grad_x.layout.dtype == DType()) { \ | |||
| using T = typename DTypeTrait<DType>::ctype; \ | |||
| run_elemwise<LSQBwdKernOp<T>, T, 3>( \ | |||
| ele_param, stream, {diff, input, grad_x, grad_s, m_param}); \ | |||
| return; \ | |||
| } | |||
| cb(megdnn::dtype::Float32) | |||
| #undef cb | |||
| } | |||
| void LSQBackwardImpl::exec_noncontig(_megdnn_tensor_in diff, | |||
| _megdnn_tensor_in input, | |||
| _megdnn_tensor_in scale, | |||
| _megdnn_tensor_in zero_point, | |||
| _megdnn_tensor_in grad_scale, | |||
| _megdnn_tensor_out grad_x, | |||
| _megdnn_tensor_out grad_s) { | |||
| ElemwiseOpParamN<7> ele_param; | |||
| ele_param[0] = grad_x; | |||
| ele_param[1] = grad_s; | |||
| ele_param[2] = diff; | |||
| ele_param[3] = input; | |||
| ele_param[4] = scale; | |||
| ele_param[4].layout = ele_param[4].layout.broadcast(input.layout); | |||
| ele_param[5] = zero_point; | |||
| ele_param[5].layout = ele_param[5].layout.broadcast(input.layout); | |||
| ele_param[6] = grad_scale; | |||
| ele_param[6].layout = ele_param[6].layout.broadcast(input.layout); | |||
| ele_param.init_from_given_tensor(); | |||
| auto m_param = param(); | |||
| auto stream = cuda_stream(handle()); | |||
| #define cb(DType) \ | |||
| if (input.layout.dtype == DType()) { \ | |||
| using T = typename DTypeTrait<DType>::ctype; \ | |||
| run_elemwise<LSQBwdKernOpNonContig<T>, T, 7>(ele_param, stream, \ | |||
| {m_param}); \ | |||
| return; \ | |||
| } | |||
| cb(megdnn::dtype::Float32) | |||
| #undef cb | |||
| } | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| @@ -0,0 +1,65 @@ | |||
| /** | |||
| * \file dnn/src/cuda/lsq/opr_impl.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "megdnn/oprs.h" | |||
| #include "src/cuda/utils.h" | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| class LSQForwardImpl final : public LSQForward { | |||
| public: | |||
| using LSQForward::LSQForward; | |||
| void exec(_megdnn_tensor_in input, _megdnn_tensor_in scale, | |||
| _megdnn_tensor_in zero_point, _megdnn_tensor_in grad_scale, | |||
| _megdnn_tensor_out output, _megdnn_workspace workspace) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout&, /* input */ | |||
| const TensorLayout&, /* scale */ | |||
| const TensorLayout&, /* zero_point */ | |||
| const TensorLayout&, /* grad_scale */ | |||
| const TensorLayout& /* output */) override { | |||
| return 0; | |||
| } | |||
| private: | |||
| void exec_noncontig(_megdnn_tensor_in input, _megdnn_tensor_in scale, | |||
| _megdnn_tensor_in zero_point, | |||
| _megdnn_tensor_in grad_scale, | |||
| _megdnn_tensor_out output); | |||
| }; | |||
| class LSQBackwardImpl final : public LSQBackward { | |||
| public: | |||
| using LSQBackward::LSQBackward; | |||
| void exec(_megdnn_tensor_in diff, _megdnn_tensor_in input, | |||
| _megdnn_tensor_in scale, _megdnn_tensor_in zero_point, | |||
| _megdnn_tensor_in grad_scale, _megdnn_tensor_out grad_x, | |||
| _megdnn_tensor_out grad_s, _megdnn_workspace workspace) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout& /* diff */, | |||
| const TensorLayout& /* input */, | |||
| const TensorLayout& /* scale */, | |||
| const TensorLayout& /* zero_point */, | |||
| const TensorLayout& /* grad_scale */, | |||
| const TensorLayout& /* grad_x */, | |||
| const TensorLayout& /* grad_s */) override { | |||
| return 0; | |||
| } | |||
| private: | |||
| void exec_noncontig(_megdnn_tensor_in diff, _megdnn_tensor_in input, | |||
| _megdnn_tensor_in scale, _megdnn_tensor_in zero_point, | |||
| _megdnn_tensor_in grad_scale, _megdnn_tensor_out grad_x, | |||
| _megdnn_tensor_out grad_s); | |||
| }; | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| @@ -50,6 +50,7 @@ | |||
| #include "src/naive/local/opr_impl.h" | |||
| #include "src/naive/local_share/opr_impl.h" | |||
| #include "src/naive/lrn/opr_impl.h" | |||
| #include "src/naive/lsq/opr_impl.h" | |||
| #include "src/naive/mask_conv/opr_impl.h" | |||
| #include "src/naive/matrix_inverse/opr_impl.h" | |||
| #include "src/naive/matrix_mul/opr_impl.h" | |||
| @@ -0,0 +1,141 @@ | |||
| /** | |||
| * \file dnn/src/naive/lsq/opr_impl.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/naive/lsq/opr_impl.h" | |||
| #include <cmath> | |||
| #include "megdnn/tensor_iter.h" | |||
| #include "src/common/elemwise_helper.cuh" | |||
| #include "src/common/utils.h" | |||
| #include "src/naive/handle.h" | |||
| namespace { | |||
| using namespace megdnn; | |||
| template <typename T> | |||
| void forward_impl(const ElemwiseOpParamN<5> src, float qmin, float qmax) { | |||
| auto inp = tensor_iter_valonly<T>(src[0]).begin(); | |||
| auto out = tensor_iter_valonly<T>(src[1]).begin(); | |||
| auto scale = tensor_iter_valonly<T>(src[2]).begin(); | |||
| auto zero_point = tensor_iter_valonly<T>(src[3]).begin(); | |||
| auto grad_scale = tensor_iter_valonly<T>(src[4]).begin(); | |||
| size_t total = src[0].layout.total_nr_elems(); | |||
| for (size_t i = 0; i < total; ++i) { | |||
| T x = (*inp) / (*scale) + (*zero_point); | |||
| x = x <= qmin ? qmin : x; | |||
| x = x >= qmax ? qmax : x; | |||
| x = round(x); | |||
| *out = (x - (*zero_point)) * (*scale); | |||
| ++inp; | |||
| ++out; | |||
| ++scale; | |||
| ++zero_point; | |||
| ++grad_scale; | |||
| } | |||
| } | |||
| template <typename T> | |||
| void backward_impl(const ElemwiseOpParamN<7> src, float qmin, float qmax) { | |||
| auto diff = tensor_iter_valonly<T>(src[0]).begin(); | |||
| auto input = tensor_iter_valonly<T>(src[1]).begin(); | |||
| auto scale = tensor_iter_valonly<T>(src[2]).begin(); | |||
| auto zero_point = tensor_iter_valonly<T>(src[3]).begin(); | |||
| auto grad_scale = tensor_iter_valonly<T>(src[4]).begin(); | |||
| auto grad_x = tensor_iter_valonly<T>(src[5]).begin(); | |||
| auto grad_s = tensor_iter_valonly<T>(src[6]).begin(); | |||
| size_t total = src[0].layout.total_nr_elems(); | |||
| for (size_t i = 0; i < total; ++i) { | |||
| T x = (*input) / (*scale) + (*zero_point); | |||
| bool ind_small = x < qmin; | |||
| bool ind_big = x > qmax; | |||
| bool ind_middle = ind_small ^ ind_big; | |||
| ind_middle = !ind_middle; | |||
| *grad_s = ind_small * qmin + ind_big * qmax + | |||
| ind_middle * (-x + round(x)); | |||
| *grad_s = (*grad_s) * (*grad_scale) * (*diff); | |||
| *grad_x = ind_middle * (*diff); | |||
| ++diff; | |||
| ++input; | |||
| ++scale; | |||
| ++zero_point; | |||
| ++grad_scale; | |||
| ++grad_x; | |||
| ++grad_s; | |||
| } | |||
| } | |||
| } // namespace | |||
| namespace megdnn { | |||
| namespace naive { | |||
| void LSQForwardImpl::exec(_megdnn_tensor_in input, _megdnn_tensor_in scale, | |||
| _megdnn_tensor_in zero_point, | |||
| _megdnn_tensor_in grad_scale, | |||
| _megdnn_tensor_out output, | |||
| _megdnn_workspace workspace) { | |||
| check_exec(input.layout, scale.layout, zero_point.layout, grad_scale.layout, | |||
| output.layout, workspace.size); | |||
| ElemwiseOpParamN<5> src; | |||
| src[0] = input; | |||
| src[1] = output; | |||
| src[2] = scale; | |||
| src[2].layout = src[2].layout.broadcast(input.layout); | |||
| src[3] = zero_point; | |||
| src[3].layout = src[3].layout.broadcast(input.layout); | |||
| src[4] = grad_scale; | |||
| src[4].layout = src[4].layout.broadcast(input.layout); | |||
| #define cb(DType) \ | |||
| if (input.layout.dtype == DType()) { \ | |||
| using T = typename DTypeTrait<DType>::ctype; \ | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||
| forward_impl<T>(src, param().qmin, param().qmax)); \ | |||
| return; \ | |||
| } | |||
| cb(dtype::Float32) | |||
| #undef cb | |||
| } | |||
| void LSQBackwardImpl::exec(_megdnn_tensor_in diff, _megdnn_tensor_in input, | |||
| _megdnn_tensor_in scale, | |||
| _megdnn_tensor_in zero_point, | |||
| _megdnn_tensor_in grad_scale, | |||
| _megdnn_tensor_out grad_x, _megdnn_tensor_out grad_s, | |||
| _megdnn_workspace workspace) { | |||
| check_exec(diff.layout, input.layout, scale.layout, zero_point.layout, | |||
| grad_scale.layout, grad_x.layout, grad_s.layout, workspace.size); | |||
| ElemwiseOpParamN<7> src; | |||
| src[0] = diff; | |||
| src[1] = input; | |||
| src[2] = scale; | |||
| src[2].layout = src[2].layout.broadcast(input.layout); | |||
| src[3] = zero_point; | |||
| src[3].layout = src[3].layout.broadcast(input.layout); | |||
| src[4] = grad_scale; | |||
| src[4].layout = src[4].layout.broadcast(input.layout); | |||
| src[5] = grad_x; | |||
| src[6] = grad_s; | |||
| #define cb(DType) \ | |||
| if (diff.layout.dtype == DType() && grad_x.layout.dtype == DType() && \ | |||
| input.layout.dtype == DType()) { \ | |||
| using T = typename DTypeTrait<DType>::ctype; \ | |||
| MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||
| backward_impl<T>(src, param().qmin, param().qmax)); \ | |||
| return; \ | |||
| } | |||
| cb(dtype::Float32) | |||
| #undef cb | |||
| } | |||
| } // namespace naive | |||
| } // namespace megdnn | |||
| @@ -0,0 +1,53 @@ | |||
| /** | |||
| * \file dnn/src/naive/lsq/opr_impl.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "megdnn/oprs.h" | |||
| namespace megdnn { | |||
| namespace naive { | |||
| class LSQForwardImpl final : public LSQForward { | |||
| public: | |||
| using LSQForward::LSQForward; | |||
| void exec(_megdnn_tensor_in input, _megdnn_tensor_in scale, | |||
| _megdnn_tensor_in zero_point, _megdnn_tensor_in grad_scale, | |||
| _megdnn_tensor_out output, _megdnn_workspace workspace) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout& /* input */, | |||
| const TensorLayout& /* scale */, | |||
| const TensorLayout& /* zero_point */, | |||
| const TensorLayout& /* grad_scale */, | |||
| const TensorLayout& /* output */) override { | |||
| return 0; | |||
| } | |||
| }; | |||
| class LSQBackwardImpl final : public LSQBackward { | |||
| public: | |||
| using LSQBackward::LSQBackward; | |||
| void exec(_megdnn_tensor_in diff, _megdnn_tensor_in input, | |||
| _megdnn_tensor_in scale, _megdnn_tensor_in zero_point, | |||
| _megdnn_tensor_in grad_scale, _megdnn_tensor_out grad_x, | |||
| _megdnn_tensor_out grad_s, _megdnn_workspace workspace) override; | |||
| size_t get_workspace_in_bytes(const TensorLayout& /* diff */, | |||
| const TensorLayout& /* input */, | |||
| const TensorLayout& /* scale */, | |||
| const TensorLayout& /* zero_point */, | |||
| const TensorLayout& /* grad_scale */, | |||
| const TensorLayout& /* grad_x */, | |||
| const TensorLayout& /* grad_s */) override { | |||
| return 0; | |||
| } | |||
| }; | |||
| } // namespace naive | |||
| } // namespace megdnn | |||
| @@ -0,0 +1,53 @@ | |||
| /** | |||
| * \file dnn/test/common/lsq.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "megdnn/basic_types.h" | |||
| #include "megdnn/opr_param_defs.h" | |||
| namespace megdnn { | |||
| namespace test { | |||
| namespace lsq { | |||
| struct TestArg { | |||
| param::LSQ param; | |||
| TensorShape ishape; | |||
| TensorShape scale_shape; | |||
| TensorShape zeropoint_shape; | |||
| TensorShape gradscale_shape; | |||
| TestArg(param::LSQ param, TensorShape ishape, TensorShape scale_shape, | |||
| TensorShape zeropoint_shape, TensorShape gradscale_shape) | |||
| : param(param), | |||
| ishape(ishape), | |||
| scale_shape(scale_shape), | |||
| zeropoint_shape(zeropoint_shape), | |||
| gradscale_shape(gradscale_shape) {} | |||
| }; | |||
| inline std::vector<TestArg> get_args() { | |||
| std::vector<TestArg> args; | |||
| param::LSQ cur_param; | |||
| cur_param.qmin = -127; | |||
| cur_param.qmax = 127; | |||
| for (size_t i = 10; i < 30; i += 2) { | |||
| args.emplace_back(cur_param, TensorShape{10, 64, i, i}, TensorShape{1}, | |||
| TensorShape{1}, TensorShape{1}); | |||
| } | |||
| return args; | |||
| } | |||
| } // namespace lsq | |||
| } // namespace test | |||
| } // namespace megdnn | |||
| @@ -0,0 +1,110 @@ | |||
| /** | |||
| * \file dnn/test/cuda/lsq.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "test/common/lsq.h" | |||
| #include "megdnn/oprs.h" | |||
| #include "test/common/checker.h" | |||
| #include "test/cuda/fixture.h" | |||
| namespace megdnn { | |||
| namespace test { | |||
| using namespace lsq; | |||
| TEST_F(CUDA, LSQ) { | |||
| std::vector<TestArg> args = get_args(); | |||
| auto dtype = dtype::Float32(); | |||
| for (auto&& arg : args) { | |||
| auto param = arg.param; | |||
| auto ishape = arg.ishape; | |||
| auto scale_shape = arg.scale_shape; | |||
| auto zeropoint_shape = arg.zeropoint_shape; | |||
| auto gradscale_shape = arg.gradscale_shape; | |||
| Checker<LSQForward> checker(handle_cuda()); | |||
| checker.set_param(param) | |||
| .set_dtype(0, dtype) | |||
| .set_dtype(1, dtype) | |||
| .set_dtype(2, dtype) | |||
| .set_dtype(3, dtype) | |||
| .set_dtype(4, dtype) | |||
| .execs({ishape, scale_shape, zeropoint_shape, gradscale_shape, | |||
| ishape}); | |||
| } | |||
| // test noncontiguous layout | |||
| for (auto&& arg : args) { | |||
| auto param = arg.param; | |||
| auto ishape = arg.ishape; | |||
| auto sshape = arg.scale_shape; | |||
| auto zeropoint_shape = arg.zeropoint_shape; | |||
| auto gradscale_shape = arg.gradscale_shape; | |||
| Checker<LSQForward> checker(handle_cuda()); | |||
| TensorLayout ilayout( | |||
| ishape, | |||
| {(long int)(ishape[1] * ishape[2] * ishape[3] * 2), | |||
| (long int)(ishape[2] * ishape[3]), (long int)ishape[3], 1}, | |||
| dtype::Float32()); | |||
| checker.set_param(param).execl({ilayout, | |||
| {sshape, dtype::Float32()}, | |||
| {zeropoint_shape, dtype::Float32()}, | |||
| {gradscale_shape, dtype::Float32()}, | |||
| ilayout}); | |||
| } | |||
| } | |||
| TEST_F(CUDA, LSQ_BACKWARD) { | |||
| std::vector<TestArg> args = get_args(); | |||
| auto dtype = dtype::Float32(); | |||
| for (auto&& arg : args) { | |||
| auto param = arg.param; | |||
| auto ishape = arg.ishape; | |||
| auto scale_shape = arg.scale_shape; | |||
| auto zeropoint_shape = arg.zeropoint_shape; | |||
| auto gradscale_shape = arg.gradscale_shape; | |||
| Checker<LSQBackward> checker(handle_cuda()); | |||
| checker.set_param(param) | |||
| .set_dtype(0, dtype) | |||
| .set_dtype(1, dtype) | |||
| .set_dtype(2, dtype) | |||
| .set_dtype(3, dtype) | |||
| .set_dtype(4, dtype) | |||
| .set_dtype(5, dtype) | |||
| .set_dtype(6, dtype) | |||
| .execs({ishape, ishape, scale_shape, zeropoint_shape, | |||
| gradscale_shape, ishape, ishape}); | |||
| } | |||
| // test noncontiguous layout | |||
| for (auto&& arg : args) { | |||
| auto param = arg.param; | |||
| auto ishape = arg.ishape; | |||
| auto sshape = arg.scale_shape; | |||
| auto zeropoint_shape = arg.zeropoint_shape; | |||
| auto gradscale_shape = arg.gradscale_shape; | |||
| Checker<LSQBackward> checker(handle_cuda()); | |||
| TensorLayout ilayout( | |||
| ishape, | |||
| {(long int)(ishape[1] * ishape[2] * ishape[3] * 2), | |||
| (long int)(ishape[2] * ishape[3]), (long int)ishape[3], 1}, | |||
| dtype::Float32()); | |||
| checker.set_param(param).execl({ilayout, | |||
| ilayout, | |||
| {sshape, dtype::Float32()}, | |||
| {zeropoint_shape, dtype::Float32()}, | |||
| {gradscale_shape, dtype::Float32()}, | |||
| ilayout, | |||
| ilayout}); | |||
| } | |||
| } | |||
| } // namespace test | |||
| } // namespace megdnn | |||
| @@ -0,0 +1,45 @@ | |||
| /** | |||
| * \file dnn/test/naive/sliding_window_transpose.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "test/naive/fixture.h" | |||
| #include "megdnn/oprs/nn.h" | |||
| #include "test/common/checker.h" | |||
| using namespace megdnn; | |||
| using namespace test; | |||
| TEST_F(NAIVE, LSQ_FORWARD) { | |||
| Checker<LSQ> checker(handle(), /* check_dispatch */ false); | |||
| param::LSQ param; | |||
| param.qmin = -127; | |||
| param.qmax = 127; | |||
| TensorND input = | |||
| TensorValue({2, 2, 2, 2}, dtype::Float32(), | |||
| {0, 1, 3, 4, 1, 2, 4, 5, 3, 4, 6, 7, 4, 5, 7, 8}); | |||
| TensorND scale_shape = TensorValue({1}, dtype::Float32(), {2}); | |||
| TensorND zero_point = TensorValue({1}, dtype::Float32(), {1}); | |||
| TensorND grad_scale = TensorValue({1}, dtype::Float32(), {0.5}); | |||
| TensorND output = | |||
| TensorValue({2, 2, 2, 2}, dtype::Float32(), | |||
| {0, 2, 4, 4, 2, 2, 4, 6, 4, 4, 6, 8, 4, 6, 8, 8}); | |||
| checker.set_param(param).exect( | |||
| Testcase{input, scale_shape, zero_point, grad_scale, {}}, | |||
| Testcase{{}, {}, {}, {}, output}); | |||
| } | |||
| @@ -6,7 +6,7 @@ | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| from .fake_quant import TQT, FakeQuantize | |||
| from .fake_quant import LSQ, TQT, FakeQuantize | |||
| from .observer import ( | |||
| ExponentialMovingAverageObserver, | |||
| HistogramObserver, | |||
| @@ -12,13 +12,15 @@ from .. import functional as F | |||
| from ..core.tensor.dtype import QuantDtypeMeta, _builtin_quant_dtypes | |||
| from ..logger import get_logger | |||
| from ..module import Module | |||
| from ..tensor import Parameter | |||
| from ..tensor import Parameter, Tensor | |||
| from .utils import ( | |||
| LSQParams, | |||
| QParams, | |||
| QParamsModuleMixin, | |||
| QuantMode, | |||
| create_qparams, | |||
| fake_quant_tensor, | |||
| lsq_forward, | |||
| tqt_forward, | |||
| ) | |||
| @@ -117,3 +119,58 @@ class FakeQuantize(_FakeQuantize): | |||
| qparams.dtype_meta, self.dtype | |||
| ) | |||
| return fake_quant_tensor(inp, qparams) | |||
| class LSQ(_FakeQuantize, QParamsModuleMixin): | |||
| r""" | |||
| LSQ: https://arxiv.org/pdf/1902.08153.pdf Estimating and scaling the | |||
| task loss gradient at each weight and activation layer's quantizer step size | |||
| :param dtype: a string or :class:`~.QuantDtypeMeta` indicating the target | |||
| quantization dtype of input. | |||
| :param enable: whether do ``normal_forward`` or ``fake_quant_forward``. | |||
| :param eps:a small value to avoid division by zero. Default: 1e-5 | |||
| """ | |||
| def init( | |||
| self, | |||
| dtype: Union[str, QuantDtypeMeta], | |||
| enable: bool = True, | |||
| eps: float = 1e-5, | |||
| **kwargs | |||
| ): | |||
| super().__init__(dtype=dtype, enable=enable, **kwargs) | |||
| self.eps = Tensor(eps, dtype="float32") | |||
| self.step_size = Parameter(1.0, dtype="float32") | |||
| def set_qparams(self, qparams: LSQParams): | |||
| self.mode = qparams.mode | |||
| if qparams.mode == QuantMode.ASYMMERTIC: | |||
| self.zero_point = qparams.zero_point | |||
| else: | |||
| self.zero_point = Tensor([0.0], dtype="float32") | |||
| if qparams.scale is None: | |||
| raise AssertionError("Can not get an initialized scale") | |||
| init_step_size = qparams.scale | |||
| if init_step_size < self.eps: | |||
| init_step_size = 0 | |||
| else: | |||
| init_step_size = init_step_size - self.eps | |||
| self.step_size = Parameter(init_step_size, dtype="float32") | |||
| self.grad_scale = qparams.grad_scale | |||
| def fake_quant_forward(self, inp, qparams: LSQParams = None): | |||
| step_size = F.abs(self.step_size) + self.eps | |||
| return lsq_forward( | |||
| self.qmin, self.qmax, inp, step_size, self.zero_point, self.grad_scale | |||
| ) | |||
| def get_qparams(self): | |||
| return LSQParams( | |||
| mode=self.mode, | |||
| dtype_meta=self.dtype, | |||
| scale=F.abs(self.step_size.detach()) + self.eps, | |||
| zero_point=self.zero_point, | |||
| grad_scale=self.grad_scale, | |||
| ) | |||
| @@ -43,6 +43,16 @@ def tqt_forward(qmin, qmax, inp, scale): | |||
| return output | |||
| def lsq_forward(qmin, qmax, inp, step_size, zero_point=None, scale_grad=None): | |||
| if zero_point is None: | |||
| zero_point = Tensor([0.0], dtype=np.float32) | |||
| if scale_grad is None: | |||
| scale_grad = Tensor([1.0], dtype=np.float32) | |||
| op = builtin.LSQ(qmin=qmin, qmax=qmax) | |||
| (output,) = apply(op, inp, step_size, zero_point, scale_grad) | |||
| return output | |||
| def register_method_to_class(cls): | |||
| def decorator(func): | |||
| @wraps(func) | |||
| @@ -105,6 +115,47 @@ class QParams: | |||
| return "QParams({})".format(content) | |||
| class LSQParams: | |||
| """ | |||
| To standardize LSQ's qparams format. If custom | |||
| qparams is needed, inherit this class and add custom ``__slots__``. | |||
| """ | |||
| __slots__ = "mode", "dtype_meta", "scale", "zero_point", "grad_scale" | |||
| def __init__( | |||
| self, | |||
| mode: QuantMode, | |||
| dtype_meta: QuantDtypeMeta, | |||
| scale: Tensor, | |||
| zero_point: Tensor, | |||
| grad_scale: Tensor, | |||
| ): | |||
| self.mode = mode | |||
| self.dtype_meta = dtype_meta | |||
| self.scale = scale | |||
| self.zero_point = zero_point | |||
| self.grad_scale = grad_scale | |||
| def update(self, lsqparams: "LSQParams"): | |||
| for key in self.__slots__: | |||
| setattr(self, key, getattr(lsqparams, key)) | |||
| def __eq__(self, other): | |||
| if len(self.__slots__) != len(other.__slots__): | |||
| return False | |||
| for key in self.__slots__: | |||
| if not hasattr(other, key) or getattr(self, key) != getattr(other, key): | |||
| return False | |||
| return True | |||
| def __repr__(self): | |||
| content = ", ".join( | |||
| ["{}={}".format(key, getattr(self, key)) for key in self.__slots__] | |||
| ) | |||
| return "LSQParams({})".format(content) | |||
| class QParamsModuleMixin(abc.ABC): | |||
| def get_quantized_dtype(self): | |||
| qparams = self.get_qparams() | |||
| @@ -10,6 +10,7 @@ import numpy as np | |||
| import pytest | |||
| import megengine as mge | |||
| import megengine.functional as F | |||
| from megengine import tensor | |||
| from megengine.core.autodiff.grad import Function, Grad | |||
| from megengine.core.tensor.dtype import QuantDtypeMeta | |||
| @@ -19,6 +20,7 @@ from megengine.quantization.utils import ( | |||
| QuantMode, | |||
| create_qparams, | |||
| fake_quant_tensor, | |||
| lsq_forward, | |||
| tqt_forward, | |||
| ) | |||
| @@ -150,3 +152,78 @@ def test_fakequant(): | |||
| zero_point = tensor(1.0 * np.ones((1, 32, 1, 1)), dtype=np.float32) | |||
| scale = tensor(4.0 * np.ones((1, 32, 1, 1)), dtype=np.float32) | |||
| run(zero_point, scale) | |||
| class LSQ_numpy: | |||
| def __init__(self, lowerbound, upperbound): | |||
| super().__init__() | |||
| self.lowerbound = lowerbound | |||
| self.upperbound = upperbound | |||
| def forward(self, inp, scale, zero_point, grad_scale): | |||
| inp_scaled = inp / scale + zero_point | |||
| inp_clipped = np.maximum( | |||
| np.minimum(inp_scaled, self.upperbound), self.lowerbound | |||
| ) | |||
| inp_rounded = np.floor(inp_clipped + 0.5) | |||
| inp_flq = (inp_rounded - zero_point) * scale | |||
| self.saved_tensors = (inp_scaled, inp_rounded, scale, grad_scale) | |||
| return inp_flq | |||
| def backward(self, grad_inp_flq): | |||
| (inp_scaled, inp_rounded, scale, grad_scale) = self.saved_tensors | |||
| ind_small = inp_scaled < self.lowerbound | |||
| ind_big = inp_scaled > self.upperbound | |||
| ind_middle = np.logical_xor(ind_small, ind_big) | |||
| ind_middle = np.abs(ind_middle - 1) | |||
| grad_s = ( | |||
| ind_small * self.lowerbound | |||
| + ind_big * self.upperbound | |||
| + ind_middle * (-inp_scaled + inp_rounded) | |||
| ) | |||
| grad_s = grad_s * grad_scale * grad_inp_flq | |||
| grad_s = grad_s.sum() | |||
| grad_inp = grad_inp_flq * ind_middle | |||
| return grad_inp, grad_s | |||
| def test_lsq(): | |||
| def preprocess(scale, eps): | |||
| scale = np.array([0]) if scale < eps else scale - eps | |||
| return np.abs(scale) + eps | |||
| g = [] | |||
| def cb(grad): | |||
| g.append(grad) | |||
| x = np.random.randint(-128, 128, size=(1, 2, 3, 4)).astype("float32") | |||
| s = np.random.rand(1) | |||
| eps = np.array([1e-5], dtype="float32") | |||
| s = preprocess(s, eps) | |||
| zero_point = np.array([1.0], dtype="float32") | |||
| grad_s = np.array([2.0], dtype="float32") | |||
| g_y = np.ones(shape=(1, 2, 3, 4), dtype="float32") | |||
| n = LSQ_numpy(-127, 127) | |||
| y_np = n.forward(x, s, zero_point, grad_s) | |||
| g_x_np, g_s_np = n.backward(g_y) | |||
| x = mge.tensor(x, dtype="float32") | |||
| s = mge.tensor(s, dtype="float32") | |||
| zero_point = mge.tensor(zero_point, dtype="float32") | |||
| grad_s = mge.tensor(grad_s, dtype="float32") | |||
| g_y = mge.tensor(g_y, dtype="float32") | |||
| grad = Grad().wrt(x, s, callback=cb) | |||
| y = lsq_forward(-127, 127, x, s, zero_point, grad_s) | |||
| grad(y, g_y) | |||
| g_x, g_s = g | |||
| np.testing.assert_allclose(y.numpy(), y_np, rtol=1e-7, atol=1e-7) | |||
| np.testing.assert_allclose(g_x.numpy(), g_x_np, rtol=1e-7, atol=1e-7) | |||
| np.testing.assert_allclose(g_s.numpy(), g_s_np, rtol=5e-7, atol=5e-7) | |||
| @@ -6,23 +6,26 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| // FIXME: split this file into separate files for each specialized op | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| #include "megbrain/opr/dnn/convolution.h" | |||
| #include "megbrain/opr/basic_arith.h" | |||
| #include "megbrain/opr/blas.h" | |||
| #include "megbrain/opr/dnn/adaptive_pooling.h" | |||
| #include "megbrain/opr/dnn/convolution.h" | |||
| #include "megbrain/opr/dnn/correlation.h" | |||
| #include "megbrain/opr/dnn/fake_quant.h" | |||
| #include "megbrain/opr/dnn/tqt.h" | |||
| #include "megbrain/opr/dnn/pooling.h" | |||
| #include "megbrain/opr/dnn/images2neibs.h" | |||
| #include "megbrain/opr/dnn/local.h" | |||
| #include "megbrain/opr/dnn/lsq.h" | |||
| #include "megbrain/opr/dnn/pooling.h" | |||
| #include "megbrain/opr/dnn/roi_align.h" | |||
| #include "megbrain/opr/dnn/correlation.h" | |||
| #include "megbrain/opr/dnn/roi_pooling.h" | |||
| #include "megbrain/opr/basic_arith.h" | |||
| #include "megbrain/opr/blas.h" | |||
| #include "megbrain/opr/dnn/tqt.h" | |||
| #include "megbrain/opr/imgproc.h" | |||
| #include "megbrain/opr/indexing.h" | |||
| #include "megbrain/opr/io.h" | |||
| @@ -32,40 +35,38 @@ | |||
| #include "megbrain/opr/tensor_gen.h" | |||
| #include "megbrain/opr/tensor_manip.h" | |||
| #include "megbrain/opr/utility.h" | |||
| #include "megbrain/opr/dnn/images2neibs.h" | |||
| #include "../op_trait.h" | |||
| namespace mgb::imperative { | |||
| namespace { namespace dimshuffle { | |||
| namespace { | |||
| namespace dimshuffle { | |||
| std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||
| auto* node = &node_->cast_final_safe<opr::Dimshuffle>(); | |||
| std::vector<int> pattern(node->param().pattern_len); | |||
| for (size_t i = 0; i < node->param().pattern_len; ++ i) { | |||
| for (size_t i = 0; i < node->param().pattern_len; ++i) { | |||
| pattern[i] = node->param().pattern[i]; | |||
| } | |||
| return Dimshuffle::make(pattern); | |||
| } | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& ds = static_cast<const Dimshuffle&>(def); | |||
| OperatorNodeConfig config{ds.make_name()}; | |||
| return opr::Dimshuffle::make(inputs[0], ds.pattern, 0UL, config); | |||
| } | |||
| OP_TRAIT_REG(Dimshuffle, Dimshuffle, opr::Dimshuffle) | |||
| .make_from_op_node(make_from_op_node) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // dimshuffle | |||
| .make_from_op_node(make_from_op_node) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| } // namespace dimshuffle | |||
| } // namespace | |||
| namespace { namespace add_axis { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| namespace { | |||
| namespace add_axis { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& add_axis = static_cast<const AddAxis&>(def); | |||
| using Desc = opr::AxisAddRemove::AxisDesc; | |||
| std::vector<Desc> param; | |||
| @@ -76,15 +77,13 @@ auto apply_on_var_node( | |||
| return opr::AxisAddRemove::make(inputs[0], param, config); | |||
| } | |||
| OP_TRAIT_REG(AddAxis, AddAxis) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // add_axis | |||
| OP_TRAIT_REG(AddAxis, AddAxis).apply_on_var_node(apply_on_var_node).fallback(); | |||
| } // namespace add_axis | |||
| } // namespace | |||
| namespace { namespace remove_axis { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| namespace { | |||
| namespace remove_axis { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& remove_axis = static_cast<const RemoveAxis&>(def); | |||
| using Desc = opr::AxisAddRemove::AxisDesc; | |||
| std::vector<Desc> param; | |||
| @@ -96,36 +95,35 @@ auto apply_on_var_node( | |||
| } | |||
| OP_TRAIT_REG(RemoveAxis, RemoveAxis) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // remove_axis | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| } // namespace remove_axis | |||
| } // namespace | |||
| namespace { namespace top_k { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| namespace { | |||
| namespace top_k { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& topk = static_cast<const TopK&>(def); | |||
| OperatorNodeConfig config{topk.make_name()}; | |||
| return opr::TopK::make(inputs[0], inputs[1], topk.param(), config)[0] | |||
| .node()->owner_opr(); | |||
| .node() | |||
| ->owner_opr(); | |||
| } | |||
| OP_TRAIT_REG(TopK, TopK) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // top_k | |||
| OP_TRAIT_REG(TopK, TopK).apply_on_var_node(apply_on_var_node).fallback(); | |||
| } // namespace top_k | |||
| } // namespace | |||
| namespace { namespace reduce { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| namespace { | |||
| namespace reduce { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& reduce = static_cast<const Reduce&>(def); | |||
| OperatorNodeConfig config{reduce.make_name()}; | |||
| if (inputs.size() > 1) { | |||
| return opr::Reduce::make(inputs[0], reduce.param(), inputs[1], config); | |||
| } else { | |||
| return opr::Reduce::make( | |||
| inputs[0], reduce.param(), (cg::VarNode*)nullptr, config); | |||
| return opr::Reduce::make(inputs[0], reduce.param(), | |||
| (cg::VarNode*)nullptr, config); | |||
| } | |||
| } | |||
| @@ -135,86 +133,92 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||
| } | |||
| OP_TRAIT_REG(Reduce, Reduce, opr::Reduce) | |||
| .make_from_op_node(make_from_op_node) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // reduce | |||
| .make_from_op_node(make_from_op_node) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| } // namespace reduce | |||
| } // namespace | |||
| namespace { namespace adaptive_pooling { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| namespace { | |||
| namespace adaptive_pooling { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& pool = static_cast<const AdaptivePooling&>(def); | |||
| OperatorNodeConfig config{pool.make_name()}; | |||
| return opr::AdaptivePooling::make(inputs[0], inputs[1], pool.param(), config); | |||
| return opr::AdaptivePooling::make(inputs[0], inputs[1], pool.param(), | |||
| config); | |||
| } | |||
| OP_TRAIT_REG(AdaptivePooling, AdaptivePooling) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // adaptive_pooling | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| } // namespace adaptive_pooling | |||
| } // namespace | |||
| namespace { namespace conv_bias { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| namespace { | |||
| namespace conv_bias { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& conv = static_cast<const ConvBias&>(def); | |||
| cg::OperatorNodeConfig config{conv.dtype}; | |||
| config.name(conv.make_name()); | |||
| if (inputs.size() == 2) { | |||
| return opr::ConvBias::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); | |||
| return opr::ConvBias::make(inputs[0], inputs[1], conv.param(), | |||
| conv.policy(), config); | |||
| } else if (inputs.size() == 3) { | |||
| return opr::ConvBias::make(inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config); | |||
| return opr::ConvBias::make(inputs[0], inputs[1], inputs[2], | |||
| conv.param(), conv.policy(), config); | |||
| } else if (inputs.size() == 4) { | |||
| return opr::ConvBias::make(inputs[0], inputs[1], inputs[2], inputs[3], conv.param(), conv.policy(), config); | |||
| return opr::ConvBias::make(inputs[0], inputs[1], inputs[2], inputs[3], | |||
| conv.param(), conv.policy(), config); | |||
| } | |||
| mgb_assert(0); | |||
| } | |||
| OP_TRAIT_REG(ConvBias, ConvBias) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // conv_bias | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| } // namespace conv_bias | |||
| } // namespace | |||
| namespace { namespace batch_conv_bias { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| namespace { | |||
| namespace batch_conv_bias { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& conv = static_cast<const BatchConvBias&>(def); | |||
| cg::OperatorNodeConfig config{conv.dtype}; | |||
| config.name(conv.make_name()); | |||
| if (inputs.size() == 2) { | |||
| return opr::BatchConvBias::make(inputs[0], inputs[1], conv.param(), conv.policy(), config); | |||
| return opr::BatchConvBias::make(inputs[0], inputs[1], conv.param(), | |||
| conv.policy(), config); | |||
| } else if (inputs.size() == 3) { | |||
| return opr::BatchConvBias::make(inputs[0], inputs[1], inputs[2], conv.param(), conv.policy(), config); | |||
| return opr::BatchConvBias::make(inputs[0], inputs[1], inputs[2], | |||
| conv.param(), conv.policy(), config); | |||
| } else if (inputs.size() == 4) { | |||
| return opr::BatchConvBias::make(inputs[0], inputs[1], inputs[2], inputs[3], conv.param(), conv.policy(), config); | |||
| return opr::BatchConvBias::make(inputs[0], inputs[1], inputs[2], | |||
| inputs[3], conv.param(), conv.policy(), | |||
| config); | |||
| } | |||
| mgb_assert(0); | |||
| } | |||
| OP_TRAIT_REG(BatchConvBias, BatchConvBias) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // batch_conv_bias | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| } // namespace batch_conv_bias | |||
| } // namespace | |||
| namespace { namespace pooling { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| namespace { | |||
| namespace pooling { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& pool = static_cast<const Pooling&>(def); | |||
| OperatorNodeConfig config{pool.make_name()}; | |||
| return opr::Pooling::make(inputs[0], pool.param(), config); | |||
| } | |||
| OP_TRAIT_REG(Pooling, Pooling) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // pooling | |||
| OP_TRAIT_REG(Pooling, Pooling).apply_on_var_node(apply_on_var_node).fallback(); | |||
| } // namespace pooling | |||
| } // namespace | |||
| namespace { namespace matrix_mul { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| namespace { | |||
| namespace matrix_mul { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& matmul = static_cast<const MatrixMul&>(def); | |||
| mgb_assert(inputs.size() == 2); | |||
| OperatorNodeConfig config{matmul.make_name()}; | |||
| @@ -222,14 +226,14 @@ auto apply_on_var_node( | |||
| matmul.policy(), config); | |||
| } | |||
| OP_TRAIT_REG(MatrixMul, MatrixMul) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // matrix_mul | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| } // namespace matrix_mul | |||
| } // namespace | |||
| namespace { namespace batched_matrix_mul { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| namespace { | |||
| namespace batched_matrix_mul { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& matmul = static_cast<const BatchedMatrixMul&>(def); | |||
| mgb_assert(inputs.size() == 2); | |||
| OperatorNodeConfig config{matmul.make_name()}; | |||
| @@ -237,166 +241,155 @@ auto apply_on_var_node( | |||
| matmul.policy(), config); | |||
| } | |||
| OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // batched_matrix_mul | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| } // namespace batched_matrix_mul | |||
| } // namespace | |||
| namespace { namespace dot { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| namespace { | |||
| namespace dot { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& op = def.cast_final_safe<Dot>(); | |||
| mgb_assert(inputs.size() == 2); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| return opr::Dot::make(inputs[0], inputs[1], config); | |||
| } | |||
| OP_TRAIT_REG(Dot, Dot) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // dot | |||
| OP_TRAIT_REG(Dot, Dot).apply_on_var_node(apply_on_var_node).fallback(); | |||
| } // namespace dot | |||
| } // namespace | |||
| namespace { namespace argsort { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| namespace { | |||
| namespace argsort { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& argsort = static_cast<const Argsort&>(def); | |||
| OperatorNodeConfig config{argsort.make_name()}; | |||
| return opr::Argsort::make(inputs[0], argsort.param(), config); | |||
| } | |||
| OP_TRAIT_REG(Argsort, Argsort) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // argsort | |||
| OP_TRAIT_REG(Argsort, Argsort).apply_on_var_node(apply_on_var_node).fallback(); | |||
| } // namespace argsort | |||
| } // namespace | |||
| namespace { namespace argmax { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| namespace { | |||
| namespace argmax { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& argmax = static_cast<const Argmax&>(def); | |||
| OperatorNodeConfig config{argmax.make_name()}; | |||
| return opr::Argmax::make(inputs[0], argmax.param(), config); | |||
| } | |||
| OP_TRAIT_REG(Argmax, Argmax) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // argmax | |||
| OP_TRAIT_REG(Argmax, Argmax).apply_on_var_node(apply_on_var_node).fallback(); | |||
| } // namespace argmax | |||
| } // namespace | |||
| namespace { namespace argmin { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| namespace { | |||
| namespace argmin { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& argmin = static_cast<const Argmin&>(def); | |||
| OperatorNodeConfig config{argmin.make_name()}; | |||
| return opr::Argmin::make(inputs[0], argmin.param(), config); | |||
| } | |||
| OP_TRAIT_REG(Argmin, Argmin) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // argmin | |||
| OP_TRAIT_REG(Argmin, Argmin).apply_on_var_node(apply_on_var_node).fallback(); | |||
| } // namespace argmin | |||
| } // namespace | |||
| namespace { namespace warp_perspective { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| namespace { | |||
| namespace warp_perspective { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& warp = static_cast<const WarpPerspective&>(def); | |||
| OperatorNodeConfig config{warp.make_name()}; | |||
| if (inputs.size() == 3) { | |||
| return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], warp.param(), config); | |||
| return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], | |||
| warp.param(), config); | |||
| } else { | |||
| mgb_assert(inputs.size() == 4); | |||
| return opr::WarpPerspective::make( | |||
| inputs[0], inputs[1], inputs[2], inputs[3], warp.param(), config); | |||
| return opr::WarpPerspective::make(inputs[0], inputs[1], inputs[2], | |||
| inputs[3], warp.param(), config); | |||
| } | |||
| } | |||
| OP_TRAIT_REG(WarpPerspective, WarpPerspective) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // warp_perspective | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| } // namespace warp_perspective | |||
| } // namespace | |||
| namespace { namespace group_local { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| namespace { | |||
| namespace group_local { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& local = static_cast<const GroupLocal&>(def); | |||
| mgb_assert(inputs.size() == 2); | |||
| OperatorNodeConfig config{local.make_name()}; | |||
| return opr::GroupLocal::make(inputs[0], inputs[1], local.param(), config); | |||
| } | |||
| OP_TRAIT_REG(GroupLocal, GroupLocal) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // group_local | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| } // namespace group_local | |||
| } // namespace | |||
| namespace { namespace indexing_one_hot { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| namespace { | |||
| namespace indexing_one_hot { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const IndexingOneHot&>(def); | |||
| mgb_assert(inputs.size() == 2); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| return opr::IndexingOneHot::make(inputs[0], inputs[1], op.param(), config); | |||
| } | |||
| OP_TRAIT_REG(IndexingOneHot, IndexingOneHot) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // indexing_one_hot | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| } // namespace indexing_one_hot | |||
| } // namespace | |||
| namespace { namespace indexing_set_one_hot { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| namespace { | |||
| namespace indexing_set_one_hot { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const IndexingSetOneHot&>(def); | |||
| mgb_assert(inputs.size() == 3); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| return opr::IndexingSetOneHot::make(inputs[0], inputs[1], inputs[2], op.param(), config); | |||
| return opr::IndexingSetOneHot::make(inputs[0], inputs[1], inputs[2], | |||
| op.param(), config); | |||
| } | |||
| OP_TRAIT_REG(IndexingSetOneHot, IndexingSetOneHot) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // indexing_set_one_hot | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| } // namespace indexing_set_one_hot | |||
| } // namespace | |||
| namespace { namespace typecvt { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| namespace { | |||
| namespace typecvt { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const TypeCvt&>(def); | |||
| mgb_assert(inputs.size() == 1); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| return opr::TypeCvt::make(inputs[0], op.dtype, config); | |||
| } | |||
| OP_TRAIT_REG(TypeCvt, TypeCvt) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // typecvt | |||
| OP_TRAIT_REG(TypeCvt, TypeCvt).apply_on_var_node(apply_on_var_node).fallback(); | |||
| } // namespace typecvt | |||
| } // namespace | |||
| namespace { namespace concat { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| namespace { | |||
| namespace concat { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const Concat&>(def); | |||
| cg::OperatorNodeConfig config{op.comp_node}; | |||
| config.name(op.make_name()); | |||
| return opr::Concat::make(inputs, op.axis, config); | |||
| } | |||
| OP_TRAIT_REG(Concat, Concat) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // concat | |||
| OP_TRAIT_REG(Concat, Concat).apply_on_var_node(apply_on_var_node).fallback(); | |||
| } // namespace concat | |||
| } // namespace | |||
| namespace { namespace copy { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| namespace { | |||
| namespace copy { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const Copy&>(def); | |||
| mgb_assert(inputs.size() == 1); | |||
| cg::OperatorNodeConfig config{op.comp_node}; | |||
| config.name(op.make_name()); | |||
| return opr::Copy::make(inputs[0], config); | |||
| } | |||
| OP_TRAIT_REG(Copy, Copy) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // copy | |||
| OP_TRAIT_REG(Copy, Copy).apply_on_var_node(apply_on_var_node).fallback(); | |||
| } // namespace copy | |||
| } // namespace | |||
| namespace { namespace assert_equal { | |||
| auto apply_on_var_node( | |||
| @@ -408,81 +401,81 @@ auto apply_on_var_node( | |||
| } else { | |||
| // workaround for MiniGraph, which only allow one opr in the graph | |||
| mgb_assert(inputs.size() == 3); | |||
| return opr::AssertEqual::make(inputs[0], inputs[1], inputs[2], op.param(), {}); | |||
| return opr::AssertEqual::make(inputs[0], inputs[1], inputs[2], | |||
| op.param(), {}); | |||
| } | |||
| } | |||
| OP_TRAIT_REG(AssertEqual, AssertEqual) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // assert_equal | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| } // namespace assert_equal | |||
| } // namespace | |||
| namespace { namespace roi_align { | |||
| VarNodeArray apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| namespace { | |||
| namespace roi_align { | |||
| VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const ROIAlign&>(def); | |||
| mgb_assert(inputs.size() == 2); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| auto* opr = opr::ROIAlign::make( | |||
| inputs[0], inputs[1], op.param(), config).node()->owner_opr(); | |||
| auto* opr = opr::ROIAlign::make(inputs[0], inputs[1], op.param(), config) | |||
| .node() | |||
| ->owner_opr(); | |||
| return {opr->output(0), opr->output(1)}; | |||
| } | |||
| OP_TRAIT_REG(ROIAlign, ROIAlign) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // roi_align | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| } // namespace roi_align | |||
| } // namespace | |||
| namespace { namespace correlation { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| namespace { | |||
| namespace correlation { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const Correlation&>(def); | |||
| mgb_assert(inputs.size() == 2); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| return opr::Correlation::make( | |||
| inputs[0], inputs[1], op.param(), config); | |||
| return opr::Correlation::make(inputs[0], inputs[1], op.param(), config); | |||
| } | |||
| OP_TRAIT_REG(Correlation, Correlation) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // correlation | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| } // namespace correlation | |||
| } // namespace | |||
| #if MGB_CUDA | |||
| namespace { namespace nvof { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| namespace { | |||
| namespace nvof { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const NvOf&>(def); | |||
| mgb_assert(inputs.size() == 1); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| return opr::NvOf::make(inputs[0], op.param(), config); | |||
| } | |||
| OP_TRAIT_REG(NvOf, NvOf) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // nvof | |||
| OP_TRAIT_REG(NvOf, NvOf).apply_on_var_node(apply_on_var_node).fallback(); | |||
| } // namespace nvof | |||
| } // namespace | |||
| #endif | |||
| namespace { namespace linspace { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| namespace { | |||
| namespace linspace { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const Linspace&>(def); | |||
| mgb_assert(inputs.size() == 3); | |||
| cg::OperatorNodeConfig config{op.comp_node}; | |||
| config.name(op.make_name()); | |||
| return opr::Linspace::make(inputs[0], inputs[1], inputs[2], op.param(), config); | |||
| return opr::Linspace::make(inputs[0], inputs[1], inputs[2], op.param(), | |||
| config); | |||
| } | |||
| OP_TRAIT_REG(Linspace, Linspace) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // linspace | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| } // namespace linspace | |||
| } // namespace | |||
| namespace { namespace eye { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| namespace { | |||
| namespace eye { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const Eye&>(def); | |||
| mgb_assert(inputs.size() == 1); | |||
| cg::OperatorNodeConfig config{op.comp_node}; | |||
| @@ -490,58 +483,59 @@ auto apply_on_var_node( | |||
| opr::Eye::Param param{op.k, op.dtype.enumv()}; | |||
| return opr::Eye::make(inputs[0], param, config); | |||
| } | |||
| OP_TRAIT_REG(Eye, Eye) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // eye | |||
| OP_TRAIT_REG(Eye, Eye).apply_on_var_node(apply_on_var_node).fallback(); | |||
| } // namespace eye | |||
| } // namespace | |||
| namespace { namespace roi_pooling { | |||
| VarNodeArray apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| namespace { | |||
| namespace roi_pooling { | |||
| VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const ROIPooling&>(def); | |||
| mgb_assert(inputs.size() == 3); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| auto* opr = opr::ROIPooling::make( | |||
| inputs[0], inputs[1], inputs[2], op.param(), config | |||
| ).node()->owner_opr(); | |||
| auto* opr = opr::ROIPooling::make(inputs[0], inputs[1], inputs[2], | |||
| op.param(), config) | |||
| .node() | |||
| ->owner_opr(); | |||
| return {opr->output(0), opr->output(1)}; | |||
| } | |||
| OP_TRAIT_REG(ROIPooling, ROIPooling) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // roi_pooling | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| } // namespace roi_pooling | |||
| } // namespace | |||
| namespace { namespace remap { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| namespace { | |||
| namespace remap { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const Remap&>(def); | |||
| mgb_assert(inputs.size() == 2); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| return opr::Remap::make(inputs[0], inputs[1], op.param(), config); | |||
| } | |||
| OP_TRAIT_REG(Remap, Remap) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // remap | |||
| OP_TRAIT_REG(Remap, Remap).apply_on_var_node(apply_on_var_node).fallback(); | |||
| } // namespace remap | |||
| } // namespace | |||
| namespace { | |||
| auto get_index( | |||
| const VarNodeArray& inputs, size_t vidx, | |||
| const std::vector<std::tuple<int8_t, bool, bool, bool, bool>>& mask) { | |||
| const VarNodeArray& inputs, size_t vidx, | |||
| const std::vector<std::tuple<int8_t, bool, bool, bool, bool>>& mask) { | |||
| size_t length = mask.size(); | |||
| opr::Subtensor::IndexDesc ret(length); | |||
| for (size_t i = 0; i < length; ++ i) { | |||
| for (size_t i = 0; i < length; ++i) { | |||
| auto&& [axis, begin, end, step, idx] = mask[i]; | |||
| ret[i].axis = axis; | |||
| if (idx) { | |||
| ret[i].idx = inputs[vidx++]; | |||
| } else { | |||
| mgb_assert(begin || end || step); | |||
| if (begin) ret[i].begin = inputs[vidx++]; | |||
| if (end) ret[i].end = inputs[vidx++]; | |||
| if (step) ret[i].step = inputs[vidx++]; | |||
| if (begin) | |||
| ret[i].begin = inputs[vidx++]; | |||
| if (end) | |||
| ret[i].end = inputs[vidx++]; | |||
| if (step) | |||
| ret[i].step = inputs[vidx++]; | |||
| } | |||
| } | |||
| mgb_assert(vidx == inputs.size()); | |||
| @@ -550,19 +544,19 @@ auto get_index( | |||
| #define IN1 inputs[0] | |||
| #define IN2 inputs[0], inputs[1] | |||
| #define FANCY_INDEXING_IMPL(NAME, NR_INPUT) \ | |||
| namespace NAME##_impl { \ | |||
| auto apply_on_var_node( \ | |||
| const OpDef& def, \ | |||
| const VarNodeArray& inputs) { \ | |||
| auto&& op = static_cast<const NAME&>(def); \ | |||
| OperatorNodeConfig config{op.make_name()}; \ | |||
| return opr::NAME::make(IN##NR_INPUT, get_index(inputs, NR_INPUT, op.items), config); \ | |||
| } \ | |||
| OP_TRAIT_REG(NAME, NAME) \ | |||
| .apply_on_var_node(apply_on_var_node) \ | |||
| .fallback(); \ | |||
| } | |||
| #define FANCY_INDEXING_IMPL(NAME, NR_INPUT) \ | |||
| namespace NAME##_impl { \ | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { \ | |||
| auto&& op = static_cast<const NAME&>(def); \ | |||
| OperatorNodeConfig config{op.make_name()}; \ | |||
| return opr::NAME::make(IN##NR_INPUT, \ | |||
| get_index(inputs, NR_INPUT, op.items), \ | |||
| config); \ | |||
| } \ | |||
| OP_TRAIT_REG(NAME, NAME) \ | |||
| .apply_on_var_node(apply_on_var_node) \ | |||
| .fallback(); \ | |||
| } | |||
| FANCY_INDEXING_IMPL(Subtensor, 1) | |||
| FANCY_INDEXING_IMPL(SetSubtensor, 2) | |||
| @@ -580,76 +574,88 @@ FANCY_INDEXING_IMPL(BatchedSetMeshIndexing, 2) | |||
| #undef FANCY_INDEXING_IMPL | |||
| #undef IN1 | |||
| #undef IN2 | |||
| } // anonymous namespace | |||
| } // anonymous namespace | |||
| namespace { namespace fake_quant { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| namespace { | |||
| namespace fake_quant { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const FakeQuant&>(def); | |||
| mgb_assert(inputs.size() == 3); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| return opr::FakeQuant::make(inputs[0], inputs[1], inputs[2], op.param(), config); | |||
| return opr::FakeQuant::make(inputs[0], inputs[1], inputs[2], op.param(), | |||
| config); | |||
| } | |||
| OP_TRAIT_REG(FakeQuant, FakeQuant) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // fake_quant | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| } // namespace fake_quant | |||
| } // namespace | |||
| namespace { namespace tqt { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| namespace { | |||
| namespace tqt { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const TQT&>(def); | |||
| mgb_assert(inputs.size() == 2); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| return opr::TQT::make(inputs[0], inputs[1], op.param(), config); | |||
| } | |||
| OP_TRAIT_REG(TQT, TQT) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // tqt | |||
| OP_TRAIT_REG(TQT, TQT).apply_on_var_node(apply_on_var_node).fallback(); | |||
| } // namespace tqt | |||
| } // namespace | |||
| namespace { namespace elemwise_multi_type { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| namespace { | |||
| namespace elemwise_multi_type { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const ElemwiseMultiType&>(def); | |||
| OperatorNodeConfig config{op.dtype}; | |||
| config.name(op.make_name()); | |||
| return opr::ElemwiseMultiType::make(inputs, op.param(), config); | |||
| } | |||
| OP_TRAIT_REG(ElemwiseMultiType, ElemwiseMultiType) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // elemwise_multi_type | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| } // namespace elemwise_multi_type | |||
| } // namespace | |||
| namespace { namespace svd { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| namespace { | |||
| namespace svd { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const SVD&>(def); | |||
| mgb_assert(inputs.size() == 1); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| return opr::SVD::make(inputs[0], op.param(), config)[0] | |||
| .node()->owner_opr()->usable_output(); | |||
| .node() | |||
| ->owner_opr() | |||
| ->usable_output(); | |||
| } | |||
| OP_TRAIT_REG(SVD, SVD) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // svd | |||
| OP_TRAIT_REG(SVD, SVD).apply_on_var_node(apply_on_var_node).fallback(); | |||
| } // namespace svd | |||
| } // namespace | |||
| namespace { namespace images2neibs { | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| namespace { | |||
| namespace images2neibs { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const Images2Neibs&>(def); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| return opr::Images2Neibs::make(inputs[0], op.param(), config); | |||
| } | |||
| OP_TRAIT_REG(Images2Neibs, Images2Neibs) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| }} // images2neibs | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .fallback(); | |||
| } // namespace images2neibs | |||
| } // namespace | |||
| namespace { | |||
| namespace lsq { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& op = static_cast<const LSQ&>(def); | |||
| mgb_assert(inputs.size() == 4); | |||
| OperatorNodeConfig config{op.make_name()}; | |||
| return opr::LSQ::make(inputs[0], inputs[1], inputs[2], inputs[3], | |||
| op.param(), config); | |||
| } | |||
| OP_TRAIT_REG(LSQ, LSQ).apply_on_var_node(apply_on_var_node).fallback(); | |||
| } // namespace lsq | |||
| } // namespace | |||
| } // namespace mgb::imperative | |||
| } // namespace mgb::imperative | |||
| @@ -6,22 +6,24 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "./helper.h" | |||
| #include "megbrain/imperative/backward_graph_opt.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| #include "megbrain/imperative/ops/opr_attr.h" | |||
| #include "megbrain/opr/basic_arith.h" | |||
| #include "megbrain/opr/dnn/batch_norm.h" | |||
| #include "megbrain/imperative/ops/opr_attr.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| #include "megbrain/imperative/backward_graph_opt.h" | |||
| using namespace mgb; | |||
| using namespace cg; | |||
| using namespace imperative; | |||
| template <typename T> | |||
| T prepare_backward_graph_inputs(const BackwardGraphResult& bg, const T& inputs, const T& outputs, const T& grads) { | |||
| T prepare_backward_graph_inputs(const BackwardGraphResult& bg, const T& inputs, | |||
| const T& outputs, const T& grads) { | |||
| T ret; | |||
| size_t i = 0; | |||
| for (auto&& t : inputs) { | |||
| @@ -54,7 +56,9 @@ T expand_grads(const U& bg, const T& outputs) { | |||
| } | |||
| template <typename T> | |||
| T prepare_optimized_backward_inputs(const OptimizedBackwardGraphResult& bg, const T& precomp, const T& inputs, const T& outputs, const T& grads) { | |||
| T prepare_optimized_backward_inputs(const OptimizedBackwardGraphResult& bg, | |||
| const T& precomp, const T& inputs, | |||
| const T& outputs, const T& grads) { | |||
| T ret = precomp; | |||
| size_t i = 0; | |||
| for (auto&& t : inputs) { | |||
| @@ -75,7 +79,8 @@ T prepare_optimized_backward_inputs(const OptimizedBackwardGraphResult& bg, cons | |||
| return ret; | |||
| } | |||
| SmallVector<TensorPtr> apply_shared_on_physical_tensor(std::shared_ptr<OpDef> def, SmallVector<TensorPtr> inputs) { | |||
| SmallVector<TensorPtr> apply_shared_on_physical_tensor( | |||
| std::shared_ptr<OpDef> def, SmallVector<TensorPtr> inputs) { | |||
| return OpDef::apply_on_physical_tensor(*def, inputs); | |||
| } | |||
| @@ -83,7 +88,7 @@ TEST(TestImperative, BackwardGraphBasic) { | |||
| HostTensorGenerator<> gen; | |||
| SmallVector<HostTensorND> hvs; | |||
| SmallVector<TensorPtr> inputs; | |||
| for(size_t i = 0; i < 2; ++ i) { | |||
| for (size_t i = 0; i < 2; ++i) { | |||
| hvs.push_back(*gen({42})); | |||
| inputs.push_back(Tensor::make(hvs.back())); | |||
| } | |||
| @@ -97,7 +102,8 @@ TEST(TestImperative, BackwardGraphBasic) { | |||
| for (auto&& i : inputs) { | |||
| input_descs.push_back({i->layout(), i->comp_node()}); | |||
| } | |||
| auto result = OpDef::make_backward_graph(*attr, input_descs, {true, true}, {true}); | |||
| auto result = OpDef::make_backward_graph(*attr, input_descs, {true, true}, | |||
| {true}); | |||
| auto&& save_for_backward = result.save_for_backward; | |||
| auto&& input_has_grad = result.input_has_grad; | |||
| @@ -106,9 +112,9 @@ TEST(TestImperative, BackwardGraphBasic) { | |||
| hvs.push_back(*gen({42})); | |||
| inputs.push_back(Tensor::make(hvs.back())); | |||
| mgb_assert(save_for_backward.size() == inputs.size()); | |||
| for (size_t i = 0; i < inputs.size(); ++ i) { | |||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||
| if (!save_for_backward[i]) { | |||
| inputs[i].reset(); // drop unused tensor | |||
| inputs[i].reset(); // drop unused tensor | |||
| } | |||
| } | |||
| SmallVector<TensorPtr> backward_graph_inputs; | |||
| @@ -118,13 +124,11 @@ TEST(TestImperative, BackwardGraphBasic) { | |||
| } | |||
| } | |||
| inputs.clear(); | |||
| auto input_grads = result.backward.apply( | |||
| backward_graph_inputs, | |||
| apply_shared_on_physical_tensor, | |||
| [&](auto&& x){ return x; } | |||
| ); | |||
| auto input_grads = result.backward.apply(backward_graph_inputs, | |||
| apply_shared_on_physical_tensor, | |||
| [&](auto&& x) { return x; }); | |||
| mgb_assert(input_grads.size() == input_has_grad.size()); | |||
| for (size_t i = 0; i < input_has_grad.size(); ++ i) { | |||
| for (size_t i = 0; i < input_has_grad.size(); ++i) { | |||
| mgb_assert(input_has_grad[i] == static_cast<bool>(input_grads[i])); | |||
| } | |||
| @@ -133,9 +137,10 @@ TEST(TestImperative, BackwardGraphBasic) { | |||
| res.emplace_back(); | |||
| res.back().copy_from(i->dev_tensor()).sync(); | |||
| } | |||
| for (size_t i = 0; i < 42; ++ i) { | |||
| for (size_t j = 0; j < 1; ++ j) { | |||
| ASSERT_EQ(hvs[2].ptr<float>()[i] * hvs[j].ptr<float>()[i], res[j ^ 1].ptr<float>()[i]); | |||
| for (size_t i = 0; i < 42; ++i) { | |||
| for (size_t j = 0; j < 1; ++j) { | |||
| ASSERT_EQ(hvs[2].ptr<float>()[i] * hvs[j].ptr<float>()[i], | |||
| res[j ^ 1].ptr<float>()[i]); | |||
| } | |||
| } | |||
| } | |||
| @@ -152,7 +157,8 @@ TEST(TestImperative, BackwardGraphIdentity) { | |||
| SmallVector<LogicalTensorDesc> input_descs; | |||
| input_descs.push_back({a->layout(), a->comp_node()}); | |||
| auto result = OpDef::make_backward_graph(*attr, input_descs, {true}, {true}); | |||
| auto result = | |||
| OpDef::make_backward_graph(*attr, input_descs, {true}, {true}); | |||
| auto&& save_for_backward = result.save_for_backward; | |||
| auto&& input_has_grad = result.input_has_grad; | |||
| @@ -160,9 +166,9 @@ TEST(TestImperative, BackwardGraphIdentity) { | |||
| inputs.push_back(outputs[0]); | |||
| inputs.push_back(dc); | |||
| mgb_assert(save_for_backward.size() == inputs.size()); | |||
| for (size_t i = 0; i < inputs.size(); ++ i) { | |||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||
| if (!save_for_backward[i]) { | |||
| inputs[i].reset(); // drop unused tensor | |||
| inputs[i].reset(); // drop unused tensor | |||
| } | |||
| } | |||
| SmallVector<TensorPtr> backward_graph_inputs; | |||
| @@ -172,19 +178,17 @@ TEST(TestImperative, BackwardGraphIdentity) { | |||
| } | |||
| } | |||
| inputs.clear(); | |||
| auto input_grads = result.backward.apply( | |||
| backward_graph_inputs, | |||
| apply_shared_on_physical_tensor, | |||
| [&](auto&& x){ return x; } | |||
| ); | |||
| auto input_grads = result.backward.apply(backward_graph_inputs, | |||
| apply_shared_on_physical_tensor, | |||
| [&](auto&& x) { return x; }); | |||
| mgb_assert(input_grads.size() == input_has_grad.size()); | |||
| for (size_t i = 0; i < input_has_grad.size(); ++ i) { | |||
| for (size_t i = 0; i < input_has_grad.size(); ++i) { | |||
| mgb_assert(input_has_grad[i] == static_cast<bool>(input_grads[i])); | |||
| } | |||
| HostTensorND hv; | |||
| hv.copy_from(input_grads[0]->dev_tensor()).sync(); | |||
| for (size_t i = 0; i < 42; ++ i) { | |||
| for (size_t i = 0; i < 42; ++i) { | |||
| ASSERT_EQ(host_dc->ptr<float>()[i], hv.ptr<float>()[i]); | |||
| } | |||
| } | |||
| @@ -192,7 +196,7 @@ TEST(TestImperative, BackwardGraphIdentity) { | |||
| TEST(TestImperative, BatchNormGrad) { | |||
| auto cn = CompNode::load("xpux"); | |||
| using Param = opr::BatchNorm::Param; | |||
| size_t N=2, C=3, H=5, W=5; | |||
| size_t N = 2, C = 3, H = 5, W = 5; | |||
| LogicalTensorDesc inp{TensorLayout{{N, C, H, W}, dtype::Float32()}, cn}; | |||
| LogicalTensorDesc stat{TensorLayout{{C}, dtype::Float32()}, cn}; | |||
| { | |||
| @@ -202,7 +206,8 @@ TEST(TestImperative, BatchNormGrad) { | |||
| param.fwd_mode = Param::FwdMode::TRAINING; | |||
| attr.param.write_pod(param); | |||
| OpDef::make_backward_graph(attr, {inp, stat, stat, stat, stat}, | |||
| {true, true ,true, false, false}, {false, false, false, false, true}); | |||
| {true, true, true, false, false}, | |||
| {false, false, false, false, true}); | |||
| } | |||
| { | |||
| auto op = OprAttr::make("BatchNorm"); | |||
| @@ -210,8 +215,8 @@ TEST(TestImperative, BatchNormGrad) { | |||
| Param param; | |||
| param.fwd_mode = Param::FwdMode::TRAINING; | |||
| attr.param.write_pod(param); | |||
| OpDef::make_backward_graph(attr, {inp, stat, stat}, | |||
| {true, true ,true}, {false, false, true}); | |||
| OpDef::make_backward_graph(attr, {inp, stat, stat}, {true, true, true}, | |||
| {false, false, true}); | |||
| } | |||
| } | |||
| @@ -220,7 +225,8 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) { | |||
| LogicalTensorDesc desc = {TensorLayout(dtype::Float32()), cn}; | |||
| HostTensorGenerator<> gen; | |||
| auto op = std::shared_ptr<OpDef>(Elemwise::make(Elemwise::Mode::ADD)); | |||
| auto bg = OpDef::make_backward_graph(*op, {desc, desc}, {true, true}, {true}); | |||
| auto bg = | |||
| OpDef::make_backward_graph(*op, {desc, desc}, {true, true}, {true}); | |||
| auto obg = OptimizedBackwardGraphResult(bg); | |||
| ASSERT_EQ(obg.save_for_backward.size(), 4); | |||
| ASSERT_FALSE(obg.save_for_backward[0]); | |||
| @@ -235,30 +241,30 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) { | |||
| auto dc_tn = Tensor::make(*dc_hv); | |||
| auto c_tn = OpDef::apply_on_physical_tensor(*op, {a_tn, b_tn})[0]; | |||
| auto backward_graph_inputs = prepare_backward_graph_inputs<SmallVector<TensorPtr>>(bg, {a_tn, b_tn}, {c_tn}, {dc_tn}); | |||
| auto grads = expand_grads(bg, bg.backward.apply( | |||
| backward_graph_inputs, | |||
| apply_shared_on_physical_tensor, | |||
| [&](auto&& x){ return x; } | |||
| )); | |||
| auto backward_graph_inputs = | |||
| prepare_backward_graph_inputs<SmallVector<TensorPtr>>( | |||
| bg, {a_tn, b_tn}, {c_tn}, {dc_tn}); | |||
| auto grads = | |||
| expand_grads(bg, bg.backward.apply(backward_graph_inputs, | |||
| apply_shared_on_physical_tensor, | |||
| [&](auto&& x) { return x; })); | |||
| auto precomp = obg.precomp.apply( | |||
| SmallVector<TensorPtr>{a_tn, b_tn, c_tn}, | |||
| apply_shared_on_physical_tensor, | |||
| [&](auto&& x){ return x; } | |||
| ); | |||
| auto precomp = obg.precomp.apply(SmallVector<TensorPtr>{a_tn, b_tn, c_tn}, | |||
| apply_shared_on_physical_tensor, | |||
| [&](auto&& x) { return x; }); | |||
| ASSERT_EQ(precomp.size(), 2); | |||
| ASSERT_EQ(precomp[0]->shape().ndim, 1); | |||
| ASSERT_LE(precomp[0]->shape()[0], 2); | |||
| ASSERT_EQ(precomp[1]->shape().ndim, 1); | |||
| ASSERT_LE(precomp[1]->shape()[0], 2); | |||
| auto backward_inputs = prepare_optimized_backward_inputs<SmallVector<TensorPtr>>(obg, precomp, {a_tn, b_tn}, {c_tn}, {dc_tn}); | |||
| auto grads2 = expand_grads(obg, obg.backward.apply( | |||
| backward_inputs, | |||
| apply_shared_on_physical_tensor, | |||
| [&](auto&& x){ return x; } | |||
| )); | |||
| auto backward_inputs = | |||
| prepare_optimized_backward_inputs<SmallVector<TensorPtr>>( | |||
| obg, precomp, {a_tn, b_tn}, {c_tn}, {dc_tn}); | |||
| auto grads2 = expand_grads( | |||
| obg, | |||
| obg.backward.apply(backward_inputs, apply_shared_on_physical_tensor, | |||
| [&](auto&& x) { return x; })); | |||
| ASSERT_EQ(grads2.size(), 2); | |||
| MGB_ASSERT_TENSOR_EQ(grads[0]->get_value(), grads2[0]->get_value()); | |||
| @@ -271,6 +271,7 @@ def BatchedSetMeshIndexing: FancyIndexingBase<"BatchedSetMeshIndexing">; | |||
| def FakeQuant: MgbHashableOp<"FakeQuant", [FakeQuantParam]>; | |||
| def AssertEqual: MgbHashableOp<"AssertEqual",[AssertEqualParam]>; | |||
| def TQT: MgbHashableOp<"TQT", [TQTParam]>; | |||
| def LSQ: MgbHashableOp<"LSQ", [LSQParam]>; | |||
| def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypeParam]> { | |||
| let extraArguments = (ins | |||
| MgbDTypeAttr:$dtype | |||
| @@ -324,5 +324,7 @@ decl_opr('FakeQuant', | |||
| decl_opr('TQT', | |||
| inputs=[Doc('src','input tensor'),Doc('scale','scale tensor')], | |||
| params='TQT') | |||
| decl_opr('LSQ', | |||
| inputs=[Doc('src','input tensor'),Doc('scale','scale tensor'),Doc('zero_point','zero point tensor'),Doc('grad_scale','grad scale tensor')], | |||
| params='LSQ') | |||
| # vim: ft=python | |||
| @@ -6,20 +6,22 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "megbrain/opr/dnn/adaptive_pooling.h" | |||
| #include "megbrain/opr/dnn/batch_norm.h" | |||
| #include "megbrain/opr/dnn/convolution.h" | |||
| #include "megbrain/opr/dnn/correlation.h" | |||
| #include "megbrain/opr/dnn/fake_quant.h" | |||
| #include "megbrain/opr/dnn/images2neibs.h" | |||
| #include "megbrain/opr/dnn/pooling.h" | |||
| #include "megbrain/opr/dnn/adaptive_pooling.h" | |||
| #include "megbrain/opr/dnn/roi_pooling.h" | |||
| #include "megbrain/opr/dnn/roi_align.h" | |||
| #include "megbrain/opr/dnn/local.h" | |||
| #include "megbrain/opr/dnn/lrn.h" | |||
| #include "megbrain/opr/dnn/fake_quant.h" | |||
| #include "megbrain/opr/dnn/lsq.h" | |||
| #include "megbrain/opr/dnn/pooling.h" | |||
| #include "megbrain/opr/dnn/roi_align.h" | |||
| #include "megbrain/opr/dnn/roi_pooling.h" | |||
| #include "megbrain/opr/dnn/tqt.h" | |||
| #include "megbrain/serialization/sereg.h" | |||
| #include "megdnn/opr_param_defs.h" | |||
| @@ -183,7 +185,8 @@ struct ConvLoadDumpImpl { | |||
| static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { | |||
| auto&& opr = opr_.cast_final_safe<Opr>(); | |||
| ctx.write_param<ConvParam>(opr.param()); | |||
| ctx.write_param<megdnn::param::ExecutionPolicy>(opr.execution_policy_transient()); | |||
| ctx.write_param<megdnn::param::ExecutionPolicy>( | |||
| opr.execution_policy_transient()); | |||
| } | |||
| static VarNode* make(const cg::VarNodeArray& inputs, const ConvParam& param, | |||
| @@ -251,6 +254,20 @@ struct OprMaker<opr::TQTBackward, 3> { | |||
| } | |||
| }; | |||
| template <> | |||
| struct OprMaker<opr::LSQBackward, 5> { | |||
| using Param = opr::LSQBackward::Param; | |||
| static cg::OperatorNodeBase* make(const Param& param, | |||
| const cg::VarNodeArray& i, | |||
| ComputingGraph& graph, | |||
| const OperatorNodeConfig& config) { | |||
| MGB_MARK_USED_VAR(graph); | |||
| return opr::LSQBackward::make(i[0], i[1], i[2], i[3], i[4], param, | |||
| config)[0] | |||
| .node() | |||
| ->owner_opr(); | |||
| } | |||
| }; | |||
| template <> | |||
| struct OprLoadDumpImpl<opr::AdaptivePoolingBackward, 0> | |||
| : public PoolingLoadDumpImpl<opr::AdaptivePoolingBackward, | |||
| @@ -587,6 +604,8 @@ MGB_SEREG_OPR(FakeQuant, 3); | |||
| MGB_SEREG_OPR(FakeQuantBackward, 4); | |||
| MGB_SEREG_OPR(TQT, 2); | |||
| MGB_SEREG_OPR(TQTBackward, 3); | |||
| MGB_SEREG_OPR(LSQ, 4); | |||
| MGB_SEREG_OPR(LSQBackward, 5); | |||
| } // namespace opr | |||
| @@ -0,0 +1,90 @@ | |||
| /** | |||
| * \file src/opr/impl/dnn/lsq.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "megbrain/opr/dnn/lsq.h" | |||
| #include "../internal/megdnn_opr_wrapper.inl" | |||
| #include "megbrain/graph/grad_impl.h" | |||
| #include "megbrain/opr/basic_arith_wrapper.h" | |||
| #include "megbrain/opr/internal/out_shape_by_sym_var.h" | |||
| #include "megbrain/opr/tensor_manip.h" | |||
| #include "megbrain/opr/utility.h" | |||
| using namespace mgb; | |||
| using namespace opr; | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(LSQForward); | |||
| MEGDNN_OPR_INIT4(LSQForward, "lsq_fwd"); | |||
| #ifdef MGB_ENABLE_GRAD | |||
| MGB_IMPL_OPR_GRAD(LSQForward) { | |||
| SymbolVarArray grad = | |||
| LSQBackward::make(out_grad[0], opr.input(0), opr.input(1), | |||
| opr.input(2), opr.input(3), opr.param()); | |||
| if (wrt_idx == 0) { | |||
| return grad[0].node(); | |||
| } else if (wrt_idx == 1) { | |||
| return reduce_sum(grad[1], GetVarShape::make(opr.input(wrt_idx))) | |||
| .node(); | |||
| } else { | |||
| return nullptr; | |||
| } | |||
| } | |||
| #endif | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(LSQBackward); | |||
| LSQBackward::LSQBackward(VarNode* y_grad, VarNode* x, VarNode* scale, | |||
| VarNode* zero_point, VarNode* grad_scale, | |||
| const Param& param, const OperatorNodeConfig& config) | |||
| : Super({x->owner_graph(), | |||
| config, | |||
| "lsq_bwd", | |||
| {y_grad, x, scale, zero_point, grad_scale}}, | |||
| 1, true) { | |||
| init_megdnn_opr(*this, param); | |||
| add_input({y_grad, x, scale, zero_point, grad_scale}); | |||
| } | |||
| SymbolVarArray LSQBackward::make(SymbolVar y_grad, SymbolVar x, SymbolVar scale, | |||
| SymbolVar zero_point, SymbolVar grad_scale, | |||
| const Param& param, | |||
| const OperatorNodeConfig& config) { | |||
| auto&& out = x.node()->owner_graph() | |||
| ->insert_opr(std::make_unique<LSQBackward>( | |||
| y_grad.node(), x.node(), scale.node(), | |||
| zero_point.node(), grad_scale.node(), param, | |||
| config)) | |||
| ->output(); | |||
| SymbolVarArray ret(out.size()); | |||
| for (size_t i = 0; i < ret.size(); ++i) { | |||
| ret[i] = out[i]; | |||
| } | |||
| return ret; | |||
| } | |||
| void LSQBackward::init_output_static_infer_desc() { | |||
| using namespace cg::static_infer; | |||
| auto&& mgr = owner_graph()->static_infer_manager(); | |||
| mgr.register_shape_infer(output(0), | |||
| ShapeInferDesc::make_identity(input(1))); | |||
| mgr.register_shape_infer(output(1), | |||
| ShapeInferDesc::make_identity(input(1))); | |||
| this->init_output_static_infer_desc_workspace( | |||
| intl::AutoAddWorkspaceNeedLimitGetter<megdnn::LSQBackward>::val); | |||
| } | |||
| void LSQBackward::init_output_dtype() { | |||
| output(0)->dtype(input(1)->dtype()); | |||
| output(1)->dtype(input(2)->dtype()); | |||
| } | |||
| @@ -0,0 +1,50 @@ | |||
| /** | |||
| * \file src/opr/include/megbrain/opr/dnn/lsq.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||
| #include "megdnn/oprs.h" | |||
| namespace mgb { | |||
| namespace opr { | |||
| MGB_DEFINE_OPR_CLASS(LSQForward, | |||
| intl::MegDNNOprWrapperFwd<megdnn::LSQForward>) // { | |||
| public: | |||
| LSQForward(VarNode* src, VarNode* scale, VarNode* zero_point, | |||
| VarNode* grad_scale, const Param& param, | |||
| const OperatorNodeConfig& config); | |||
| static SymbolVar make(SymbolVar src, SymbolVar scale, SymbolVar zero_point, | |||
| SymbolVar grad_scale, const Param& param = {}, | |||
| const OperatorNodeConfig& config = {}); | |||
| }; | |||
| using LSQ = LSQForward; | |||
| MGB_DEFINE_OPR_CLASS(LSQBackward, | |||
| intl::MegDNNOprWrapperBwd<megdnn::LSQBackward>) // { | |||
| public: | |||
| LSQBackward(VarNode* y_grad, VarNode* x, VarNode* scale, VarNode* zero_point, | |||
| VarNode* grad_scale, const Param& param, | |||
| const OperatorNodeConfig& config); | |||
| static SymbolVarArray make(SymbolVar y_grad, SymbolVar x, SymbolVar scale, | |||
| SymbolVar zero_point, SymbolVar grad_scale, | |||
| const Param& param = {}, | |||
| const OperatorNodeConfig& config = {}); | |||
| private: | |||
| void init_output_static_infer_desc() override; | |||
| void init_output_dtype() override; | |||
| }; | |||
| } // namespace opr | |||
| } // namespace mgb | |||
| @@ -0,0 +1,78 @@ | |||
| /** | |||
| * \file src/opr/test/dnn/lsq.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "megbrain/opr/dnn/lsq.h" | |||
| #include "megbrain/comp_node_env.h" | |||
| #include "megbrain/test/autocheck.h" | |||
| using namespace std; | |||
| using namespace mgb; | |||
| namespace { | |||
| void run() { | |||
| using Checker = AutoOprChecker<4, 1>; | |||
| auto make_graph = | |||
| [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray { | |||
| auto o0 = opr::LSQForward::make(inputs[0], inputs[1], inputs[2], | |||
| inputs[3]); | |||
| return {o0}; | |||
| }; | |||
| auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) { | |||
| auto opr = MegDNNHandle::get( | |||
| CompNodeEnv::from_comp_node(CompNode::default_cpu())) | |||
| ->create_operator<megdnn::LSQForward>(); | |||
| dest[0].dtype(dtype::Float32()) | |||
| .comp_node(inp[0]->comp_node()) | |||
| .resize(inp[0]->shape()); | |||
| opr->exec(inp[0]->as_megdnn(), inp[1]->as_megdnn(), inp[2]->as_megdnn(), | |||
| inp[3]->as_megdnn(), dest[0].as_megdnn(), {}); | |||
| }; | |||
| auto gen = [&](HostTensorND& src) { | |||
| HostTensorGenerator<dtype::Float32, RandomDistribution::GAUSSIAN> | |||
| src_gen(10.f); | |||
| src = *src_gen(src.shape(), src.comp_node()); | |||
| }; | |||
| Checker::RunOptions opt; | |||
| opt.numdiff_max_err = 1e-5; | |||
| Checker checker{make_graph, fwd}; | |||
| checker.set_input_generator(0, gen) | |||
| .set_input_generator(1, gen) | |||
| .set_input_generator(2, gen) | |||
| .set_input_generator(3, gen) | |||
| .set_input_allow_grad(0, false) | |||
| .set_input_allow_grad(1, false) | |||
| .set_input_allow_grad(2, false) | |||
| .set_input_allow_grad(3, false) | |||
| .set_output_allow_grad(0, false); | |||
| checker.run({TensorShape{1, 2, 3, 4}, TensorShape{1}, TensorShape{1}, | |||
| TensorShape{1}}, | |||
| opt) | |||
| .run({TensorShape{2, 3, 8, 8}, TensorShape{1}, TensorShape{1}, | |||
| TensorShape{1}}, | |||
| opt) | |||
| .run({TensorShape{1, 3, 4, 4}, TensorShape{1}, TensorShape{1}, | |||
| TensorShape{1}}, | |||
| opt); | |||
| } | |||
| } // anonymous namespace | |||
| TEST(TestOprDNN, LSQForward) { | |||
| REQUIRE_GPU(1); | |||
| run(); | |||
| } | |||
| @@ -107,6 +107,7 @@ union OperatorParam { | |||
| param.FakeQuant = 73, | |||
| param.TQT = 74, | |||
| param.Correlation = 75, | |||
| param.LSQ = 76, | |||
| } | |||
| table Operator { | |||