/** * \file dnn/src/cuda/convolution/helper.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #pragma once #include "./opr_impl.h" #include "src/common/algo_chooser.h" #include "src/common/utils.h" #include "src/cuda/cudnn_wrapper.h" #include "src/cuda/handle.h" namespace megdnn { namespace cuda { namespace convolution { using CanonizedFilterMeta = ConvolutionForward::CanonizedFilterMeta; //! conv size descriptor in the forward view struct ForwardSizeArgs { HandleImpl* handle; const TensorLayout* src_layout; const TensorLayout* filter_layout; CanonizedFilterMeta filter_meta; const TensorLayout* dst_layout; }; //! whether cudnn is supported for a filter meta bool is_cudnn_supported(const ForwardSizeArgs& args); //! get workspace bundle for matmul algo SmallVector matmul_get_workspace_bundle(const ForwardSizeArgs& args); struct CUDNNForwardDescs { TensorDesc src_desc, dst_desc; FilterDesc filter_desc; ConvDesc conv_desc; void set( const TensorLayout& src, const CanonizedFilterMeta& filter, const TensorLayout& dst, const param::Convolution& param) { src_desc.set(src, param.format); filter_desc.set(filter); dst_desc.set(dst, param.format); conv_desc.set(src.dtype, param, filter.group); } }; struct CUDNNBwdDataDescs { TensorDesc diff_desc, grad_desc; FilterDesc filter_desc; ConvDesc conv_desc; void set( const CanonizedFilterMeta& filter, const TensorLayout& diff, const TensorLayout& grad, const param::Convolution& param) { filter_desc.set(filter); diff_desc.set(diff, param.format); grad_desc.set(grad, param.format); conv_desc.set(filter.dtype, param, filter.group); } }; struct CUDNNBwdFilterDescs { TensorDesc diff_desc, src_desc; FilterDesc grad_desc; ConvDesc conv_desc; void set( const TensorLayout& src, const TensorLayout& diff, const CanonizedFilterMeta& grad, const param::Convolution& param) { src_desc.set(src, param.format); diff_desc.set(diff, param.format); grad_desc.set(grad); conv_desc.set(src.dtype, param, grad.group); } }; /*! * \brief flip conv filter * * Flip conv filter pointed by \p raw_ptr, store result in workspace, and * change \p raw_ptr to workspace. */ void flip_filter( const ForwardSizeArgs& args, const Workspace& workspace, RefPtr& raw_ptr); } // namespace convolution } // namespace cuda } // namespace megdnn // vim: syntax=cpp.doxygen