| @@ -2048,6 +2048,53 @@ protected: | |||
| const TensorLayout& doup, const TensorLayout& mask, | |||
| const TensorLayout& dinp, size_t workspace_in_bytes); | |||
| }; | |||
| class SoftmaxBase : public OperatorBase { | |||
| DEF_OPR_IMPL_CTOR(SoftmaxBase, OperatorBase); | |||
| DEF_OPR_PARAM(Softmax); | |||
| protected: | |||
| void deduce_layout_fwd(const TensorLayout& input, TensorLayout& output); | |||
| void check_layout_fwd(const TensorLayout& input, const TensorLayout& output); | |||
| }; | |||
| class SoftmaxForward : public SoftmaxBase { | |||
| DEF_OPR_IMPL(SoftmaxForward, SoftmaxBase, 1, 1); | |||
| public: | |||
| /** | |||
| * \param[in] input input tensor | |||
| * \param[out] output output tensor | |||
| */ | |||
| virtual void exec( | |||
| _megdnn_tensor_in input, _megdnn_tensor_out output, | |||
| _megdnn_workspace workspace) = 0; | |||
| void deduce_layout(const TensorLayout& input, TensorLayout& output); | |||
| virtual size_t get_workspace_in_bytes( | |||
| const TensorLayout& input, const TensorLayout& output) = 0; | |||
| protected: | |||
| void check_exec( | |||
| const TensorLayout& input, const TensorLayout& output, | |||
| size_t workspace_in_bytes); | |||
| }; | |||
| using Softmax = SoftmaxForward; | |||
| class SoftmaxBackward : public SoftmaxBase { | |||
| DEF_OPR_IMPL(SoftmaxBackward, SoftmaxBase, 2, 1); | |||
| public: | |||
| virtual void exec( | |||
| _megdnn_tensor_in input, _megdnn_tensor_in diff, _megdnn_tensor_out grad_x, | |||
| _megdnn_workspace workspace) = 0; | |||
| virtual size_t get_workspace_in_bytes( | |||
| const TensorLayout& input, const TensorLayout& diff, | |||
| const TensorLayout& grad_x) = 0; | |||
| protected: | |||
| void check_exec( | |||
| const TensorLayout& input, const TensorLayout& diff, | |||
| const TensorLayout& grad_x, size_t workspace_in_bytes); | |||
| }; | |||
| class RNNCellForward : public OperatorBase { | |||
| DEF_OPR_PARAM(RNNCell); | |||
| @@ -253,6 +253,10 @@ pdef('Axis').add_fields('int32', 'axis', 0) | |||
| add_enum_alias('Format', 'Convolution') | |||
| ) | |||
| (pdef('Softmax'). | |||
| add_fields('int32', 'axis', -1) | |||
| ) | |||
| (pdef('AdaptivePooling', version=0, is_legacy=True). | |||
| add_enum_alias('Mode', 'PoolingV0'). | |||
| add_enum_alias('Format', 'ConvolutionV0') | |||
| @@ -219,7 +219,9 @@ private: | |||
| cb(RNN) \ | |||
| cb(RNNBackward) \ | |||
| cb(LSTM) \ | |||
| cb(LSTMBackward) | |||
| cb(LSTMBackward) \ | |||
| cb(SoftmaxForward) \ | |||
| cb(SoftmaxBackward) | |||
| // clang-format on | |||
| /*! | |||
| @@ -145,6 +145,8 @@ DEF(RNNBackward, 10, true, true); | |||
| DEF(LSTMCellForward, 10, true, true); | |||
| DEF(LSTMForward, 8, true, true); | |||
| DEF(LSTMBackward, 13, true, true); | |||
| DEF(SoftmaxForward, 2, true, true); | |||
| DEF(SoftmaxBackward, 3, true, false); | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,61 @@ | |||
| /** | |||
| * \file dnn/src/common/softmax.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "megdnn/oprs.h" | |||
| #include "src/common/utils.h" | |||
| namespace megdnn { | |||
| void SoftmaxBase::deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst) { | |||
| megdnn_assert( | |||
| param().axis >= -static_cast<int32_t>(src.ndim) && | |||
| param().axis < static_cast<int32_t>(src.ndim), | |||
| "axis: %d ndim: %zu", param().axis, src.ndim); | |||
| megdnn_assert_contiguous(src); | |||
| dst = src; | |||
| dst.dtype = src.dtype; | |||
| dst.format = src.format; | |||
| dst.init_contiguous_stride(); | |||
| } | |||
| void SoftmaxBase::check_layout_fwd(const TensorLayout& src, const TensorLayout& dst) { | |||
| TensorLayout dst_expected; | |||
| megdnn_assert_eq_dtype(src, dst); | |||
| deduce_layout_fwd(src, dst_expected); | |||
| megdnn_assert_eq_layout(dst_expected, dst); | |||
| megdnn_assert(src.dtype == dst.dtype); | |||
| } | |||
| void SoftmaxForward::deduce_layout(const TensorLayout& src, TensorLayout& dst) { | |||
| deduce_layout_fwd(src, dst); | |||
| } | |||
| void SoftmaxForward::check_exec( | |||
| const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) { | |||
| check_layout_fwd(src, dst); | |||
| auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst); | |||
| megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
| } | |||
| void SoftmaxBackward::check_exec( | |||
| const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad, | |||
| size_t workspace_in_bytes) { | |||
| megdnn_assert_eq_layout(src, diff); | |||
| megdnn_assert_eq_layout(src, grad); | |||
| auto required_workspace_in_bytes = get_workspace_in_bytes(src, diff, grad); | |||
| megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
| } | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -76,6 +76,7 @@ | |||
| #include "src/cuda/separable_filter/opr_impl.h" | |||
| #include "src/cuda/sleep/opr_impl.h" | |||
| #include "src/cuda/sliding_window_transpose/opr_impl.h" | |||
| #include "src/cuda/softmax/opr_impl.h" | |||
| #include "src/cuda/split/opr_impl.h" | |||
| #include "src/cuda/svd/opr_impl.h" | |||
| #include "src/cuda/tensor_remap/opr_impl.h" | |||
| @@ -221,6 +222,8 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(LayerNormForward); | |||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(LayerNormBackward); | |||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(DropoutForward); | |||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(DropoutBackward); | |||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(SoftmaxForward); | |||
| MEGDNN_SPECIALIZE_CREATE_OPERATOR(SoftmaxBackward); | |||
| template <typename Opr> | |||
| std::unique_ptr<Opr> HandleImpl::create_operator() { | |||
| @@ -0,0 +1,174 @@ | |||
| /** | |||
| * \file dnn/src/cuda/softmax/opr_impl.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "src/cuda/softmax/opr_impl.h" | |||
| #include "src/cuda/handle.h" | |||
| #include "src/cuda/utils.h" | |||
| using namespace megdnn; | |||
| using namespace cuda; | |||
| int CanonicalAxis(const int axis, const int rank) { | |||
| if (axis < 0) { | |||
| return axis + rank; | |||
| } | |||
| return axis; | |||
| } | |||
| int SizeToAxis(const int axis, const size_t* dims) { | |||
| int size = 1; | |||
| for (int i = 0; i < axis; i++) { | |||
| size *= dims[i]; | |||
| } | |||
| return size; | |||
| } | |||
| int SizeOutAxis(const int axis, const size_t* dims, const int ndim) { | |||
| int size = 1; | |||
| for (int i = axis + 1; i < ndim; i++) { | |||
| size *= dims[i]; | |||
| } | |||
| return size; | |||
| } | |||
| std::vector<int> SoftmaxForwardImpl::init_mode( | |||
| _megdnn_tensor_in src, cudnnSoftmaxMode_t& mode) const { | |||
| auto dims = src.layout.shape; | |||
| const int rank = src.layout.ndim; | |||
| const int axis = CanonicalAxis(param().axis, rank); | |||
| const int dim = dims[axis]; | |||
| const int N = SizeToAxis(axis, dims); | |||
| const int D = SizeOutAxis(axis, dims, rank); | |||
| mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE : CUDNN_SOFTMAX_MODE_CHANNEL; | |||
| return {N, dim, D, 1}; | |||
| } | |||
| int sc(const size_t x) { | |||
| return static_cast<int>(x); | |||
| } | |||
| cudnnDataType_t to_cudnn_dtype( | |||
| DType type, const param::Convolution::Format format = {}) { | |||
| switch (type.enumv()) { | |||
| case DTypeEnum::Float32: | |||
| return CUDNN_DATA_FLOAT; | |||
| case DTypeEnum::Float16: | |||
| return CUDNN_DATA_HALF; | |||
| #if CUDNN_MAJOR >= 7 | |||
| case DTypeEnum::Int32: | |||
| case DTypeEnum::QuantizedS32: | |||
| return CUDNN_DATA_INT32; | |||
| #endif | |||
| #if CUDNN_MAJOR >= 6 | |||
| case DTypeEnum::QuantizedS8: { | |||
| if (format == param::Convolution::Format::NCHW4) | |||
| return CUDNN_DATA_INT8x4; | |||
| #if CUDNN_VERSION >= 7500 | |||
| else if (format == param::Convolution::Format::NCHW32) | |||
| return CUDNN_DATA_INT8x32; | |||
| #endif | |||
| else | |||
| return CUDNN_DATA_INT8; | |||
| } | |||
| case DTypeEnum::Int8: { | |||
| if (format == param::Convolution::Format::NCHW4) | |||
| return CUDNN_DATA_INT8x4; | |||
| #if CUDNN_VERSION >= 7500 | |||
| else if (format == param::Convolution::Format::NCHW32) | |||
| return CUDNN_DATA_INT8x32; | |||
| #endif | |||
| else | |||
| return CUDNN_DATA_INT8; | |||
| } | |||
| #endif | |||
| default: | |||
| #if CUDNN_MAJOR >= 6 | |||
| megdnn_throw("dtype must be float16/float32/int8/int32"); | |||
| #else | |||
| megdnn_throw("dtype must be float16/float32"); | |||
| #endif | |||
| } | |||
| } | |||
| void SoftmaxForwardImpl::exec( | |||
| _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||
| dt_float32 alpha = 1.0f, beta = 0.0f; | |||
| TensorDesc src_desc, dst_desc; | |||
| cudnnSoftmaxMode_t mode; | |||
| std::vector<int> tensor_dims = init_mode(src, mode); | |||
| const int dimA[] = { | |||
| sc(tensor_dims[0]), sc(tensor_dims[1]), sc(tensor_dims[2]), | |||
| sc(tensor_dims[3])}; | |||
| const int strideA[] = { | |||
| sc(tensor_dims[1] * tensor_dims[2] * tensor_dims[3]), | |||
| sc(tensor_dims[2] * tensor_dims[3]), sc(tensor_dims[3]), 1}; | |||
| cudnn_check(cudnnSetTensorNdDescriptor( | |||
| src_desc.desc, to_cudnn_dtype(src.layout.dtype), 4, dimA, strideA)); | |||
| cudnn_check(cudnnSetTensorNdDescriptor( | |||
| dst_desc.desc, to_cudnn_dtype(dst.layout.dtype), 4, dimA, strideA)); | |||
| cudnn_check(cudnnSoftmaxForward( | |||
| cudnn_handle(this->handle()), CUDNN_SOFTMAX_ACCURATE, mode, &alpha, | |||
| src_desc.desc, src.raw_ptr(), &beta, dst_desc.desc, dst.raw_ptr())); | |||
| } | |||
| //================================Softmax Backward============================ | |||
| std::vector<int> SoftmaxBackwardImpl::init_mode( | |||
| _megdnn_tensor_in src, cudnnSoftmaxMode_t& mode) const { | |||
| auto dims = src.layout.shape; | |||
| const int rank = src.layout.ndim; | |||
| const int axis = CanonicalAxis(param().axis, rank); | |||
| const int dim = dims[axis]; | |||
| const int N = SizeToAxis(axis, dims); | |||
| const int D = SizeOutAxis(axis, dims, rank); | |||
| mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE : CUDNN_SOFTMAX_MODE_CHANNEL; | |||
| return {N, dim, D, 1}; | |||
| } | |||
| void SoftmaxBackwardImpl::exec( | |||
| _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace) { | |||
| { | |||
| dt_float32 alpha = 1.0f, beta = 0.0f; | |||
| TensorDesc src_desc, diff_desc, grad_desc; | |||
| cudnnSoftmaxMode_t mode; | |||
| std::vector<int> tensor_dims = init_mode(src, mode); | |||
| const int dimA[] = { | |||
| sc(tensor_dims[0]), sc(tensor_dims[1]), sc(tensor_dims[2]), | |||
| sc(tensor_dims[3])}; | |||
| const int strideA[] = { | |||
| sc(tensor_dims[1] * tensor_dims[2] * tensor_dims[3]), | |||
| sc(tensor_dims[2] * tensor_dims[3]), sc(tensor_dims[3]), 1}; | |||
| cudnn_check(cudnnSetTensorNdDescriptor( | |||
| src_desc.desc, to_cudnn_dtype(src.layout.dtype), 4, dimA, strideA)); | |||
| cudnn_check(cudnnSetTensorNdDescriptor( | |||
| diff_desc.desc, to_cudnn_dtype(diff.layout.dtype), 4, dimA, strideA)); | |||
| cudnn_check(cudnnSetTensorNdDescriptor( | |||
| grad_desc.desc, to_cudnn_dtype(grad.layout.dtype), 4, dimA, strideA)); | |||
| cudnn_check(cudnnSoftmaxBackward( | |||
| cudnn_handle(this->handle()), CUDNN_SOFTMAX_ACCURATE, mode, &alpha, | |||
| src_desc.desc, src.raw_ptr(), diff_desc.desc, diff.raw_ptr(), &beta, | |||
| grad_desc.desc, grad.raw_ptr())); | |||
| } | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,58 @@ | |||
| /** | |||
| * \file dnn/src/cuda/softmax/opr_impl.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #pragma once | |||
| #include "megdnn/oprs.h" | |||
| #include "src/common/algo_base.h" | |||
| #include "src/common/metahelper.h" | |||
| #include "src/cuda/cudnn_wrapper.h" | |||
| #include "src/cuda/utils.h" | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| class SoftmaxForwardImpl final : public SoftmaxForward { | |||
| public: | |||
| using SoftmaxForward::SoftmaxForward; | |||
| std::vector<int> init_mode(_megdnn_tensor_in src, cudnnSoftmaxMode_t& mode) const; | |||
| virtual void exec( | |||
| _megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) override; | |||
| size_t get_workspace_in_bytes( | |||
| const TensorLayout&, /* src */ | |||
| const TensorLayout& /* dst */) override { | |||
| return 0; | |||
| } | |||
| }; | |||
| class SoftmaxBackwardImpl final : public SoftmaxBackward { | |||
| public: | |||
| using SoftmaxBackward::SoftmaxBackward; | |||
| std::vector<int> init_mode(_megdnn_tensor_in src, cudnnSoftmaxMode_t& mode) const; | |||
| size_t get_workspace_in_bytes( | |||
| const TensorLayout& /* input */, const TensorLayout& /* diff */, | |||
| const TensorLayout& /* grad_x */) override { | |||
| return 0; | |||
| } | |||
| virtual void exec( | |||
| _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace) override; | |||
| }; | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -81,6 +81,7 @@ | |||
| #include "src/naive/separable_filter/opr_impl.h" | |||
| #include "src/naive/sleep/opr_impl.h" | |||
| #include "src/naive/sliding_window_transpose/opr_impl.h" | |||
| #include "src/naive/softmax/opr_impl.h" | |||
| #include "src/naive/split/opr_impl.h" | |||
| #include "src/naive/svd/opr_impl.h" | |||
| #include "src/naive/tensor_remap/opr_impl.h" | |||
| @@ -0,0 +1,116 @@ | |||
| /** | |||
| * \file dnn/src/naive/softmax/opr_impl.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "src/naive/softmax/opr_impl.h" | |||
| #include <cstring> | |||
| #include "megdnn/dtype.h" | |||
| #include "megdnn/tensor_iter.h" | |||
| #include "src/common/elemwise_helper.cuh" | |||
| #include "src/common/opr_delegate.h" | |||
| #include "src/common/reduce_helper.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/naive/elemwise/opr_impl.h" | |||
| #include "src/naive/handle.h" | |||
| #include "src/naive/lowbit_utils.h" | |||
| using namespace megdnn; | |||
| namespace { | |||
| template <typename T> | |||
| TensorND op_exec(_megdnn_tensor_in src, megdnn::dt_byte* workspace_ptr, const T& opr) { | |||
| TensorLayout dst_layout; | |||
| opr->deduce_layout(src.layout, dst_layout); | |||
| TensorND dst{workspace_ptr, dst_layout}; | |||
| workspace_ptr += dst_layout.span().dist_byte(); | |||
| auto new_workspace = Workspace{ | |||
| workspace_ptr, opr->get_workspace_in_bytes(src.layout, dst_layout)}; | |||
| workspace_ptr += opr->get_workspace_in_bytes(src.layout, dst_layout); | |||
| opr->exec(src, dst, new_workspace); | |||
| return dst; | |||
| } | |||
| } // namespace | |||
| namespace megdnn { | |||
| namespace naive { | |||
| //===============================Softmax Forward============================ | |||
| void SoftmaxForwardImpl::exec( | |||
| _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||
| auto axis = param().axis; | |||
| if (axis < 0) | |||
| axis += src.layout.ndim; | |||
| check_exec(src.layout, dst.layout, workspace.size); | |||
| auto workspace_ptr = workspace.raw_ptr; | |||
| auto reduce_opr = handle()->create_operator<ReduceForward>(); | |||
| reduce_opr->param().axis = axis; | |||
| reduce_opr->param().mode = Reduce::Mode::MAX; | |||
| reduce_opr->param().data_type = param::Reduce::DataType::DEFAULT; | |||
| TensorND max_tensor = op_exec(src, workspace_ptr, reduce_opr); | |||
| auto elemwise_opr = handle()->create_operator<Elemwise>(); | |||
| elemwise_opr->param().mode = Elemwise::Mode::SUB; | |||
| elemwise_opr->exec({src, max_tensor}, dst); | |||
| elemwise_opr->param().mode = Elemwise::Mode::EXP; | |||
| TensorLayout exp_layout; | |||
| elemwise_opr->deduce_layout({src.layout}, exp_layout); | |||
| TensorND exp_tensor{workspace_ptr, exp_layout}; | |||
| workspace_ptr += exp_layout.span().dist_byte(); | |||
| elemwise_opr->exec({dst}, exp_tensor); | |||
| reduce_opr->param().mode = Reduce::Mode::SUM; | |||
| TensorND down_tensor = op_exec(exp_tensor, workspace_ptr, reduce_opr); | |||
| elemwise_opr->param().mode = Elemwise::Mode::TRUE_DIV; | |||
| elemwise_opr->exec({exp_tensor, down_tensor}, dst); | |||
| } | |||
| //=============================Softmax backward ============================ | |||
| void SoftmaxBackwardImpl::exec( | |||
| _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace) { | |||
| auto axis = param().axis; | |||
| if (axis < 0) | |||
| axis += src.layout.ndim; | |||
| check_exec(src.layout, diff.layout, grad.layout, workspace.size); | |||
| auto workspace_ptr = workspace.raw_ptr; | |||
| TensorLayout mulres = src.layout; | |||
| mulres.dtype = src.layout.dtype; | |||
| mulres.format = src.layout.format; | |||
| mulres.init_contiguous_stride(); | |||
| TensorND mul_tensor{workspace_ptr, mulres}; | |||
| workspace_ptr += mulres.span().dist_byte(); | |||
| TensorND mul_tensor2{workspace_ptr, mulres}; | |||
| workspace_ptr += mulres.span().dist_byte(); | |||
| auto elemwise_opr = handle()->create_operator<Elemwise>(); | |||
| elemwise_opr->param().mode = Elemwise::Mode::MUL; | |||
| elemwise_opr->exec({src, diff}, mul_tensor); | |||
| auto reduce_opr = handle()->create_operator<ReduceForward>(); | |||
| reduce_opr->param().axis = axis; | |||
| reduce_opr->param().mode = Reduce::Mode::SUM; | |||
| reduce_opr->param().data_type = param::Reduce::DataType::DEFAULT; | |||
| TensorND sum_tensor = op_exec(mul_tensor, workspace_ptr, reduce_opr); | |||
| elemwise_opr->exec({sum_tensor, src}, mul_tensor2); | |||
| elemwise_opr->param().mode = Elemwise::Mode::SUB; | |||
| elemwise_opr->exec({mul_tensor, mul_tensor2}, grad); | |||
| } | |||
| } // namespace naive | |||
| } // namespace megdnn | |||
| @@ -0,0 +1,45 @@ | |||
| /** | |||
| * \file dnn/src/naive/softmax/opr_impl.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "megdnn/oprs.h" | |||
| namespace megdnn { | |||
| namespace naive { | |||
| class SoftmaxForwardImpl final : public SoftmaxForward { | |||
| public: | |||
| using SoftmaxForward::SoftmaxForward; | |||
| void exec( | |||
| _megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) override; | |||
| size_t get_workspace_in_bytes( | |||
| const TensorLayout& src, const TensorLayout&) override { | |||
| return src.span().dist_byte() * 2; | |||
| } | |||
| }; | |||
| class SoftmaxBackwardImpl final : public SoftmaxBackward { | |||
| public: | |||
| using SoftmaxBackward::SoftmaxBackward; | |||
| void exec( | |||
| _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad_x, | |||
| _megdnn_workspace workspace) override; | |||
| size_t get_workspace_in_bytes( | |||
| const TensorLayout& src, const TensorLayout&, | |||
| const TensorLayout&) override { | |||
| return src.span().dist_byte() * 3; | |||
| } | |||
| }; | |||
| } // namespace naive | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,41 @@ | |||
| /** | |||
| * \file dnn/test/common/softmax.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include <cstddef> | |||
| #include "megdnn/basic_types.h" | |||
| #include "megdnn/opr_param_defs.h" | |||
| namespace megdnn { | |||
| namespace test { | |||
| namespace softmax { | |||
| struct TestArg { | |||
| param::Softmax param; | |||
| TensorShape ishape; | |||
| TestArg(param::Softmax param, TensorShape ishape) : param(param), ishape(ishape) {} | |||
| }; | |||
| inline std::vector<TestArg> get_args() { | |||
| std::vector<TestArg> args; | |||
| using Param = param::Softmax; | |||
| for (int32_t axis = 0; axis < 5; axis++) { | |||
| args.emplace_back(Param{axis}, TensorShape{2, 23, 32, 30, 17}); | |||
| } | |||
| return args; | |||
| } | |||
| } // namespace softmax | |||
| } // namespace test | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,71 @@ | |||
| /** | |||
| * \file dnn/test/cuda/softmax.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "test/cuda/fixture.h" | |||
| #include "megdnn/tensor_iter.h" | |||
| #include "test/common/checker.h" | |||
| #include "test/common/softmax.h" | |||
| #include "src/common/utils.h" | |||
| #include "test/cuda/utils.h" | |||
| // to check cudnn version | |||
| #include <cudnn.h> | |||
| #include "test/cuda/benchmark.h" | |||
| namespace megdnn { | |||
| namespace test { | |||
| TEST_F(CUDA, SOFTMAX_FORWARD) { | |||
| auto args = softmax::get_args(); | |||
| std::vector<DType> dtypes{dtype::Float16(), dtype::Float32()}; | |||
| for (auto dtype : dtypes) | |||
| for (auto&& arg : args) { | |||
| auto param = arg.param; | |||
| auto src = arg.ishape; | |||
| Checker<Softmax> checker(handle_cuda()); | |||
| if (dtype == dtype::BFloat16()) { | |||
| checker.set_epsilon(2e-2); | |||
| } else { | |||
| checker.set_epsilon(1e-2); | |||
| } | |||
| checker.set_param(param).set_dtype(0, dtype).set_dtype(1, dtype).exec( | |||
| TensorShapeArray{src, {}}); | |||
| } | |||
| } | |||
| TEST_F(CUDA, SOFTMAX_BACKWARD) { | |||
| auto args = softmax::get_args(); | |||
| for (auto&& arg : args) { | |||
| Checker<SoftmaxBackward> checker(handle_cuda()); | |||
| TensorLayout ilayout = TensorLayout(arg.ishape, dtype::Float32()); | |||
| TensorLayout olayout; | |||
| { | |||
| auto opr = handle_cuda()->create_operator<SoftmaxForward>(); | |||
| opr->param() = arg.param; | |||
| opr->deduce_layout(ilayout, olayout); | |||
| } | |||
| auto set_dtype = [&checker](DType dtype) { | |||
| checker.set_dtype(0, dtype).set_dtype(1, dtype).set_dtype(2, dtype); | |||
| }; | |||
| set_dtype(dtype::Float32()); | |||
| checker.set_epsilon(1e-3).set_param(arg.param).exec( | |||
| TensorShapeArray{ilayout, olayout, ilayout}); | |||
| } | |||
| } | |||
| } // namespace test | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,56 @@ | |||
| /** | |||
| * \file dnn/test/naive/softmax.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "test/naive/fixture.h" | |||
| #include "megdnn/oprs/nn.h" | |||
| #include "test/common/checker.h" | |||
| using namespace megdnn; | |||
| using namespace test; | |||
| TEST_F(NAIVE, SOFTMAX_FORWARD) { | |||
| Checker<Softmax> checker(handle(), /* check_dispatch */ false); | |||
| Softmax::Param param{0}; | |||
| TensorND input = TensorValue( | |||
| {2, 2, 2, 2}, dtype::Float32(), | |||
| {0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); | |||
| TensorND output = TensorValue( | |||
| {2, 2, 2, 2}, dtype::Float32(), | |||
| {0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.9997, | |||
| 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997}); | |||
| checker.set_param(param).exect(Testcase{input, {}}, Testcase{{}, output}); | |||
| } | |||
| TEST_F(NAIVE, SOFTMAX_BACKWARD) { | |||
| Checker<SoftmaxBackward> checker(handle(), /* check_dispatch */ false); | |||
| Softmax::Param param{0}; | |||
| TensorND input = TensorValue( | |||
| {2, 2, 2, 2}, dtype::Float32(), | |||
| {0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.0003, 0.9997, | |||
| 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997, 0.9997}); | |||
| TensorND diff = TensorValue( | |||
| {2, 2, 2, 2}, dtype::Float32(), | |||
| {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}); | |||
| TensorND output = TensorValue( | |||
| {2, 2, 2, 2}, dtype::Float32(), | |||
| {0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}); | |||
| checker.set_param(param).exect(Testcase{input, diff, {}}, Testcase{{}, {}, output}); | |||
| } | |||
| @@ -1061,10 +1061,15 @@ def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor: | |||
| """ | |||
| if axis is None: | |||
| axis = _get_softmax_axis(len(inp.shape)) | |||
| offset = inp.max(axis=axis, keepdims=True).detach() | |||
| cached = exp(inp - offset) | |||
| down = sum(cached, axis=axis, keepdims=True) | |||
| return cached / down | |||
| if isinstance(axis, list): | |||
| offset = inp.max(axis=axis, keepdims=True).detach() | |||
| cached = exp(inp - offset) | |||
| down = sum(cached, axis=axis, keepdims=True) | |||
| return cached / down | |||
| else: | |||
| op = builtin.Softmax(axis=axis,) | |||
| (output,) = apply(op, inp) | |||
| return output | |||
| def layer_norm( | |||
| @@ -0,0 +1,52 @@ | |||
| /** | |||
| * \file imperative/src/impl/ops/softmax.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "megbrain/opr/dnn/softmax.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| #include "../dnn_op_helper.h" | |||
| #include "../op_trait.h" | |||
| namespace mgb { | |||
| namespace imperative { | |||
| namespace { | |||
| namespace softmax { | |||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| auto&& softmax = static_cast<const Softmax&>(def); | |||
| OperatorNodeConfig config{softmax.make_name()}; | |||
| return opr::Softmax::make(inputs[0], softmax.param(), config); | |||
| } | |||
| std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||
| auto* node = &node_->cast_final_safe<opr::Softmax>(); | |||
| return Softmax::make(node->param()); | |||
| } | |||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| const OpDef&, const SmallVector<LogicalTensorDesc>& inputs) { | |||
| SmallVector<LogicalTensorDesc> out_shapes(1); | |||
| auto&& i0 = inputs[0]; | |||
| out_shapes[0] = {i0.layout, i0.comp_node}; | |||
| return {out_shapes, true}; | |||
| } | |||
| OP_TRAIT_REG(Softmax, Softmax, opr::Softmax) | |||
| .make_from_op_node(make_from_op_node) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .infer_output_attrs_fallible(infer_output_attrs_fallible) | |||
| .fallback(); | |||
| } // namespace softmax | |||
| } // namespace | |||
| } // namespace imperative | |||
| } // namespace mgb | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -354,6 +354,7 @@ def FakeQuant: MgbHashableOp<"FakeQuant", [FakeQuantParam]>; | |||
| def AssertEqual: MgbHashableOp<"AssertEqual",[AssertEqualParam]>; | |||
| def TQT: MgbHashableOp<"TQT", [TQTParam]>; | |||
| def LSQ: MgbHashableOp<"LSQ", [LSQParam]>; | |||
| def Softmax: MgbHashableOp<"Softmax", [SoftmaxParam]>; | |||
| def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypeParam]> { | |||
| let extraArguments = (ins | |||
| MgbDTypeAttr:$dtype | |||
| @@ -327,4 +327,7 @@ decl_opr('TQT', | |||
| decl_opr('LSQ', | |||
| inputs=[Doc('src','input tensor'),Doc('scale','scale tensor'),Doc('zero_point','zero point tensor'),Doc('grad_scale','grad scale tensor')], | |||
| params='LSQ') | |||
| decl_opr('Softmax', | |||
| inputs=[Doc('src','input tensor')], | |||
| params='Softmax') | |||
| # vim: ft=python | |||
| @@ -25,6 +25,7 @@ | |||
| #include "megbrain/opr/dnn/roi_align.h" | |||
| #include "megbrain/opr/dnn/roi_pooling.h" | |||
| #include "megbrain/opr/dnn/sliding_window_transpose.h" | |||
| #include "megbrain/opr/dnn/softmax.h" | |||
| #include "megbrain/opr/dnn/tqt.h" | |||
| #include "megbrain/serialization/sereg.h" | |||
| #include "megdnn/opr_param_defs.h" | |||
| @@ -324,6 +325,19 @@ struct OprMaker<opr::LSTMBackward, 9> { | |||
| } | |||
| }; | |||
| template <> | |||
| struct OprMaker<opr::SoftmaxBackward, 2> { | |||
| using Param = opr::SoftmaxBackward::Param; | |||
| static cg::OperatorNodeBase* make( | |||
| const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph, | |||
| const OperatorNodeConfig& config) { | |||
| MGB_MARK_USED_VAR(graph); | |||
| return opr::SoftmaxBackward::make(i[0], i[1], param, config) | |||
| .node() | |||
| ->owner_opr(); | |||
| } | |||
| }; | |||
| template <> | |||
| struct OprLoadDumpImpl<opr::AdaptivePoolingBackward, 0> | |||
| : public GeneralOprLoadDumpImpl< | |||
| @@ -720,6 +734,8 @@ MGB_SEREG_OPR(RNNForward, 3); | |||
| MGB_SEREG_OPR(RNNBackward, 7); | |||
| MGB_SEREG_OPR(LSTMForward, 4); | |||
| MGB_SEREG_OPR(LSTMBackward, 9); | |||
| MGB_SEREG_OPR(Softmax, 1); | |||
| MGB_SEREG_OPR(SoftmaxBackward, 2); | |||
| } // namespace opr | |||
| } // namespace mgb | |||
| @@ -0,0 +1,124 @@ | |||
| /** | |||
| * \file src/opr/impl/dnn/softmax.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "megbrain/opr/dnn/softmax.h" | |||
| #include "megbrain/graph/grad_impl.h" | |||
| #include "megbrain/opr/internal/out_shape_by_sym_var.h" | |||
| #include "megbrain/opr/utility.h" | |||
| #include "../internal/megdnn_opr_wrapper.inl" | |||
| using namespace mgb; | |||
| using namespace opr; | |||
| /* ==================== SoftmaxForward ==================== */ | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(SoftmaxForward); | |||
| SoftmaxForward::SoftmaxForward( | |||
| VarNode* inp, const Param& param, const OperatorNodeConfig& config) | |||
| : Super{inp->owner_graph(), config, "softmax", {inp}} { | |||
| init_megdnn_opr(*this, param); | |||
| add_input({inp}); | |||
| output(0)->dtype(inp->dtype()); | |||
| } | |||
| SymbolVar SoftmaxForward::make( | |||
| SymbolVar inp, const Param& param, const OperatorNodeConfig& config) { | |||
| auto out = inp.node() | |||
| ->owner_graph() | |||
| ->insert_opr(std::make_unique<SoftmaxForward>( | |||
| inp.node(), param, config)) | |||
| ->output(); | |||
| return out[0]; | |||
| } | |||
| void SoftmaxForward::get_output_var_shape( | |||
| const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const { | |||
| out_shape[0] = inp_shape[0]; | |||
| } | |||
| size_t SoftmaxForward::get_workspace_size_bytes( | |||
| const TensorShapeArray& input_shapes, | |||
| const TensorShapeArray& output_shapes) const { | |||
| return megdnn_opr()->get_workspace_in_bytes( | |||
| {input_shapes[0], input(0)->dtype(), input(0)->format()}, | |||
| {output_shapes[0], output(0)->dtype(), output(0)->format()}); | |||
| } | |||
| void SoftmaxForward::scn_do_execute() { | |||
| megdnn_opr()->exec( | |||
| input(0)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(), | |||
| intl::get_megdnn_workspace_from_var(output().back())); | |||
| } | |||
| #if MGB_ENABLE_GRAD | |||
| MGB_IMPL_OPR_GRAD(SoftmaxForward) { | |||
| SymbolVar grad = SoftmaxBackward::make(opr.output(0), out_grad[0], opr.param()); | |||
| return grad.node(); | |||
| } | |||
| #endif | |||
| // /* ==================== SoftmaxBackward ==================== */ | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(SoftmaxBackward); | |||
| SoftmaxBackward::SoftmaxBackward( | |||
| VarNode* src, VarNode* diff, const Param& param, | |||
| const OperatorNodeConfig& config) | |||
| : Super({src->owner_graph(), config, "Softmax_backward", {src, diff}}, 0, | |||
| true) { | |||
| init_megdnn_opr(*this, param); | |||
| add_input({src, diff}); | |||
| } | |||
| SymbolVar SoftmaxBackward::make( | |||
| SymbolVar src, SymbolVar diff, const Param& param, | |||
| const OperatorNodeConfig& config) { | |||
| auto out = src.node() | |||
| ->owner_graph() | |||
| ->insert_opr(std::make_unique<SoftmaxBackward>( | |||
| src.node(), diff.node(), param, config)) | |||
| ->output(); | |||
| return out[0]; | |||
| } | |||
| void SoftmaxBackward::init_output_static_infer_desc() { | |||
| using namespace cg::static_infer; | |||
| auto&& mgr = owner_graph()->static_infer_manager(); | |||
| mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(input(0))); | |||
| this->init_output_static_infer_desc_workspace(false); | |||
| } | |||
| void SoftmaxBackward::init_output_dtype() { | |||
| output(0)->dtype(input(0)->dtype()); | |||
| } | |||
| size_t SoftmaxBackward::get_workspace_size_bytes( | |||
| const TensorShapeArray& input_shapes, | |||
| const TensorShapeArray& output_shapes) const { | |||
| return megdnn_opr()->get_workspace_in_bytes( | |||
| {input_shapes[0], input(0)->dtype(), input(0)->format()}, | |||
| {input_shapes[1], input(1)->dtype(), input(1)->format()}, | |||
| {output_shapes[0], output(0)->dtype(), output(0)->format()}); | |||
| } | |||
| void SoftmaxBackward::scn_do_execute() { | |||
| megdnn_opr()->exec( | |||
| input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), | |||
| output(0)->dev_tensor().as_megdnn(), | |||
| intl::get_megdnn_workspace_from_var(output().back())); | |||
| } | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -0,0 +1,64 @@ | |||
| /** | |||
| * \file src/opr/include/megbrain/opr/dnn/softmax.h | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||
| #include "megdnn/oprs/nn.h" | |||
| namespace mgb { | |||
| namespace opr { | |||
| MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | |||
| SoftmaxForward, intl::MegDNNOprWrapperFwd<megdnn::SoftmaxForward>) // { | |||
| public: | |||
| MGE_WIN_DECLSPEC_FUC SoftmaxForward( | |||
| VarNode* src, const Param& param, const OperatorNodeConfig& config); | |||
| MGE_WIN_DECLSPEC_FUC static SymbolVar make( | |||
| SymbolVar src, const Param& param = {}, | |||
| const OperatorNodeConfig& config = {}); | |||
| private: | |||
| void get_output_var_shape( | |||
| const TensorShapeArray& inp_shape, | |||
| TensorShapeArray& out_shape) const override; | |||
| size_t get_workspace_size_bytes( | |||
| const TensorShapeArray& input_shapes, | |||
| const TensorShapeArray& output_shapes) const override; | |||
| void scn_do_execute() override; | |||
| }; | |||
| using Softmax = SoftmaxForward; | |||
| MGB_DEFINE_OPR_CLASS_WITH_EXPORT( | |||
| SoftmaxBackward, intl::MegDNNOprWrapperBwd<megdnn::SoftmaxBackward>) // { | |||
| public: | |||
| MGE_WIN_DECLSPEC_FUC SoftmaxBackward( | |||
| VarNode* x, VarNode* y_grad, const Param& param, | |||
| const OperatorNodeConfig& config); | |||
| MGE_WIN_DECLSPEC_FUC static SymbolVar make( | |||
| SymbolVar x, SymbolVar y_grad, const Param& param = {}, | |||
| const OperatorNodeConfig& config = {}); | |||
| private: | |||
| void init_output_static_infer_desc() override; | |||
| void init_output_dtype() override; | |||
| size_t get_workspace_size_bytes( | |||
| const TensorShapeArray& input_shapes, | |||
| const TensorShapeArray& output_shapes) const override; | |||
| void scn_do_execute() override; | |||
| }; | |||
| } // namespace opr | |||
| } // namespace mgb | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -0,0 +1,65 @@ | |||
| /** | |||
| * \file src/opr/test/dnn/softmax.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "megbrain/opr/dnn/softmax.h" | |||
| #include "megbrain/comp_node_env.h" | |||
| #include "megbrain/test/autocheck.h" | |||
| using namespace std; | |||
| using namespace mgb; | |||
| namespace { | |||
| using Param = opr::SoftmaxForward::Param; | |||
| void run(int32_t axis) { | |||
| using Checker = AutoOprChecker<1, 1>; | |||
| Param param{axis}; | |||
| auto make_graph = [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray { | |||
| auto o0 = opr::SoftmaxForward::make(inputs[0], param); | |||
| return {o0}; | |||
| }; | |||
| auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) { | |||
| auto opr = | |||
| MegDNNHandle::get(CompNodeEnv::from_comp_node(CompNode::default_cpu())) | |||
| ->create_operator<megdnn::SoftmaxForward>(); | |||
| opr->param() = param; | |||
| dest[0].dtype(dtype::Float32()) | |||
| .comp_node(inp[0]->comp_node()) | |||
| .resize(inp[0]->shape()); | |||
| size_t wk_size = | |||
| opr->get_workspace_in_bytes(inp[0]->layout(), dest[0].layout()); | |||
| std::unique_ptr<dt_byte[]> wk_store{new dt_byte[wk_size]}; | |||
| opr->exec(inp[0]->as_megdnn(), dest[0].as_megdnn(), {wk_store.get(), wk_size}); | |||
| }; | |||
| auto gen = [&](HostTensorND& src) { | |||
| HostTensorGenerator<dtype::Float32, RandomDistribution::GAUSSIAN> src_gen(10.f); | |||
| src = *src_gen(src.shape(), src.comp_node()); | |||
| }; | |||
| Checker::RunOptions opt; | |||
| opt.numdiff_max_err = 1e-4; | |||
| Checker checker{make_graph, fwd}; | |||
| checker.set_input_generator(0, gen); | |||
| checker.run({TensorShape{1, 2, 3, 4}}, opt) | |||
| .run({TensorShape{2, 3, 8, 8}}, opt) | |||
| .run({TensorShape{1, 3, 4, 4}}, opt); | |||
| } | |||
| } // anonymous namespace | |||
| TEST(TestOprDNN, SoftmaxForward) { | |||
| REQUIRE_GPU(1); | |||
| run(1); | |||
| } | |||
| @@ -121,6 +121,7 @@ union OperatorParam { | |||
| param.RNNCell = 87, | |||
| param.RNN = 88, | |||
| param.LSTM = 89, | |||
| param.Softmax = 90, | |||
| } | |||
| table Operator { | |||