| @@ -54,9 +54,12 @@ if(MGE_WITH_CUDA) | |||||
| add_library(cutlass INTERFACE) | add_library(cutlass INTERFACE) | ||||
| target_include_directories( | target_include_directories( | ||||
| cutlass | cutlass | ||||
| INTERFACE $<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/third_party/cutlass/include>) | |||||
| add_library(cudnn-frontend INTERFACE) | |||||
| target_include_directories( | |||||
| cudnn-frontend | |||||
| INTERFACE | INTERFACE | ||||
| $<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/third_party/cutlass/include> | |||||
| $<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/third_party/cutlass/tools/util/include>) | |||||
| $<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/third_party/cudnn-frontend/include>) | |||||
| endif() | endif() | ||||
| if(MGE_WITH_TEST) | if(MGE_WITH_TEST) | ||||
| @@ -22,7 +22,16 @@ public: | |||||
| bool operator==(const KeyStorage& k) const { return k1 == k.k1 && k2 == k.k2; } | bool operator==(const KeyStorage& k) const { return k1 == k.k1 && k2 == k.k2; } | ||||
| }; | }; | ||||
| struct Key { | |||||
| struct Hash { | |||||
| size_t operator()(const KeyStorage& k) const { | |||||
| size_t h1 = k.k1; | |||||
| size_t h2 = k.k2; | |||||
| h1 ^= h2 + 0x9e3779b9 + (h1 << 6) + (h1 >> 2); | |||||
| return h1; | |||||
| } | |||||
| }; | |||||
| class Key { | |||||
| Handle* m_handle; | Handle* m_handle; | ||||
| uint32_t m_opr_type; | uint32_t m_opr_type; | ||||
| const TensorLayout* m_inp_layouts_ptr; | const TensorLayout* m_inp_layouts_ptr; | ||||
| @@ -62,14 +71,6 @@ public: | |||||
| MGE_WIN_DECLSPEC_FUC void clear(); | MGE_WIN_DECLSPEC_FUC void clear(); | ||||
| private: | private: | ||||
| struct Hash { | |||||
| size_t operator()(const KeyStorage& k) const { | |||||
| size_t h1 = k.k1; | |||||
| size_t h2 = k.k2; | |||||
| h1 ^= h2 + 0x9e3779b9 + (h1 << 6) + (h1 >> 2); | |||||
| return h1; | |||||
| } | |||||
| }; | |||||
| std::unordered_map<KeyStorage, Result, Hash> m_heuristic_cache; | std::unordered_map<KeyStorage, Result, Hash> m_heuristic_cache; | ||||
| #if __DEPLOY_ON_XP_SP2__ | #if __DEPLOY_ON_XP_SP2__ | ||||
| size_t m_mtx; | size_t m_mtx; | ||||
| @@ -222,6 +222,8 @@ target_link_libraries(megdnn PUBLIC opr_param_defs) | |||||
| if(MGE_WITH_CUDA) | if(MGE_WITH_CUDA) | ||||
| target_link_libraries(megdnn PRIVATE $<BUILD_INTERFACE:cutlass>) | target_link_libraries(megdnn PRIVATE $<BUILD_INTERFACE:cutlass>) | ||||
| target_include_directories(megdnn PRIVATE ${CUDNN_INCLUDE_DIR}) | target_include_directories(megdnn PRIVATE ${CUDNN_INCLUDE_DIR}) | ||||
| target_link_libraries(megdnn PRIVATE $<BUILD_INTERFACE:cudnn-frontend>) | |||||
| endif() | endif() | ||||
| if(MGE_WITH_ROCM) | if(MGE_WITH_ROCM) | ||||
| @@ -14,6 +14,12 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() { | |||||
| non_cudnn_algos.push_back(&matmul8x8x32); | non_cudnn_algos.push_back(&matmul8x8x32); | ||||
| non_cudnn_algos.push_back(&batched_matmul); | non_cudnn_algos.push_back(&batched_matmul); | ||||
| non_cudnn_algos.push_back(&int1_simple); | non_cudnn_algos.push_back(&int1_simple); | ||||
| #if CUDNN_VERSION > 8004 | |||||
| all_algos.push_back(&cudnn_conv_v8); | |||||
| all_algos.push_back(&cudnn_conv_bias_activation_v8); | |||||
| #endif | |||||
| fill_cudnn_algos(); | fill_cudnn_algos(); | ||||
| for (auto&& algo : cudnn_conv_bias_activations) { | for (auto&& algo : cudnn_conv_bias_activations) { | ||||
| all_algos.push_back(&algo); | all_algos.push_back(&algo); | ||||
| @@ -169,6 +175,30 @@ std::string ConvBiasForwardImpl::AlgoBase::SizeArgs::to_string() const { | |||||
| nonlinear_mode_str.c_str()); | nonlinear_mode_str.c_str()); | ||||
| } | } | ||||
| param::Convolution ConvBiasForwardImpl::AlgoBase::get_param_convolution( | |||||
| const SizeArgs& args) const { | |||||
| param::Convolution::Mode mode; | |||||
| param::Convolution::Sparse sparse = args.filter_meta.group > 1 | |||||
| ? param::Convolution::Sparse::GROUP | |||||
| : param::Convolution::Sparse::DENSE; | |||||
| if (args.filter_meta.should_flip) { | |||||
| mode = param::Convolution::Mode::CONVOLUTION; | |||||
| } else { | |||||
| mode = param::Convolution::Mode::CROSS_CORRELATION; | |||||
| } | |||||
| return param::Convolution{ | |||||
| mode, | |||||
| args.filter_meta.padding[0], | |||||
| args.filter_meta.padding[1], | |||||
| args.filter_meta.stride[0], | |||||
| args.filter_meta.stride[1], | |||||
| args.filter_meta.dilation[1], | |||||
| args.filter_meta.dilation[0], | |||||
| sparse, | |||||
| args.filter_meta.format, | |||||
| args.opr->param().compute_mode}; | |||||
| } | |||||
| void ConvBiasForwardImpl::AlgoPack::fill_cudnn_algos() { | void ConvBiasForwardImpl::AlgoPack::fill_cudnn_algos() { | ||||
| for (auto&& algo : CudnnAlgoPack::conv_fwd_algos()) { | for (auto&& algo : CudnnAlgoPack::conv_fwd_algos()) { | ||||
| cudnn_conv_bias_activations.push_back(algo.first); | cudnn_conv_bias_activations.push_back(algo.first); | ||||
| @@ -76,6 +76,8 @@ public: | |||||
| CUDA_IMPLICIT_BATCHED_GEMM_FMA_NCHW_F32, | CUDA_IMPLICIT_BATCHED_GEMM_FMA_NCHW_F32, | ||||
| CUDA_IMPLICIT_BATCHED_GEMM_HMMA_NCHW_F16, | CUDA_IMPLICIT_BATCHED_GEMM_HMMA_NCHW_F16, | ||||
| CUDA_SIMPLE_INT1, | CUDA_SIMPLE_INT1, | ||||
| CUDA_CUDNN_CONV_V8, | |||||
| CUDA_CUDNN_CONVBIAS_V8, | |||||
| }; | }; | ||||
| using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>; | ||||
| @@ -157,12 +159,40 @@ public: | |||||
| } | } | ||||
| virtual bool is_cudnn() const { return false; } | virtual bool is_cudnn() const { return false; } | ||||
| param::Convolution get_param_convolution(const SizeArgs& args) const; | |||||
| }; | |||||
| class ConvBiasForwardImpl::AlgoCUDNNConvBiasActivationBase : public AlgoBase { | |||||
| public: | |||||
| AlgoCUDNNConvBiasActivationBase() = default; | |||||
| virtual ~AlgoCUDNNConvBiasActivationBase() = default; | |||||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
| void exec(const ExecArgs& args) const override; | |||||
| bool is_cudnn() const override { return true; } | |||||
| size_t get_preprocess_workspace_in_bytes(const SizeArgs& args) const override; | |||||
| SmallVector<TensorLayout> deduce_preprocessed_filter_layout( | |||||
| const SizeArgs& args) const override; | |||||
| void exec_preprocess(const ExecArgs& args) const override; | |||||
| protected: | |||||
| virtual size_t cudnn_get_workspace_in_bytes(const SizeArgs& args) const = 0; | |||||
| virtual void cudnn_execute( | |||||
| const ExecArgs& args, const Workspace& workspace, float alpha, | |||||
| float beta) const = 0; | |||||
| protected: | |||||
| std::string m_name; | |||||
| }; | }; | ||||
| class ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation final : public AlgoBase { | |||||
| class ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation final | |||||
| : public AlgoCUDNNConvBiasActivationBase { | |||||
| public: | public: | ||||
| AlgoCUDNNConvBiasActivation(cudnnConvolutionFwdAlgo_t cudnn_enum) | AlgoCUDNNConvBiasActivation(cudnnConvolutionFwdAlgo_t cudnn_enum) | ||||
| : m_cudnn_enum(cudnn_enum) { | |||||
| : AlgoCUDNNConvBiasActivationBase(), m_cudnn_enum(cudnn_enum) { | |||||
| megdnn_assert( | megdnn_assert( | ||||
| CudnnAlgoPack::conv_fwd_algos().find(cudnn_enum) != | CudnnAlgoPack::conv_fwd_algos().find(cudnn_enum) != | ||||
| CudnnAlgoPack::conv_fwd_algos().end()); | CudnnAlgoPack::conv_fwd_algos().end()); | ||||
| @@ -171,9 +201,6 @@ public: | |||||
| "CUDNN:ConvBiasActivation:" + m_attr.name, {}); | "CUDNN:ConvBiasActivation:" + m_attr.name, {}); | ||||
| } | } | ||||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
| void exec(const ExecArgs& args) const override; | |||||
| param::Convolution get_param_convolution(const SizeArgs& args) const; | |||||
| bool is_available(const SizeArgs&) const override; | bool is_available(const SizeArgs&) const override; | ||||
| const char* name() const override { return m_name.c_str(); } | const char* name() const override { return m_name.c_str(); } | ||||
| @@ -191,8 +218,6 @@ public: | |||||
| cudnnConvolutionFwdAlgo_t cudnn_enum() { return m_cudnn_enum; } | cudnnConvolutionFwdAlgo_t cudnn_enum() { return m_cudnn_enum; } | ||||
| bool is_cudnn() const override { return true; } | |||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN_CONVBIAS) | MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN_CONVBIAS) | ||||
| std::string param() const override { | std::string param() const override { | ||||
| @@ -202,11 +227,46 @@ public: | |||||
| } | } | ||||
| private: | private: | ||||
| std::string m_name; | |||||
| size_t cudnn_get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
| void cudnn_execute( | |||||
| const ExecArgs& args, const Workspace& workspace, float alpha, | |||||
| float beta) const override; | |||||
| private: | |||||
| cudnnConvolutionFwdAlgo_t m_cudnn_enum; | cudnnConvolutionFwdAlgo_t m_cudnn_enum; | ||||
| CudnnAlgoPack::Attr m_attr; | CudnnAlgoPack::Attr m_attr; | ||||
| }; | }; | ||||
| #if CUDNN_VERSION > 8004 | |||||
| class ConvBiasForwardImpl::AlgoCUDNNConvBiasActivationV8 final | |||||
| : public AlgoCUDNNConvBiasActivationBase { | |||||
| public: | |||||
| AlgoCUDNNConvBiasActivationV8() : AlgoCUDNNConvBiasActivationBase() { | |||||
| m_name = ConvBiasForward::algo_name<DefaultParam>( | |||||
| "CUDNN:ConvBiasActivationV8", {}); | |||||
| } | |||||
| ~AlgoCUDNNConvBiasActivationV8() = default; | |||||
| bool is_available(const SizeArgs& args) const override; | |||||
| AlgoAttribute attribute() const override { | |||||
| return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; | |||||
| } | |||||
| const char* name() const override { return m_name.c_str(); } | |||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN_CONVBIAS_V8) | |||||
| std::string param() const override { return ""; } | |||||
| private: | |||||
| size_t cudnn_get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
| void cudnn_execute( | |||||
| const ExecArgs& args, const Workspace& workspace, float alpha, | |||||
| float beta) const override; | |||||
| }; | |||||
| #endif | |||||
| class ConvBiasForwardImpl::AlgoChanwise final : public AlgoBase { | class ConvBiasForwardImpl::AlgoChanwise final : public AlgoBase { | ||||
| public: | public: | ||||
| bool is_available(const SizeArgs& args) const override; | bool is_available(const SizeArgs& args) const override; | ||||
| @@ -284,9 +344,34 @@ private: | |||||
| mutable std::string m_name; | mutable std::string m_name; | ||||
| }; | }; | ||||
| class ConvBiasForwardImpl::AlgoCUDNNConv final : public AlgoBase { | |||||
| class ConvBiasForwardImpl::AlgoCUDNNConvBase : public AlgoBase { | |||||
| public: | |||||
| AlgoCUDNNConvBase() = default; | |||||
| virtual ~AlgoCUDNNConvBase() = default; | |||||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override { | |||||
| return get_workspace_bundle(nullptr, args).total_size_in_bytes(); | |||||
| } | |||||
| void exec(const ExecArgs& args) const override; | |||||
| bool is_cudnn() const override { return true; } | |||||
| protected: | |||||
| virtual size_t cudnn_get_workspace_in_bytes(const SizeArgs& args) const = 0; | |||||
| virtual void cudnn_execute( | |||||
| const ExecArgs& args, const Workspace& workspace) const = 0; | |||||
| private: | |||||
| WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | |||||
| protected: | |||||
| std::string m_name; | |||||
| }; | |||||
| class ConvBiasForwardImpl::AlgoCUDNNConv final : public AlgoCUDNNConvBase { | |||||
| public: | public: | ||||
| AlgoCUDNNConv(cudnnConvolutionFwdAlgo_t cudnn_enum) : m_cudnn_enum(cudnn_enum) { | |||||
| AlgoCUDNNConv(cudnnConvolutionFwdAlgo_t cudnn_enum) | |||||
| : AlgoCUDNNConvBase(), m_cudnn_enum(cudnn_enum) { | |||||
| megdnn_assert( | megdnn_assert( | ||||
| CudnnAlgoPack::conv_fwd_algos().find(cudnn_enum) != | CudnnAlgoPack::conv_fwd_algos().find(cudnn_enum) != | ||||
| CudnnAlgoPack::conv_fwd_algos().end()); | CudnnAlgoPack::conv_fwd_algos().end()); | ||||
| @@ -296,8 +381,6 @@ public: | |||||
| } | } | ||||
| bool is_available(const SizeArgs& args) const override; | bool is_available(const SizeArgs& args) const override; | ||||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
| void exec(const ExecArgs& args) const override; | |||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| auto ret = static_cast<AlgoAttribute>(0); | auto ret = static_cast<AlgoAttribute>(0); | ||||
| @@ -314,8 +397,6 @@ public: | |||||
| cudnnConvolutionFwdAlgo_t cudnn_enum() const { return m_cudnn_enum; } | cudnnConvolutionFwdAlgo_t cudnn_enum() const { return m_cudnn_enum; } | ||||
| bool is_cudnn() const override { return true; } | |||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN_CONV) | MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN_CONV) | ||||
| std::string param() const override { | std::string param() const override { | ||||
| @@ -325,12 +406,38 @@ public: | |||||
| } | } | ||||
| private: | private: | ||||
| std::string m_name; | |||||
| size_t cudnn_get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
| void cudnn_execute(const ExecArgs& args, const Workspace& workspace) const override; | |||||
| private: | |||||
| cudnnConvolutionFwdAlgo_t m_cudnn_enum; | cudnnConvolutionFwdAlgo_t m_cudnn_enum; | ||||
| CudnnAlgoPack::Attr m_attr; | CudnnAlgoPack::Attr m_attr; | ||||
| }; | |||||
| WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | |||||
| #if CUDNN_VERSION > 8004 | |||||
| class ConvBiasForwardImpl::AlgoCUDNNConvV8 final : public AlgoCUDNNConvBase { | |||||
| public: | |||||
| AlgoCUDNNConvV8() : AlgoCUDNNConvBase() { | |||||
| m_name = ConvBiasForward::algo_name<DefaultParam>("CUDNN:ConvolutionV8", {}); | |||||
| } | |||||
| bool is_available(const SizeArgs& args) const override; | |||||
| AlgoAttribute attribute() const override { | |||||
| return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; | |||||
| } | |||||
| const char* name() const override { return m_name.c_str(); } | |||||
| MEGDNN_DECL_ALGO_TYPE(CUDA_CUDNN_CONV_V8) | |||||
| std::string param() const override { return ""; } | |||||
| private: | |||||
| size_t cudnn_get_workspace_in_bytes(const SizeArgs& args) const override; | |||||
| void cudnn_execute(const ExecArgs& args, const Workspace& workspace) const override; | |||||
| }; | }; | ||||
| #endif | |||||
| //! compute small matmul in the kernel | //! compute small matmul in the kernel | ||||
| class ConvBiasForwardImpl::AlgoInplaceMatmul final : public AlgoBase { | class ConvBiasForwardImpl::AlgoInplaceMatmul final : public AlgoBase { | ||||
| @@ -1140,6 +1247,10 @@ public: | |||||
| AlgoGroupConvGeneral group; | AlgoGroupConvGeneral group; | ||||
| AlgoBFloat16 bfloat16; | AlgoBFloat16 bfloat16; | ||||
| AlgoSimpleInt1 int1_simple; | AlgoSimpleInt1 int1_simple; | ||||
| #if CUDNN_VERSION > 8004 | |||||
| AlgoCUDNNConvV8 cudnn_conv_v8; | |||||
| AlgoCUDNNConvBiasActivationV8 cudnn_conv_bias_activation_v8; | |||||
| #endif | |||||
| AlgoBase* cudnn_conv_bias_act_from_enum(cudnnConvolutionFwdAlgo_t algo); | AlgoBase* cudnn_conv_bias_act_from_enum(cudnnConvolutionFwdAlgo_t algo); | ||||
| @@ -56,99 +56,33 @@ bool ConvBiasForwardImpl::AlgoCUDNNConv::is_available(const SizeArgs& args) cons | |||||
| return status == CUDNN_STATUS_SUCCESS; | return status == CUDNN_STATUS_SUCCESS; | ||||
| } | } | ||||
| WorkspaceBundle ConvBiasForwardImpl::AlgoCUDNNConv::get_workspace_bundle( | |||||
| void* ptr, const SizeArgs& args) const { | |||||
| auto dst_layout = *args.dst_layout; | |||||
| SmallVector<size_t> sizes; | |||||
| if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) { | |||||
| dst_layout.dtype = DType(); | |||||
| args.opr->check_or_deduce_dtype_fwd( | |||||
| args.src_layout->dtype, args.filter_layout->dtype, dst_layout.dtype); | |||||
| sizes.push_back(dst_layout.span().dist_byte()); | |||||
| } | |||||
| if (args.z_layout->ndim > 0 && | |||||
| args.z_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | |||||
| auto z_layout = *args.z_layout; | |||||
| z_layout.dtype = DType(); | |||||
| args.opr->check_or_deduce_dtype_fwd( | |||||
| args.src_layout->dtype, args.filter_layout->dtype, z_layout.dtype); | |||||
| sizes.push_back(z_layout.span().dist_byte()); | |||||
| } | |||||
| SizeArgs conv_args = args; | |||||
| conv_args.dst_layout = &dst_layout; | |||||
| size_t ConvBiasForwardImpl::AlgoCUDNNConv::cudnn_get_workspace_in_bytes( | |||||
| const SizeArgs& args) const { | |||||
| CUDNNForwardDescs D; | CUDNNForwardDescs D; | ||||
| conv_args.init_conv_desc(D); | |||||
| args.init_conv_desc(D); | |||||
| size_t conv_workspace_size; | size_t conv_workspace_size; | ||||
| auto status = cudnnGetConvolutionForwardWorkspaceSize( | |||||
| conv_args.handle->cudnn_handle(), D.src_desc.desc, D.filter_desc.desc, | |||||
| D.conv_desc.conv_desc, D.dst_desc.desc, m_cudnn_enum, &conv_workspace_size); | |||||
| megdnn_assert( | |||||
| status == CUDNN_STATUS_SUCCESS, | |||||
| "conv fwd get workspace failed: %s; info: %s", cudnnGetErrorString(status), | |||||
| args.to_string().c_str()); | |||||
| sizes.insert(sizes.begin(), conv_workspace_size); | |||||
| return {ptr, std::move(sizes)}; | |||||
| } | |||||
| size_t ConvBiasForwardImpl::AlgoCUDNNConv::get_workspace_in_bytes( | |||||
| const SizeArgs& args) const { | |||||
| return get_workspace_bundle(nullptr, args).total_size_in_bytes(); | |||||
| cudnn_check(cudnnGetConvolutionForwardWorkspaceSize( | |||||
| args.handle->cudnn_handle(), D.src_desc.desc, D.filter_desc.desc, | |||||
| D.conv_desc.conv_desc, D.dst_desc.desc, m_cudnn_enum, | |||||
| &conv_workspace_size)); | |||||
| return conv_workspace_size; | |||||
| } | } | ||||
| void ConvBiasForwardImpl::AlgoCUDNNConv::exec(const ExecArgs& args) const { | |||||
| auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | |||||
| TensorND conv_dst_tensor = *args.dst_tensor; | |||||
| if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | |||||
| conv_dst_tensor = TensorND{bundle.get(1), args.dst_tensor->layout}; | |||||
| conv_dst_tensor.layout.dtype = DType(); | |||||
| args.opr->check_or_deduce_dtype_fwd( | |||||
| args.src_layout->dtype, args.filter_layout->dtype, | |||||
| conv_dst_tensor.layout.dtype); | |||||
| } | |||||
| ExecArgs conv_args = args; | |||||
| conv_args.dst_tensor = &conv_dst_tensor; | |||||
| conv_args.dst_layout = &conv_dst_tensor.layout; | |||||
| { | |||||
| CUDNNForwardDescs D; | |||||
| conv_args.init_conv_desc(D); | |||||
| auto conv_workspace = bundle.get_workspace(0); | |||||
| float alpha = 1.0f, beta = 0.0f; | |||||
| auto status = cudnnConvolutionForward( | |||||
| conv_args.handle->cudnn_handle(), &alpha, D.src_desc.desc, | |||||
| conv_args.src_tensor->raw_ptr(), D.filter_desc.desc, | |||||
| conv_args.filter_tensor->raw_ptr(), D.conv_desc.conv_desc, m_cudnn_enum, | |||||
| conv_workspace.raw_ptr, conv_workspace.size, &beta, D.dst_desc.desc, | |||||
| conv_args.dst_tensor->raw_ptr()); | |||||
| megdnn_assert( | |||||
| status == CUDNN_STATUS_SUCCESS, "conv fwd failed: %s; info: %s", | |||||
| cudnnGetErrorString(status), conv_args.to_string().c_str()); | |||||
| } | |||||
| if (args.z_layout->ndim > 0) { | |||||
| auto z_tensor = *args.z_tensor; | |||||
| if (args.z_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | |||||
| z_tensor = TensorND{bundle.get(2), args.z_tensor->layout}; | |||||
| z_tensor.layout.dtype = DType(); | |||||
| args.opr->check_or_deduce_dtype_fwd( | |||||
| args.src_layout->dtype, args.filter_layout->dtype, | |||||
| z_tensor.layout.dtype); | |||||
| auto typecvt = args.handle->create_operator<TypeCvt>(); | |||||
| typecvt->exec(*args.z_tensor, z_tensor); | |||||
| } | |||||
| auto add = args.handle->create_operator<ElemwiseForward>(); | |||||
| add->param().mode = Elemwise::Param::Mode::ADD; | |||||
| add->exec({conv_dst_tensor, z_tensor}, conv_dst_tensor); | |||||
| } | |||||
| handle_bias_and_nonlinear( | |||||
| args.handle, args.nonlinear_mode, &conv_dst_tensor, args.dst_tensor, | |||||
| args.bias_tensor); | |||||
| void ConvBiasForwardImpl::AlgoCUDNNConv::cudnn_execute( | |||||
| const ExecArgs& args, const Workspace& workspace) const { | |||||
| CUDNNForwardDescs D; | |||||
| args.init_conv_desc(D); | |||||
| float alpha = 1.0f, beta = 0.0f; | |||||
| auto status = cudnnConvolutionForward( | |||||
| args.handle->cudnn_handle(), &alpha, D.src_desc.desc, | |||||
| args.src_tensor->raw_ptr(), D.filter_desc.desc, | |||||
| args.filter_tensor->raw_ptr(), D.conv_desc.conv_desc, m_cudnn_enum, | |||||
| workspace.raw_ptr, workspace.size, &beta, D.dst_desc.desc, | |||||
| args.dst_tensor->raw_ptr()); | |||||
| megdnn_assert( | |||||
| status == CUDNN_STATUS_SUCCESS, "conv fwd failed: %s; info: %s", | |||||
| cudnnGetErrorString(status), args.to_string().c_str()); | |||||
| } | } | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -0,0 +1,87 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/conv_bias/cudnn_conv_base.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #include "src/common/conv_bias.h" | |||||
| #include "src/cuda/conv_bias/algo.h" | |||||
| #include "src/cuda/utils.h" | |||||
| using namespace megdnn; | |||||
| using namespace cuda; | |||||
| using namespace conv_bias; | |||||
| WorkspaceBundle ConvBiasForwardImpl::AlgoCUDNNConvBase::get_workspace_bundle( | |||||
| void* ptr, const SizeArgs& args) const { | |||||
| auto dst_layout = *args.dst_layout; | |||||
| SmallVector<size_t> sizes; | |||||
| if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) { | |||||
| dst_layout.dtype = DType(); | |||||
| args.opr->check_or_deduce_dtype_fwd( | |||||
| args.src_layout->dtype, args.filter_layout->dtype, dst_layout.dtype); | |||||
| sizes.push_back(dst_layout.span().dist_byte()); | |||||
| } | |||||
| if (args.z_layout->ndim > 0 && | |||||
| args.z_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | |||||
| auto z_layout = *args.z_layout; | |||||
| z_layout.dtype = DType(); | |||||
| args.opr->check_or_deduce_dtype_fwd( | |||||
| args.src_layout->dtype, args.filter_layout->dtype, z_layout.dtype); | |||||
| sizes.push_back(z_layout.span().dist_byte()); | |||||
| } | |||||
| SizeArgs conv_args = args; | |||||
| conv_args.dst_layout = &dst_layout; | |||||
| size_t conv_workspace_size = cudnn_get_workspace_in_bytes(conv_args); | |||||
| sizes.insert(sizes.begin(), conv_workspace_size); | |||||
| return {ptr, std::move(sizes)}; | |||||
| } | |||||
| void ConvBiasForwardImpl::AlgoCUDNNConvBase::exec(const ExecArgs& args) const { | |||||
| auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | |||||
| TensorND conv_dst_tensor = *args.dst_tensor; | |||||
| if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | |||||
| conv_dst_tensor = TensorND{bundle.get(1), args.dst_tensor->layout}; | |||||
| conv_dst_tensor.layout.dtype = DType(); | |||||
| args.opr->check_or_deduce_dtype_fwd( | |||||
| args.src_layout->dtype, args.filter_layout->dtype, | |||||
| conv_dst_tensor.layout.dtype); | |||||
| } | |||||
| ExecArgs conv_args = args; | |||||
| conv_args.dst_tensor = &conv_dst_tensor; | |||||
| conv_args.dst_layout = &conv_dst_tensor.layout; | |||||
| cudnn_execute(conv_args, bundle.get_workspace(0)); | |||||
| if (args.z_layout->ndim > 0) { | |||||
| auto z_tensor = *args.z_tensor; | |||||
| if (args.z_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) { | |||||
| z_tensor = TensorND{bundle.get(2), args.z_tensor->layout}; | |||||
| z_tensor.layout.dtype = DType(); | |||||
| args.opr->check_or_deduce_dtype_fwd( | |||||
| args.src_layout->dtype, args.filter_layout->dtype, | |||||
| z_tensor.layout.dtype); | |||||
| auto typecvt = args.handle->create_operator<TypeCvt>(); | |||||
| typecvt->exec(*args.z_tensor, z_tensor); | |||||
| } | |||||
| auto add = args.handle->create_operator<ElemwiseForward>(); | |||||
| add->param().mode = Elemwise::Param::Mode::ADD; | |||||
| add->exec({conv_dst_tensor, z_tensor}, conv_dst_tensor); | |||||
| } | |||||
| handle_bias_and_nonlinear( | |||||
| args.handle, args.nonlinear_mode, &conv_dst_tensor, args.dst_tensor, | |||||
| args.bias_tensor); | |||||
| } | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -124,10 +124,10 @@ bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available( | |||||
| // forbits sigmoid for quantized | // forbits sigmoid for quantized | ||||
| if (args.src_layout->dtype.category() == DTypeCategory::QUANTIZED) | if (args.src_layout->dtype.category() == DTypeCategory::QUANTIZED) | ||||
| return false; | return false; | ||||
| MEGDNN_FALLTHRU // XXX: why? | |||||
| case param::ConvBias::NonlineMode::IDENTITY | |||||
| : if (args.src_layout->dtype.category() == | |||||
| DTypeCategory::QUANTIZED) break; | |||||
| MEGDNN_FALLTHRU; // XXX: why? | |||||
| case param::ConvBias::NonlineMode::IDENTITY: | |||||
| if (args.src_layout->dtype.category() == DTypeCategory::QUANTIZED) | |||||
| break; | |||||
| if (m_cudnn_enum != CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM) { | if (m_cudnn_enum != CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM) { | ||||
| // cudnn require algo to | // cudnn require algo to | ||||
| // CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM | // CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM | ||||
| @@ -149,7 +149,7 @@ bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available( | |||||
| return status == CUDNN_STATUS_SUCCESS; | return status == CUDNN_STATUS_SUCCESS; | ||||
| } | } | ||||
| size_t ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::get_workspace_in_bytes( | |||||
| size_t ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::cudnn_get_workspace_in_bytes( | |||||
| const SizeArgs& args) const { | const SizeArgs& args) const { | ||||
| CUDNNForwardDescs D; | CUDNNForwardDescs D; | ||||
| @@ -162,85 +162,18 @@ size_t ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::get_workspace_in_bytes( | |||||
| status == CUDNN_STATUS_SUCCESS, | status == CUDNN_STATUS_SUCCESS, | ||||
| "conv fwd get workspace failed: %s; info: %s", cudnnGetErrorString(status), | "conv fwd get workspace failed: %s; info: %s", cudnnGetErrorString(status), | ||||
| args.to_string().c_str()); | args.to_string().c_str()); | ||||
| if (args.bias_layout && args.bias_layout->dtype != dtype::Float32() && | |||||
| args.src_layout->dtype.category() != DTypeCategory::FLOAT) { | |||||
| // cudnn require bias to be float when executing CONFIG_INT | |||||
| // convert bias to float if bias is not float at first | |||||
| workspace_size += sizeof(float) * args.bias_layout->span().dist_elem(); | |||||
| } | |||||
| return workspace_size; | return workspace_size; | ||||
| } | } | ||||
| void ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::exec( | |||||
| const ExecArgs& args) const { | |||||
| void ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::cudnn_execute( | |||||
| const ExecArgs& args, const Workspace& workspace, float alpha, | |||||
| float beta) const { | |||||
| #if CUDNN_MAJOR < 7 | #if CUDNN_MAJOR < 7 | ||||
| megdnn_throw("ConvBias require cudnn 7.0 or higher"); | megdnn_throw("ConvBias require cudnn 7.0 or higher"); | ||||
| #else | #else | ||||
| megdnn_assert(cudnnGetVersion() >= 7401); | megdnn_assert(cudnnGetVersion() >= 7401); | ||||
| CUDNNForwardDescs D; | CUDNNForwardDescs D; | ||||
| args.init_conv_bias_desc(D); | args.init_conv_bias_desc(D); | ||||
| float alpha = 1.0f, beta = 0.0f; | |||||
| if (args.z_layout->ndim > 0) | |||||
| beta = 1.0f; | |||||
| auto get_scale = [](const DType& dtype) -> float { | |||||
| megdnn_assert(dtype.category() == DTypeCategory::QUANTIZED); | |||||
| switch (dtype.enumv()) { | |||||
| #define cb(_dt) \ | |||||
| case DTypeTrait<_dt>::enumv: \ | |||||
| return dtype.param<_dt>().scale; | |||||
| MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | |||||
| #undef cb | |||||
| default: | |||||
| megdnn_assert_internal(0); | |||||
| } | |||||
| }; | |||||
| auto src_dtype = args.src_layout->dtype, filter_dtype = args.filter_layout->dtype, | |||||
| dst_dtype = args.dst_layout->dtype; | |||||
| megdnn_assert( | |||||
| (src_dtype.category() == dst_dtype.category()) || | |||||
| (src_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
| dst_dtype.enumv() == DTypeEnum::Float32)); | |||||
| megdnn_assert(src_dtype.category() == filter_dtype.category()); | |||||
| if (args.src_layout->dtype.category() == DTypeCategory::QUANTIZED) { | |||||
| auto expected_bias_scale = get_scale(args.src_layout->dtype) * | |||||
| get_scale(args.filter_layout->dtype); | |||||
| alpha = expected_bias_scale; | |||||
| if (args.dst_layout->dtype.category() == DTypeCategory::QUANTIZED) | |||||
| alpha /= get_scale(args.dst_layout->dtype); | |||||
| if (args.z_layout->ndim > 0 && | |||||
| args.z_layout->dtype.category() == DTypeCategory::QUANTIZED) { | |||||
| beta = get_scale(args.z_layout->dtype) / get_scale(args.dst_layout->dtype); | |||||
| } | |||||
| if (args.bias_layout->dtype.category() == DTypeCategory::QUANTIZED) { | |||||
| megdnn_assert( | |||||
| fabs(expected_bias_scale - get_scale(args.bias_layout->dtype)) < | |||||
| 1e-4); | |||||
| } | |||||
| } | |||||
| auto workspace_ptr = args.workspace.raw_ptr; | |||||
| auto workspace_size = args.workspace.size; | |||||
| auto bias_ptr = args.bias_tensor->raw_ptr(); | |||||
| if (args.bias_layout && args.bias_layout->dtype != dtype::Float32() && | |||||
| args.src_layout->dtype.category() != DTypeCategory::FLOAT) { | |||||
| auto cvt = args.handle->create_operator<TypeCvt>(); | |||||
| auto float_bias_layout = *args.bias_layout; | |||||
| auto converted_bias_layout = *args.bias_layout; | |||||
| converted_bias_layout.dtype = dtype::QuantizedS32(alpha); | |||||
| float_bias_layout.dtype = dtype::Float32(); | |||||
| auto bias_size_in_bytes = float_bias_layout.span().dist_byte(); | |||||
| megdnn_assert(args.workspace.size >= bias_size_in_bytes); | |||||
| cvt->exec( | |||||
| {args.bias_tensor->raw_ptr(), converted_bias_layout}, | |||||
| TensorND{workspace_ptr, float_bias_layout}); | |||||
| bias_ptr = workspace_ptr; | |||||
| workspace_ptr += bias_size_in_bytes; | |||||
| workspace_size -= bias_size_in_bytes; | |||||
| } | |||||
| cudnnStatus_t status; | cudnnStatus_t status; | ||||
| if (args.z_layout->ndim == 0) { | if (args.z_layout->ndim == 0) { | ||||
| @@ -248,55 +181,23 @@ void ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::exec( | |||||
| args.handle->cudnn_handle(), &alpha, D.src_desc.desc, | args.handle->cudnn_handle(), &alpha, D.src_desc.desc, | ||||
| args.src_tensor->raw_ptr(), D.filter_desc.desc, | args.src_tensor->raw_ptr(), D.filter_desc.desc, | ||||
| args.filter_tensor->raw_ptr(), D.conv_desc.conv_desc, m_cudnn_enum, | args.filter_tensor->raw_ptr(), D.conv_desc.conv_desc, m_cudnn_enum, | ||||
| workspace_ptr, workspace_size, &beta, D.dst_desc.desc, | |||||
| args.dst_tensor->raw_ptr(), D.bias_desc.desc, bias_ptr, | |||||
| D.conv_desc.act_desc, D.dst_desc.desc, args.dst_tensor->raw_ptr()); | |||||
| workspace.raw_ptr, workspace.size, &beta, D.dst_desc.desc, | |||||
| args.dst_tensor->raw_ptr(), D.bias_desc.desc, | |||||
| args.bias_tensor->raw_ptr(), D.conv_desc.act_desc, D.dst_desc.desc, | |||||
| args.dst_tensor->raw_ptr()); | |||||
| } else { | } else { | ||||
| status = cudnnConvolutionBiasActivationForward( | status = cudnnConvolutionBiasActivationForward( | ||||
| args.handle->cudnn_handle(), &alpha, D.src_desc.desc, | args.handle->cudnn_handle(), &alpha, D.src_desc.desc, | ||||
| args.src_tensor->raw_ptr(), D.filter_desc.desc, | args.src_tensor->raw_ptr(), D.filter_desc.desc, | ||||
| args.filter_tensor->raw_ptr(), D.conv_desc.conv_desc, m_cudnn_enum, | args.filter_tensor->raw_ptr(), D.conv_desc.conv_desc, m_cudnn_enum, | ||||
| workspace_ptr, workspace_size, &beta, D.z_desc.desc, | |||||
| args.z_tensor->raw_ptr(), D.bias_desc.desc, bias_ptr, | |||||
| workspace.raw_ptr, workspace.size, &beta, D.z_desc.desc, | |||||
| args.z_tensor->raw_ptr(), D.bias_desc.desc, args.bias_tensor->raw_ptr(), | |||||
| D.conv_desc.act_desc, D.dst_desc.desc, args.dst_tensor->raw_ptr()); | D.conv_desc.act_desc, D.dst_desc.desc, args.dst_tensor->raw_ptr()); | ||||
| } | } | ||||
| megdnn_assert( | megdnn_assert( | ||||
| status == CUDNN_STATUS_SUCCESS, "conv fwd failed: %s; info: %s, algo %s", | status == CUDNN_STATUS_SUCCESS, "conv fwd failed: %s; info: %s, algo %s", | ||||
| cudnnGetErrorString(status), args.to_string().c_str(), name()); | cudnnGetErrorString(status), args.to_string().c_str(), name()); | ||||
| // Noline | |||||
| switch (args.nonlinear_mode) { | |||||
| case param::ConvBias::NonlineMode::RELU: | |||||
| break; | |||||
| case param::ConvBias::NonlineMode::SIGMOID: { | |||||
| megdnn_assert( | |||||
| args.dst_layout->dtype.category() != DTypeCategory::QUANTIZED); | |||||
| auto&& elem_opr = args.handle->create_operator<ElemwiseForward>(); | |||||
| elem_opr->param().mode = Elemwise::Param::Mode::SIGMOID; | |||||
| elem_opr->exec({*(args.dst_tensor)}, *(args.dst_tensor)); | |||||
| break; | |||||
| } | |||||
| case param::ConvBias::NonlineMode::IDENTITY: | |||||
| break; | |||||
| case param::ConvBias::NonlineMode::H_SWISH: { | |||||
| megdnn_assert( | |||||
| args.dst_layout->dtype.category() == DTypeCategory::QUANTIZED || | |||||
| (args.dst_layout->dtype.category() == DTypeCategory::FLOAT && | |||||
| args.opr->param().format == param::ConvBias::Format::NCHW4_NCHW)); | |||||
| if (args.dst_layout->dtype.category() == DTypeCategory::QUANTIZED) { | |||||
| auto&& elem_opr = args.handle->create_operator<ElemwiseMultiType>(); | |||||
| elem_opr->param().mode = ElemwiseMultiType::Param::Mode::QH_SWISH; | |||||
| elem_opr->exec({*(args.dst_tensor)}, *(args.dst_tensor)); | |||||
| } else { | |||||
| auto&& elem_opr = args.handle->create_operator<ElemwiseForward>(); | |||||
| elem_opr->param().mode = ElemwiseForward::Param::Mode::H_SWISH; | |||||
| elem_opr->exec({*(args.dst_tensor)}, *(args.dst_tensor)); | |||||
| } | |||||
| break; | |||||
| } | |||||
| default: | |||||
| megdnn_throw("unsupported NonlineMode"); | |||||
| } | |||||
| #endif | #endif | ||||
| } | } | ||||
| @@ -0,0 +1,210 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/conv_bias/cudnn_conv_bias_activation_base.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #include "megdnn/oprs/general.h" | |||||
| #include "./algo.h" | |||||
| #include "src/common/conv_bias.h" | |||||
| #include "src/cuda/conv_bias/helper.h" | |||||
| #include "src/cuda/cudnn_wrapper.h" | |||||
| #include "src/cuda/utils.h" | |||||
| using namespace megdnn; | |||||
| using namespace cuda; | |||||
| using namespace conv_bias; | |||||
| size_t ConvBiasForwardImpl::AlgoCUDNNConvBiasActivationBase::get_workspace_in_bytes( | |||||
| const SizeArgs& args) const { | |||||
| auto workspace_size = cudnn_get_workspace_in_bytes(args); | |||||
| auto&& param = args.opr->param(); | |||||
| if (args.preprocessed_filter == nullptr) { | |||||
| if (args.bias_layout && args.bias_layout->dtype != dtype::Float32() && | |||||
| args.src_layout->dtype.category() != DTypeCategory::FLOAT) { | |||||
| // cudnn require bias to be float when executing CONFIG_INT | |||||
| // convert bias to float if bias is not float at first | |||||
| workspace_size += sizeof(float) * args.bias_layout->span().dist_elem(); | |||||
| } | |||||
| if (param.format == param::ConvBias::Format::NCHW32) { | |||||
| workspace_size += args.filter_layout->span().dist_byte() + | |||||
| args.bias_layout->span().dist_byte(); | |||||
| } | |||||
| } | |||||
| return workspace_size; | |||||
| } | |||||
| void ConvBiasForwardImpl::AlgoCUDNNConvBiasActivationBase::exec( | |||||
| const ExecArgs& args) const { | |||||
| float alpha, beta; | |||||
| std::tie(alpha, beta) = cudnn_get_conv_bias_act_scale_param( | |||||
| args.src_tensor->layout, args.dst_tensor->layout, | |||||
| args.filter_tensor->layout, args.bias_tensor->layout, | |||||
| args.z_tensor->layout); | |||||
| auto workspace_ptr = args.workspace.raw_ptr; | |||||
| auto workspace_size = args.workspace.size; | |||||
| auto bias_ptr = args.bias_tensor->raw_ptr(); | |||||
| TensorND filter_tensor; | |||||
| TensorND bias_tensor; | |||||
| auto&& param = args.opr->param(); | |||||
| if (args.preprocessed_filter != nullptr) { | |||||
| bias_tensor = TensorND{ | |||||
| args.bias_tensor->layout, | |||||
| args.preprocessed_filter->tensors[0].raw_ptr()}; | |||||
| if (param.format == Param::Format::NCHW32) { | |||||
| megdnn_assert(args.preprocessed_filter->tensors.size() == 2); | |||||
| filter_tensor = TensorND{ | |||||
| args.filter_tensor->layout, | |||||
| args.preprocessed_filter->tensors[1].raw_ptr()}; | |||||
| } else { | |||||
| filter_tensor = *args.filter_tensor; | |||||
| } | |||||
| } else { | |||||
| if (args.bias_layout && args.bias_layout->dtype != dtype::Float32() && | |||||
| args.src_layout->dtype.category() != DTypeCategory::FLOAT) { | |||||
| auto cvt = args.handle->create_operator<TypeCvt>(); | |||||
| auto float_bias_layout = *args.bias_layout; | |||||
| auto converted_bias_layout = *args.bias_layout; | |||||
| converted_bias_layout.dtype = dtype::QuantizedS32(alpha); | |||||
| float_bias_layout.dtype = dtype::Float32(); | |||||
| auto bias_size_in_bytes = float_bias_layout.span().dist_byte(); | |||||
| megdnn_assert(args.workspace.size >= bias_size_in_bytes); | |||||
| cvt->exec( | |||||
| {args.bias_tensor->raw_ptr(), converted_bias_layout}, | |||||
| TensorND{workspace_ptr, float_bias_layout}); | |||||
| bias_ptr = workspace_ptr; | |||||
| workspace_ptr += bias_size_in_bytes; | |||||
| workspace_size -= bias_size_in_bytes; | |||||
| } | |||||
| if (param.format == Param::Format::NCHW32) { | |||||
| size_t reorder_workspace_size = | |||||
| args.filter_tensor->layout.span().dist_byte() + | |||||
| args.bias_tensor->layout.span().dist_byte(); | |||||
| auto reorder_filter_ptr = workspace_ptr; | |||||
| auto reorder_bias_ptr = | |||||
| workspace_ptr + args.filter_tensor->layout.span().dist_byte(); | |||||
| cudnn_reorder_filer_and_bias_nchw32( | |||||
| cudnn_handle(args.opr->handle()), args.filter_tensor->raw_ptr(), | |||||
| args.filter_meta, bias_ptr, reorder_filter_ptr, reorder_bias_ptr); | |||||
| filter_tensor = TensorND(args.filter_tensor->layout, reorder_filter_ptr); | |||||
| bias_ptr = reorder_bias_ptr; | |||||
| workspace_ptr += reorder_workspace_size; | |||||
| workspace_size -= reorder_workspace_size; | |||||
| } else { | |||||
| filter_tensor = *args.filter_tensor; | |||||
| } | |||||
| } | |||||
| bias_tensor = TensorND{args.bias_tensor->layout, bias_ptr}; | |||||
| ExecArgs exec_args{ | |||||
| const_cast<ConvBiasForwardImpl*>(args.opr), | |||||
| *args.src_tensor, | |||||
| filter_tensor, | |||||
| bias_tensor, | |||||
| *args.z_tensor, | |||||
| *args.dst_tensor, | |||||
| args.workspace}; | |||||
| Workspace cudnn_workspace{workspace_ptr, workspace_size}; | |||||
| cudnn_execute(exec_args, cudnn_workspace, alpha, beta); | |||||
| // Noline | |||||
| switch (args.nonlinear_mode) { | |||||
| case param::ConvBias::NonlineMode::RELU: | |||||
| break; | |||||
| case param::ConvBias::NonlineMode::SIGMOID: { | |||||
| megdnn_assert( | |||||
| args.dst_layout->dtype.category() != DTypeCategory::QUANTIZED); | |||||
| auto&& elem_opr = args.handle->create_operator<ElemwiseForward>(); | |||||
| elem_opr->param().mode = Elemwise::Param::Mode::SIGMOID; | |||||
| elem_opr->exec({*(args.dst_tensor)}, *(args.dst_tensor)); | |||||
| break; | |||||
| } | |||||
| case param::ConvBias::NonlineMode::IDENTITY: | |||||
| break; | |||||
| case param::ConvBias::NonlineMode::H_SWISH: { | |||||
| megdnn_assert( | |||||
| args.dst_layout->dtype.category() == DTypeCategory::QUANTIZED || | |||||
| (args.dst_layout->dtype.category() == DTypeCategory::FLOAT && | |||||
| args.opr->param().format == param::ConvBias::Format::NCHW4_NCHW)); | |||||
| if (args.dst_layout->dtype.category() == DTypeCategory::QUANTIZED) { | |||||
| auto&& elem_opr = args.handle->create_operator<ElemwiseMultiType>(); | |||||
| elem_opr->param().mode = ElemwiseMultiType::Param::Mode::QH_SWISH; | |||||
| elem_opr->exec({*(args.dst_tensor)}, *(args.dst_tensor)); | |||||
| } else { | |||||
| auto&& elem_opr = args.handle->create_operator<ElemwiseForward>(); | |||||
| elem_opr->param().mode = ElemwiseForward::Param::Mode::H_SWISH; | |||||
| elem_opr->exec({*(args.dst_tensor)}, *(args.dst_tensor)); | |||||
| } | |||||
| break; | |||||
| } | |||||
| default: | |||||
| megdnn_throw("unsupported NonlineMode"); | |||||
| } | |||||
| } | |||||
| size_t ConvBiasForwardImpl::AlgoCUDNNConvBiasActivationBase:: | |||||
| get_preprocess_workspace_in_bytes(const SizeArgs& args) const { | |||||
| auto&& param = args.opr->param(); | |||||
| if (param.format == Param::Format::NCHW32) { | |||||
| return args.bias_layout->span().dist_byte(); | |||||
| } | |||||
| return 0_z; | |||||
| } | |||||
| SmallVector<TensorLayout> ConvBiasForwardImpl::AlgoCUDNNConvBiasActivationBase:: | |||||
| deduce_preprocessed_filter_layout(const SizeArgs& args) const { | |||||
| auto&& param = args.opr->param(); | |||||
| if (param.format == Param::Format::NCHW32) { | |||||
| return {args.bias_layout->collapse_contiguous(), | |||||
| args.filter_layout->collapse_contiguous()}; | |||||
| } else { | |||||
| return {args.bias_layout->collapse_contiguous()}; | |||||
| } | |||||
| } | |||||
| void ConvBiasForwardImpl::AlgoCUDNNConvBiasActivationBase::exec_preprocess( | |||||
| const ExecArgs& args) const { | |||||
| float alpha, beta; | |||||
| std::tie(alpha, beta) = cudnn_get_conv_bias_act_scale_param( | |||||
| args.src_tensor->layout, args.dst_tensor->layout, | |||||
| args.filter_tensor->layout, args.bias_tensor->layout, | |||||
| args.z_tensor->layout); | |||||
| MEGDNN_MARK_USED_VAR(beta); | |||||
| auto workspace_ptr = args.workspace.raw_ptr; | |||||
| auto workspace_size = args.workspace.size; | |||||
| auto bias_ptr = workspace_size > 0 ? workspace_ptr | |||||
| : args.preprocessed_filter->tensors[0].raw_ptr(); | |||||
| if (args.bias_layout && args.bias_layout->dtype != dtype::Float32() && | |||||
| args.src_layout->dtype.category() != DTypeCategory::FLOAT) { | |||||
| auto cvt = args.handle->create_operator<TypeCvt>(); | |||||
| auto float_bias_layout = *args.bias_layout; | |||||
| auto converted_bias_layout = *args.bias_layout; | |||||
| converted_bias_layout.dtype = dtype::QuantizedS32(alpha); | |||||
| float_bias_layout.dtype = dtype::Float32(); | |||||
| cvt->exec( | |||||
| {args.bias_tensor->raw_ptr(), converted_bias_layout}, | |||||
| TensorND{bias_ptr, float_bias_layout}); | |||||
| } | |||||
| if (args.opr->param().format == Param::Format::NCHW32) { | |||||
| auto reorder_filter_ptr = args.preprocessed_filter->tensors[1].raw_ptr(); | |||||
| auto reorder_bias_ptr = args.preprocessed_filter->tensors[0].raw_ptr(); | |||||
| cudnn_reorder_filer_and_bias_nchw32( | |||||
| cudnn_handle(args.opr->handle()), args.filter_tensor->raw_ptr(), | |||||
| args.filter_meta, bias_ptr, reorder_filter_ptr, reorder_bias_ptr); | |||||
| } | |||||
| } | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,145 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/conv_bias/cudnn_conv_bias_activation_v8.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #include "megdnn/oprs/general.h" | |||||
| #include "./algo.h" | |||||
| #include "src/common/conv_bias.h" | |||||
| #include "src/cuda/cudnn_wrapper_v8.h" | |||||
| #include "src/cuda/utils.h" | |||||
| #if CUDNN_VERSION >= 8004 | |||||
| using namespace megdnn; | |||||
| using namespace cuda; | |||||
| using namespace conv_bias; | |||||
| namespace { | |||||
| TensorLayout canonical_bias_layout( | |||||
| const TensorLayout& bias_layout, const param::ConvBias::Format format) { | |||||
| int64_t vector_count, vector_dimension; | |||||
| std::tie(vector_count, vector_dimension) = get_vector_count_and_dimension(format); | |||||
| size_t channel = bias_layout[vector_dimension] * vector_count; | |||||
| if (bias_layout.dtype.category() != DTypeCategory::FLOAT) { | |||||
| return TensorLayout{{1, channel, 1, 1}, dtype::Float32()}; | |||||
| } | |||||
| return TensorLayout{{1, channel, 1, 1}, bias_layout.dtype}; | |||||
| } | |||||
| } // namespace | |||||
| bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivationV8::is_available( | |||||
| const SizeArgs& args) const { | |||||
| auto&& param = args.opr->param(); | |||||
| if (param.format == param::ConvBias::Format::NCHW4_NCHW32 || | |||||
| param.format == param::ConvBias::Format::NCHW32_NCHW4 || | |||||
| param.format == param::ConvBias::Format::NCHW4_NCHW || | |||||
| param.format == param::ConvBias::Format::NCHW8 || | |||||
| param.format == param::ConvBias::Format::NCHW64 || | |||||
| param.format == param::ConvBias::Format::CHWN4) | |||||
| return false; | |||||
| if (param.format != Param::Format::NCHW && param.format != Param::Format::NHWC) { | |||||
| if (!args.src_layout->is_contiguous() || !args.dst_layout->is_contiguous()) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| if ((args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS4 || | |||||
| args.src_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) && | |||||
| args.filter_layout->dtype.enumv() == DTypeEnum::QuantizedS4) | |||||
| return false; | |||||
| if (args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4 || | |||||
| args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) | |||||
| return false; | |||||
| if (args.src_layout->dtype == args.filter_layout->dtype && | |||||
| args.src_layout->dtype == dtype::BFloat16()) { | |||||
| return false; | |||||
| } | |||||
| if (args.bias_layout->ndim == 0 || | |||||
| !check_bias_share_in_channel(*(args.bias_layout), param.format)) { | |||||
| return false; | |||||
| } | |||||
| // FIXME: cudnn cannot handle the case when the initial value of dst tensor | |||||
| // contains nan and beta is zero, because the result of 0.f * nan is still | |||||
| // nan | |||||
| if (args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
| args.dst_layout->dtype.enumv() == DTypeEnum::Float32 && | |||||
| param.format == param::ConvBias::Format::NCHW) { | |||||
| return false; | |||||
| } | |||||
| if (param.format == param::ConvBias::Format::NCHW32) { | |||||
| // sm version | |||||
| auto&& device_prop = current_device_prop(); | |||||
| if (device_prop.major < 7 || (device_prop.major == 7 && device_prop.minor < 5)) | |||||
| return false; | |||||
| } | |||||
| switch (args.nonlinear_mode) { | |||||
| case param::ConvBias::NonlineMode::RELU: | |||||
| case param::ConvBias::NonlineMode::IDENTITY: | |||||
| break; | |||||
| case param::ConvBias::NonlineMode::SIGMOID: | |||||
| // forbits sigmoid for quantized | |||||
| if (args.src_layout->dtype.category() == DTypeCategory::QUANTIZED) | |||||
| return false; | |||||
| break; | |||||
| case param::ConvBias::NonlineMode::H_SWISH: | |||||
| if (args.src_layout->dtype.category() == DTypeCategory::QUANTIZED) | |||||
| break; | |||||
| return false; | |||||
| default: | |||||
| megdnn_throw("unsupported NonlineMode"); | |||||
| } | |||||
| auto bias_layout = | |||||
| canonical_bias_layout(*args.bias_layout, args.opr->param().format); | |||||
| auto plan = get_heuristic_plan_from_opr( | |||||
| static_cast<const ConvBiasForward*>(args.opr), *args.src_layout, | |||||
| *args.dst_layout, *args.filter_layout, bias_layout, *args.z_layout, | |||||
| args.filter_meta); | |||||
| return plan != nullptr; | |||||
| } | |||||
| size_t ConvBiasForwardImpl::AlgoCUDNNConvBiasActivationV8::cudnn_get_workspace_in_bytes( | |||||
| const SizeArgs& args) const { | |||||
| auto bias_layout = | |||||
| canonical_bias_layout(*args.bias_layout, args.opr->param().format); | |||||
| auto plan = get_heuristic_plan_from_opr( | |||||
| static_cast<const ConvBiasForward*>(args.opr), *args.src_layout, | |||||
| *args.dst_layout, *args.filter_layout, bias_layout, *args.z_layout, | |||||
| args.filter_meta); | |||||
| megdnn_assert( | |||||
| plan != nullptr, "algo(%s) cannot find execution from heuristics", name()); | |||||
| return plan->getWorkspaceSize(); | |||||
| } | |||||
| void ConvBiasForwardImpl::AlgoCUDNNConvBiasActivationV8::cudnn_execute( | |||||
| const ExecArgs& args, const Workspace& workspace, float alpha, | |||||
| float beta) const { | |||||
| auto&& bias_layout = | |||||
| canonical_bias_layout(args.bias_tensor->layout, args.opr->param().format); | |||||
| auto plan = get_heuristic_plan_from_opr( | |||||
| static_cast<const ConvBiasForward*>(args.opr), args.src_tensor->layout, | |||||
| args.dst_tensor->layout, args.filter_tensor->layout, bias_layout, | |||||
| args.z_tensor->layout, args.filter_meta); | |||||
| megdnn_assert( | |||||
| plan != nullptr, "algo(%s) cannot find execution from heuristics", name()); | |||||
| auto&& handle = cudnn_handle(args.handle); | |||||
| TensorND bias_tensor{args.bias_tensor->raw_ptr(), bias_layout}; | |||||
| run_conv_bias_act_with_plan( | |||||
| handle, *plan, *args.src_tensor, *args.dst_tensor, *args.filter_tensor, | |||||
| bias_tensor, *args.z_tensor, workspace); | |||||
| } | |||||
| #endif | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,98 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/conv_bias/cudnn_conv_v8.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #include "src/common/conv_bias.h" | |||||
| #include "src/cuda/conv_bias/algo.h" | |||||
| #include "src/cuda/cudnn_wrapper_v8.h" | |||||
| #include "src/cuda/utils.h" | |||||
| #if CUDNN_VERSION >= 8004 | |||||
| using namespace megdnn; | |||||
| using namespace cuda; | |||||
| using namespace conv_bias; | |||||
| bool ConvBiasForwardImpl::AlgoCUDNNConvV8::is_available(const SizeArgs& args) const { | |||||
| if (args.filter_meta.format != Param::Format::NCHW && | |||||
| args.filter_meta.format != Param::Format::NHWC) { | |||||
| if (!args.src_layout->is_contiguous() || !args.dst_layout->is_contiguous()) { | |||||
| return false; | |||||
| } | |||||
| } | |||||
| if (args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4 || | |||||
| args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) { | |||||
| return false; | |||||
| } | |||||
| // FIXME: cudnn cannot handle the case when the initial value of dst tensor | |||||
| // contains nan and beta is zero, because the result of 0.f * nan is still | |||||
| // nan | |||||
| if (args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
| args.dst_layout->dtype.enumv() == DTypeEnum::Float32 && | |||||
| args.opr->param().format == param::ConvBias::Format::NCHW) { | |||||
| return false; | |||||
| } | |||||
| auto dst_layout = *args.dst_layout; | |||||
| if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) { | |||||
| dst_layout.dtype = DType(); | |||||
| args.opr->check_or_deduce_dtype_fwd( | |||||
| args.src_layout->dtype, args.filter_layout->dtype, dst_layout.dtype); | |||||
| } | |||||
| SizeArgs conv_args = args; | |||||
| conv_args.dst_layout = &dst_layout; | |||||
| if (!is_cudnn_supported(conv_args)) | |||||
| return false; | |||||
| auto conv_opr = args.handle->create_operator<ConvolutionForward>(); | |||||
| conv_opr->param() = get_param_convolution(args); | |||||
| ConvolutionForward::CanonizedFilterMeta fm; | |||||
| fm.copy_from(args.filter_meta); | |||||
| auto plan = get_heuristic_plan_from_opr( | |||||
| conv_opr.get(), *conv_args.src_layout, *conv_args.dst_layout, | |||||
| *conv_args.filter_layout, {}, {}, fm); | |||||
| return plan != nullptr; | |||||
| } | |||||
| size_t ConvBiasForwardImpl::AlgoCUDNNConvV8::cudnn_get_workspace_in_bytes( | |||||
| const SizeArgs& args) const { | |||||
| auto conv_opr = args.handle->create_operator<ConvolutionForward>(); | |||||
| conv_opr->param() = get_param_convolution(args); | |||||
| ConvolutionForward::CanonizedFilterMeta fm; | |||||
| fm.copy_from(args.filter_meta); | |||||
| auto plan = get_heuristic_plan_from_opr( | |||||
| conv_opr.get(), *args.src_layout, *args.dst_layout, *args.filter_layout, {}, | |||||
| {}, fm); | |||||
| megdnn_assert( | |||||
| plan != nullptr, "algo(%s) cannot find execution from heuristics", name()); | |||||
| return plan->getWorkspaceSize(); | |||||
| } | |||||
| void ConvBiasForwardImpl::AlgoCUDNNConvV8::cudnn_execute( | |||||
| const ExecArgs& args, const Workspace& workspace) const { | |||||
| auto conv_opr = args.handle->create_operator<ConvolutionForward>(); | |||||
| conv_opr->param() = get_param_convolution(args); | |||||
| ConvolutionForward::CanonizedFilterMeta fm; | |||||
| fm.copy_from(args.filter_meta); | |||||
| auto plan = get_heuristic_plan_from_opr( | |||||
| conv_opr.get(), args.src_tensor->layout, args.dst_tensor->layout, | |||||
| args.filter_tensor->layout, {}, {}, fm); | |||||
| megdnn_assert( | |||||
| plan != nullptr, "algo(%s) cannot find execution from heuristics", name()); | |||||
| auto&& handle = cudnn_handle(args.handle); | |||||
| run_single_conv_with_plan( | |||||
| handle, *plan, *args.src_tensor, *args.dst_tensor, *args.filter_tensor, | |||||
| workspace); | |||||
| } | |||||
| #endif | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -197,8 +197,60 @@ void flip_filter( | |||||
| ref_ptr.reset(workspace.raw_ptr); | ref_ptr.reset(workspace.raw_ptr); | ||||
| } | } | ||||
| } // namespace conv_bias | |||||
| std::pair<float, float> cudnn_get_conv_bias_act_scale_param( | |||||
| const TensorLayout& x, const TensorLayout& y, const TensorLayout& w, | |||||
| const TensorLayout& b, const TensorLayout& z) { | |||||
| float alpha = 1.f, beta = 0.f; | |||||
| if (z.ndim > 0) | |||||
| beta = 1.f; | |||||
| auto get_scale = [](const DType& dtype) -> float { | |||||
| megdnn_assert(dtype.category() == DTypeCategory::QUANTIZED); | |||||
| switch (dtype.enumv()) { | |||||
| #define cb(_dt) \ | |||||
| case DTypeTrait<_dt>::enumv: \ | |||||
| return dtype.param<_dt>().scale; | |||||
| MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | |||||
| #undef cb | |||||
| default: | |||||
| megdnn_assert_internal(0); | |||||
| } | |||||
| }; | |||||
| auto x_dtype = x.dtype, y_dtype = y.dtype, w_dtype = w.dtype; | |||||
| megdnn_assert( | |||||
| (x_dtype.category() == y_dtype.category()) || | |||||
| (x_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
| y_dtype.enumv() == DTypeEnum::Float32)); | |||||
| megdnn_assert(x_dtype.category() == w_dtype.category()); | |||||
| if (x_dtype.category() == DTypeCategory::QUANTIZED) { | |||||
| auto expected_bias_scale = get_scale(x_dtype) * get_scale(w_dtype); | |||||
| alpha = expected_bias_scale; | |||||
| if (y_dtype.category() == DTypeCategory::QUANTIZED) | |||||
| alpha /= get_scale(y_dtype); | |||||
| if (z.ndim > 0 && z.dtype.category() == DTypeCategory::QUANTIZED) { | |||||
| beta = get_scale(z.dtype) / get_scale(y_dtype); | |||||
| } | |||||
| if (b.dtype.category() == DTypeCategory::QUANTIZED) { | |||||
| megdnn_assert(fabs(expected_bias_scale - get_scale(b.dtype)) < 1e-4); | |||||
| } | |||||
| } | |||||
| return {alpha, beta}; | |||||
| } | |||||
| void cudnn_reorder_filer_and_bias_nchw32( | |||||
| const cudnnHandle_t& handle, const void* filter_ptr, | |||||
| const CanonizedFilterMeta& fm, const void* bias_ptr, void* reordered_filter_ptr, | |||||
| void* reordered_bias_ptr) { | |||||
| FilterDesc<param::ConvBias> filter_desc; | |||||
| filter_desc.set(fm); | |||||
| int reorder_bias = bias_ptr != nullptr; | |||||
| cudnn_check(cudnnReorderFilterAndBias( | |||||
| handle, filter_desc.desc, CUDNN_DEFAULT_REORDER, filter_ptr, | |||||
| reordered_filter_ptr, reorder_bias, bias_ptr, reordered_bias_ptr)); | |||||
| } | |||||
| } // namespace conv_bias | |||||
| } // namespace cuda | } // namespace cuda | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -113,6 +113,15 @@ struct CUDNNForwardDescs { | |||||
| } | } | ||||
| }; | }; | ||||
| std::pair<float, float> cudnn_get_conv_bias_act_scale_param( | |||||
| const TensorLayout& x, const TensorLayout& y, const TensorLayout& w, | |||||
| const TensorLayout& b, const TensorLayout& z); | |||||
| void cudnn_reorder_filer_and_bias_nchw32( | |||||
| const cudnnHandle_t& handle, const void* filter_ptr, | |||||
| const CanonizedFilterMeta& fm, const void* bias_ptr, void* reordered_filter_ptr, | |||||
| void* reordered_bias_ptr); | |||||
| } // namespace conv_bias | } // namespace conv_bias | ||||
| } // namespace cuda | } // namespace cuda | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -47,6 +47,17 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | |||||
| const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) { | const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) { | ||||
| using namespace conv_bias; | using namespace conv_bias; | ||||
| AlgoBase::SizeArgs args{this, src, filter, bias, z, dst}; | AlgoBase::SizeArgs args{this, src, filter, bias, z, dst}; | ||||
| #if CUDNN_VERSION >= 8004 | |||||
| if (sm_algo_pack.cudnn_conv_v8.is_available_attribute( | |||||
| args, positive_attr, negative_attr, workspace_limit_in_bytes)) { | |||||
| return &sm_algo_pack.cudnn_conv_v8; | |||||
| } | |||||
| if (sm_algo_pack.cudnn_conv_bias_activation_v8.is_available_attribute( | |||||
| args, positive_attr, negative_attr, workspace_limit_in_bytes)) { | |||||
| return &sm_algo_pack.cudnn_conv_bias_activation_v8; | |||||
| } | |||||
| #endif | |||||
| auto dst_layout = *args.dst_layout; | auto dst_layout = *args.dst_layout; | ||||
| if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) { | if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) { | ||||
| dst_layout.dtype = DType(); | dst_layout.dtype = DType(); | ||||
| @@ -1,6 +1,7 @@ | |||||
| #pragma once | #pragma once | ||||
| #include "../elemwise/opr_impl.h" | #include "../elemwise/opr_impl.h" | ||||
| #include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
| #include "src/cuda/cudnn_with_check.h" | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace cuda { | namespace cuda { | ||||
| @@ -65,6 +66,12 @@ public: | |||||
| // The following algorithms are suitable for channel wise convolution | // The following algorithms are suitable for channel wise convolution | ||||
| class AlgoFloat32NCHWFMAImplicitBatchedGemm; | class AlgoFloat32NCHWFMAImplicitBatchedGemm; | ||||
| class AlgoFloat16NCHWHMMAImplicitBatchedGemm; | class AlgoFloat16NCHWHMMAImplicitBatchedGemm; | ||||
| class AlgoCUDNNConvBase; | |||||
| class AlgoCUDNNConvBiasActivationBase; | |||||
| #if CUDNN_VERSION > 8004 | |||||
| class AlgoCUDNNConvV8; | |||||
| class AlgoCUDNNConvBiasActivationV8; | |||||
| #endif | |||||
| class AlgoPack; | class AlgoPack; | ||||
| @@ -0,0 +1,685 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/cudnn_wrapper_v8.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #include "src/cuda/cudnn_wrapper_v8.h" | |||||
| #include "src/cuda/cudnn_wrapper.h" | |||||
| #include "src/common/utils.h" | |||||
| #include "src/cuda/utils.h" | |||||
| #include "src/cuda/conv_bias/helper.h" | |||||
| #include "cudnn_frontend_EngineConfigGenerator.h" | |||||
| #include "megdnn/heuristic_cache.h" | |||||
| using namespace megdnn; | |||||
| using namespace cuda; | |||||
| // helper functions for underlying descriptors | |||||
| namespace { | |||||
| cudnnDataType_t get_cudnn_data_type(DType type) { | |||||
| switch (type.enumv()) { | |||||
| case DTypeEnum::Float32: | |||||
| return CUDNN_DATA_FLOAT; | |||||
| case DTypeEnum::Float16: | |||||
| return CUDNN_DATA_HALF; | |||||
| case DTypeEnum::Int32: | |||||
| case DTypeEnum::QuantizedS32: | |||||
| return CUDNN_DATA_INT32; | |||||
| case DTypeEnum::QuantizedS8: | |||||
| case DTypeEnum::Int8: | |||||
| return CUDNN_DATA_INT8; | |||||
| default: | |||||
| megdnn_throw("dtype must be float16/float32/int8/qint8/int32/qint32"); | |||||
| } | |||||
| } | |||||
| cudnnDataType_t get_compute_type( | |||||
| DType type, param::Convolution::ComputeMode comp_mode) { | |||||
| if (type.enumv() == DTypeEnum::Float32) { | |||||
| return CUDNN_DATA_FLOAT; | |||||
| } else if (type.enumv() == DTypeEnum::Float16) { | |||||
| return get_compute_type_fp16(comp_mode); | |||||
| } else if ( | |||||
| type.category() == DTypeCategory::INT || | |||||
| type.category() == DTypeCategory::QUANTIZED) { | |||||
| return CUDNN_DATA_INT32; | |||||
| } else { | |||||
| megdnn_throw("unsupported compute type for convolution"); | |||||
| } | |||||
| } | |||||
| using Format = param::Convolution::Format; | |||||
| using IntArrayRef = SmallVector<int64_t>; | |||||
| std::pair<IntArrayRef, IntArrayRef> get_shape_and_stride( | |||||
| const TensorLayout& layout, const Format format, int64_t nr_group) { | |||||
| // DENSE: n, c, h, w | |||||
| // n, k, p, q; ndim = 4 | |||||
| // GROUP: n, g, c, h, w | |||||
| // n, g, k, p, q; ndim = 5 | |||||
| static constexpr size_t CUDNN_NDIM = 4; | |||||
| size_t cudnn_ndim = CUDNN_NDIM; | |||||
| if (nr_group > 1) | |||||
| cudnn_ndim += 1; | |||||
| IntArrayRef shape(cudnn_ndim); | |||||
| IntArrayRef stride(cudnn_ndim); | |||||
| if (format == Format::NCHW4 || format == Format::NCHW32) | |||||
| megdnn_assert_eq_size_t(layout.ndim, 5_z); | |||||
| else | |||||
| megdnn_assert_eq_size_t(layout.ndim, 4_z); | |||||
| size_t c_pos, spatial_pos; | |||||
| if (format == Format::NCHW || format == Format::NCHW4 || format == Format::NCHW32) { | |||||
| c_pos = 1; | |||||
| spatial_pos = 2; | |||||
| } else { | |||||
| megdnn_assert(format == Format::NHWC); | |||||
| c_pos = 3; | |||||
| spatial_pos = 1; | |||||
| } | |||||
| int64_t vector_count, vector_dimension; | |||||
| std::tie(vector_count, vector_dimension) = get_vector_count_and_dimension(format); | |||||
| size_t out_c_pos = nr_group == 1 ? 1 : 2; | |||||
| size_t out_spatial_pos = nr_group == 1 ? 2 : 3; | |||||
| // For NCHW4 and NCHW32 we still compute standard strides here to input to cuDNN | |||||
| // functions. We will manually scale by resizeFactor in the cpu ref. | |||||
| shape[0] = layout[0]; | |||||
| if (nr_group > 1) | |||||
| shape[1] = nr_group; | |||||
| shape[out_c_pos] = layout[c_pos] / nr_group; | |||||
| shape[out_spatial_pos] = layout[spatial_pos]; | |||||
| shape[out_spatial_pos + 1] = layout[spatial_pos + 1]; | |||||
| if (c_pos == 1) { | |||||
| stride[cudnn_ndim - 1] = 1; | |||||
| for (int i = cudnn_ndim - 2; i >= 0; --i) { | |||||
| stride[i] = stride[i + 1] * shape[i + 1]; | |||||
| } | |||||
| } else { | |||||
| megdnn_assert(c_pos == 3); // Here we assume that the format is NHWC | |||||
| stride[out_c_pos] = 1; | |||||
| if (nr_group > 1) | |||||
| stride[1] = shape[out_c_pos] * stride[out_c_pos]; | |||||
| stride[out_spatial_pos + 1] = stride[1] * shape[1]; | |||||
| stride[out_spatial_pos] = | |||||
| stride[out_spatial_pos + 1] * shape[out_spatial_pos + 1]; | |||||
| stride[0] = stride[out_spatial_pos] * shape[out_spatial_pos]; | |||||
| } | |||||
| return {shape, stride}; | |||||
| } | |||||
| /* --------------- make cudnn-frontend tensor descriptor --------------- */ | |||||
| auto make_tensor_descriptor( | |||||
| int64_t id, uint8_t alignment, const TensorLayout& layout, const Format format, | |||||
| int64_t nr_group, bool is_virtual = false) { | |||||
| int64_t vector_count, vector_dimension; | |||||
| std::tie(vector_count, vector_dimension) = get_vector_count_and_dimension(format); | |||||
| IntArrayRef shape, stride; | |||||
| std::tie(shape, stride) = get_shape_and_stride(layout, format, nr_group); | |||||
| return cudnn_frontend::TensorBuilder() | |||||
| .setDim(shape.size(), shape.data()) | |||||
| .setStrides(stride.size(), stride.data()) | |||||
| .setId(id) | |||||
| .setAlignment(alignment) | |||||
| .setDataType(get_cudnn_data_type(layout.dtype)) | |||||
| .setVirtual(is_virtual) | |||||
| .setVectorCountAndDimension(vector_count, vector_dimension) | |||||
| .build(); | |||||
| } | |||||
| /* --------------- make cudnn-frontend filter descriptor --------------- */ | |||||
| template <typename FilterMeta> | |||||
| cudnn_frontend::Tensor make_filter_descriptor(uint8_t alignment, const FilterMeta& fm) { | |||||
| // DENSE: k, c, r, s; ndim = 4 | |||||
| // GROUP: g, k, c, r, s; ndim = 5 | |||||
| // generate shape and stride | |||||
| static constexpr size_t CUDNN_NDIM = 4; | |||||
| size_t cudnn_ndim = CUDNN_NDIM; | |||||
| if (fm.group > 1) | |||||
| cudnn_ndim += 1; | |||||
| IntArrayRef shape(cudnn_ndim), stride(cudnn_ndim); | |||||
| auto format = fm.format; | |||||
| int64_t vector_count, vector_dimension; | |||||
| std::tie(vector_count, vector_dimension) = get_vector_count_and_dimension(format); | |||||
| int64_t group = fm.group; | |||||
| size_t out_ch_pos = group == 1 ? 0 : 1; | |||||
| size_t in_ch_pos = group == 1 ? 1 : 2; | |||||
| size_t filter_start = group == 1 ? 2 : 3; | |||||
| if (group > 1) | |||||
| shape[0] = group; | |||||
| shape[out_ch_pos] = fm.ocpg; | |||||
| shape[in_ch_pos] = fm.icpg / vector_count; | |||||
| shape[filter_start] = fm.spatial[0]; | |||||
| shape[filter_start + 1] = fm.spatial[1]; | |||||
| if (format == Format::NCHW || format == Format::NCHW4 || format == Format::NCHW32) { | |||||
| stride[cudnn_ndim - 1] = 1; | |||||
| for (int i = cudnn_ndim - 2; i >= 0; --i) { | |||||
| stride[i] = stride[i + 1] * shape[i + 1]; | |||||
| } | |||||
| } else { | |||||
| megdnn_assert( | |||||
| format == Format::NHWC); // Here we assume that the format is NHWC | |||||
| stride[in_ch_pos] = 1; | |||||
| stride[filter_start + 1] = stride[in_ch_pos] * shape[in_ch_pos]; | |||||
| stride[filter_start] = stride[filter_start + 1] * shape[filter_start + 1]; | |||||
| stride[out_ch_pos] = stride[filter_start] * shape[filter_start]; | |||||
| if (group > 1) | |||||
| stride[0] = stride[out_ch_pos] * shape[out_ch_pos]; | |||||
| } | |||||
| return cudnn_frontend::TensorBuilder() | |||||
| .setDim(shape.size(), shape.data()) | |||||
| .setStrides(stride.size(), stride.data()) | |||||
| .setId('w') // weight descriptor | |||||
| .setAlignment(alignment) | |||||
| .setDataType(get_cudnn_data_type(fm.dtype)) | |||||
| .setVectorCountAndDimension(vector_count, vector_dimension) | |||||
| .build(); | |||||
| } | |||||
| /* --------------- make cudnn-frontend conv descriptor --------------- */ | |||||
| template <typename Param> | |||||
| cudnn_frontend::ConvDesc_v8 make_conv_descriptor( | |||||
| cudnnDataType_t data_type, const Param& param) { | |||||
| IntArrayRef padding = {param.pad_h, param.pad_w}; | |||||
| IntArrayRef stride = {param.stride_h, param.stride_w}; | |||||
| IntArrayRef dilation = {param.dilate_h, param.dilate_w}; | |||||
| uint64_t conv_dim = stride.size(); | |||||
| cudnnConvolutionMode_t mode; | |||||
| switch (param.mode) { | |||||
| case Param::Mode::CROSS_CORRELATION: | |||||
| mode = CUDNN_CROSS_CORRELATION; | |||||
| break; | |||||
| case Param::Mode::CONVOLUTION: | |||||
| mode = CUDNN_CONVOLUTION; | |||||
| break; | |||||
| default: | |||||
| megdnn_throw("conv mode must be conv or xcorr."); | |||||
| } | |||||
| return cudnn_frontend::ConvDescBuilder() | |||||
| .setDataType(data_type) | |||||
| .setMathMode(mode) | |||||
| .setNDims(conv_dim) | |||||
| .setStrides(conv_dim, stride.data()) | |||||
| .setPrePadding(conv_dim, padding.data()) | |||||
| .setPostPadding(conv_dim, padding.data()) | |||||
| .setDilation(conv_dim, dilation.data()) | |||||
| .build(); | |||||
| } | |||||
| /* --------------- make cudnn-frontend activation descriptor --------------- */ | |||||
| auto make_activation_descriptor( | |||||
| DType data_type, const param::ConvBias::NonlineMode nonline_mode) { | |||||
| cudnnPointwiseMode_t mode; | |||||
| using NonlineMode = param::ConvBias::NonlineMode; | |||||
| switch (nonline_mode) { | |||||
| case NonlineMode::RELU: | |||||
| mode = CUDNN_POINTWISE_RELU_FWD; | |||||
| break; | |||||
| case NonlineMode::SIGMOID: | |||||
| mode = CUDNN_POINTWISE_SIGMOID_FWD; | |||||
| break; | |||||
| default: | |||||
| megdnn_throw("unsupported non linear mode"); | |||||
| } | |||||
| return cudnn_frontend::PointWiseDescBuilder() | |||||
| .setMode(mode) | |||||
| .setMathPrecision(get_cudnn_data_type(data_type)) | |||||
| .build(); | |||||
| } | |||||
| // high-level api for convolution execution | |||||
| struct StaticData { | |||||
| using Key = megdnn::HeuristicCache::Key; | |||||
| using KeyStorage = megdnn::HeuristicCache::KeyStorage; | |||||
| using KeyHash = megdnn::HeuristicCache::Hash; | |||||
| using Result = cudnn_frontend::ExecutionPlan; | |||||
| using CudnnFrontendExecutionPlanCache = | |||||
| std::unordered_map<KeyStorage, Result, KeyHash>; | |||||
| CudnnFrontendExecutionPlanCache cache; | |||||
| #if __DEPLOY_ON_XP_SP2__ | |||||
| size_t cache_mutex; | |||||
| #else | |||||
| std::mutex cache_mutex; | |||||
| #endif | |||||
| cudnnBackendHeurMode_t heur_mode = CUDNN_HEUR_MODE_INSTANT; | |||||
| bool deterministic = true; | |||||
| }; | |||||
| StaticData& static_data() { | |||||
| static StaticData inst; | |||||
| return inst; | |||||
| } | |||||
| template <typename Opr> | |||||
| struct CudnnBackendOpTypeTrait; | |||||
| template <> | |||||
| struct CudnnBackendOpTypeTrait<ConvolutionForward> { | |||||
| static constexpr cudnnBackendDescriptorType_t OPERATION = | |||||
| CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR; | |||||
| }; | |||||
| template <> | |||||
| struct CudnnBackendOpTypeTrait<ConvolutionBackwardData> { | |||||
| static constexpr cudnnBackendDescriptorType_t OPERATION = | |||||
| CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR; | |||||
| }; | |||||
| template <> | |||||
| struct CudnnBackendOpTypeTrait<ConvolutionBackwardFilter> { | |||||
| static constexpr cudnnBackendDescriptorType_t OPERATION = | |||||
| CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR; | |||||
| }; | |||||
| auto build_opgraph( | |||||
| const cudnnHandle_t& handle, const cudnnBackendDescriptorType_t operation, | |||||
| const cudnn_frontend::Tensor& x, const cudnn_frontend::Tensor& y, | |||||
| const cudnn_frontend::Tensor& w, const cudnn_frontend::ConvDesc_v8& conv_desc) { | |||||
| auto op = cudnn_frontend::OperationBuilder(operation) | |||||
| .setxDesc(x) | |||||
| .setyDesc(y) | |||||
| .setwDesc(w) | |||||
| .setcDesc(conv_desc) | |||||
| .build(); | |||||
| std::array<cudnn_frontend::Operation const*, 1> ops = {&op}; | |||||
| auto op_graph = cudnn_frontend::OperationGraphBuilder() | |||||
| .setHandle(handle) | |||||
| .setOperationGraph(1, ops.data()) | |||||
| .build(); | |||||
| return op_graph; | |||||
| } | |||||
| auto build_opgraph_fused( | |||||
| const cudnnHandle_t& handle, const cudnn_frontend::Tensor& x, | |||||
| const cudnn_frontend::Tensor& y, const cudnn_frontend::Tensor& w, | |||||
| const cudnn_frontend::Tensor& b, const cudnn_frontend::Tensor& z, | |||||
| const cudnn_frontend::Tensor& after_add, | |||||
| const cudnn_frontend::Tensor& after_bias, | |||||
| const cudnn_frontend::Tensor& after_conv, | |||||
| const cudnn_frontend::ConvDesc_v8& conv_desc, | |||||
| const cudnn_frontend::PointWiseDesc_v8& act_desc, float alpha, float beta) { | |||||
| const auto precision = CUDNN_DATA_FLOAT; | |||||
| // add z | |||||
| auto add_desc1 = cudnn_frontend::PointWiseDescBuilder() | |||||
| .setMode(CUDNN_POINTWISE_ADD) | |||||
| .setMathPrecision(precision) | |||||
| .build(); | |||||
| // add bias | |||||
| auto add_desc2 = cudnn_frontend::PointWiseDescBuilder() | |||||
| .setMode(CUDNN_POINTWISE_ADD) | |||||
| .setMathPrecision(precision) | |||||
| .build(); | |||||
| // create conv node | |||||
| auto conv_op = cudnn_frontend::OperationBuilder( | |||||
| CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) | |||||
| .setxDesc(x) | |||||
| .setyDesc(after_conv) | |||||
| .setwDesc(w) | |||||
| .setcDesc(conv_desc) | |||||
| .build(); | |||||
| // create add z node | |||||
| auto add_op1 = cudnn_frontend::OperationBuilder( | |||||
| CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) | |||||
| .setxDesc(conv_op.getOutputTensor()) | |||||
| .setbDesc(z) | |||||
| .setyDesc(after_add) | |||||
| .setpwDesc(add_desc1) | |||||
| .setAlpha(alpha) | |||||
| .setAlpha2(beta) | |||||
| .build(); | |||||
| // create add bias node | |||||
| auto add_op2 = cudnn_frontend::OperationBuilder( | |||||
| CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) | |||||
| .setxDesc(add_op1.getOutputTensor()) | |||||
| .setbDesc(b) | |||||
| .setyDesc(after_bias) | |||||
| .setpwDesc(add_desc2) | |||||
| .build(); | |||||
| // create act node | |||||
| auto act_op = cudnn_frontend::OperationBuilder( | |||||
| CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) | |||||
| .setxDesc(add_op2.getOutputTensor()) | |||||
| .setyDesc(y) | |||||
| .setpwDesc(act_desc) | |||||
| .build(); | |||||
| std::array<cudnn_frontend::Operation const*, 4> ops = { | |||||
| &conv_op, &add_op1, &add_op2, &act_op}; | |||||
| auto op_graph = cudnn_frontend::OperationGraphBuilder() | |||||
| .setHandle(handle) | |||||
| .setOperationGraph(ops.size(), ops.data()) | |||||
| .build(); | |||||
| return op_graph; | |||||
| } | |||||
| auto build_opgraph_fused_nonactivation( | |||||
| const cudnnHandle_t& handle, const cudnn_frontend::Tensor& x, | |||||
| const cudnn_frontend::Tensor& y, const cudnn_frontend::Tensor& w, | |||||
| const cudnn_frontend::Tensor& b, const cudnn_frontend::Tensor& z, | |||||
| const cudnn_frontend::Tensor& after_add, | |||||
| const cudnn_frontend::Tensor& after_conv, | |||||
| const cudnn_frontend::ConvDesc_v8& conv_desc, float alpha, float beta) { | |||||
| const auto precision = CUDNN_DATA_FLOAT; | |||||
| // add z | |||||
| auto add_desc1 = cudnn_frontend::PointWiseDescBuilder() | |||||
| .setMode(CUDNN_POINTWISE_ADD) | |||||
| .setMathPrecision(precision) | |||||
| .build(); | |||||
| // add bias | |||||
| auto add_desc2 = cudnn_frontend::PointWiseDescBuilder() | |||||
| .setMode(CUDNN_POINTWISE_ADD) | |||||
| .setMathPrecision(precision) | |||||
| .build(); | |||||
| // create conv node | |||||
| auto conv_op = cudnn_frontend::OperationBuilder( | |||||
| CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) | |||||
| .setxDesc(x) | |||||
| .setyDesc(after_conv) | |||||
| .setwDesc(w) | |||||
| .setcDesc(conv_desc) | |||||
| .build(); | |||||
| // create add z node | |||||
| auto add_op1 = cudnn_frontend::OperationBuilder( | |||||
| CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) | |||||
| .setxDesc(conv_op.getOutputTensor()) | |||||
| .setbDesc(z) | |||||
| .setyDesc(after_add) | |||||
| .setpwDesc(add_desc1) | |||||
| .setAlpha(alpha) | |||||
| .setAlpha2(beta) | |||||
| .build(); | |||||
| // create add bias node | |||||
| auto add_op2 = cudnn_frontend::OperationBuilder( | |||||
| CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR) | |||||
| .setxDesc(add_op1.getOutputTensor()) | |||||
| .setbDesc(b) | |||||
| .setyDesc(y) | |||||
| .setpwDesc(add_desc2) | |||||
| .build(); | |||||
| std::array<cudnn_frontend::Operation const*, 3> ops = { | |||||
| &conv_op, &add_op1, &add_op2}; | |||||
| auto op_graph = cudnn_frontend::OperationGraphBuilder() | |||||
| .setHandle(handle) | |||||
| .setOperationGraph(ops.size(), ops.data()) | |||||
| .build(); | |||||
| return op_graph; | |||||
| } | |||||
| void filter_engine_configs( | |||||
| cudnn_frontend::EngineConfigList& from, cudnn_frontend::EngineConfigList& to, | |||||
| bool deterministic) { | |||||
| auto filter = [&deterministic](cudnnBackendDescriptor_t c) { | |||||
| if (deterministic) { | |||||
| if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_NONDETERMINISTIC>( | |||||
| c)) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>( | |||||
| c)) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| }; | |||||
| cudnn_frontend::filter(from, to, filter); | |||||
| } | |||||
| }; // namespace | |||||
| /* --------- get heuristic plan from megdnn opr -------- */ | |||||
| template <typename Opr> | |||||
| cudnn_frontend::ExecutionPlan* megdnn::cuda::get_heuristic_plan_from_opr( | |||||
| const Opr* opr, const TensorLayout& x, const TensorLayout& y, | |||||
| const TensorLayout& w, const TensorLayout& b, const TensorLayout& z, | |||||
| const typename Opr::CanonizedFilterMeta& fm) { | |||||
| auto&& param = opr->param(); | |||||
| TensorLayoutArray layouts{x, y, w}; | |||||
| auto key = StaticData::Key{opr->handle(), opr->get_opr_type(), | |||||
| layouts.data(), layouts.size(), | |||||
| ¶m, sizeof(param)} | |||||
| .build_key_storage(); | |||||
| auto& cache = static_data().cache; | |||||
| { | |||||
| MEGDNN_LOCK_GUARD(static_data().cache_mutex); | |||||
| auto iter = cache.find(key); | |||||
| if (iter != cache.end()) { | |||||
| return &iter->second; | |||||
| } | |||||
| } | |||||
| size_t aligned = 16; | |||||
| uint8_t alignment = std::min(opr->handle()->alignment_requirement(), aligned); | |||||
| auto&& handle = cudnn_handle(opr->handle()); | |||||
| auto&& x_desc = make_tensor_descriptor('x', alignment, x, fm.format, fm.group); | |||||
| auto&& y_desc = make_tensor_descriptor('y', alignment, y, fm.format, fm.group); | |||||
| auto&& w_desc = make_filter_descriptor(alignment, fm); | |||||
| auto compute_type = get_compute_type(x.dtype, param.compute_mode); | |||||
| auto&& conv_desc = make_conv_descriptor(compute_type, param); | |||||
| constexpr auto operation = CudnnBackendOpTypeTrait<Opr>::OPERATION; | |||||
| auto op_graph = build_opgraph(handle, operation, x_desc, y_desc, w_desc, conv_desc); | |||||
| auto deterministic = static_data().deterministic; | |||||
| auto heur_mode = static_data().heur_mode; | |||||
| auto heurgen_method = [&deterministic, | |||||
| &heur_mode](cudnn_frontend::OperationGraph& op_graph) | |||||
| -> cudnn_frontend::EngineConfigList { | |||||
| auto heuristics = cudnn_frontend::EngineHeuristicsBuilder() | |||||
| .setOperationGraph(op_graph) | |||||
| .setHeurMode(heur_mode) | |||||
| .build(); | |||||
| auto& engine_configs = | |||||
| heuristics.getEngineConfig(heuristics.getEngineConfigCount()); | |||||
| cudnn_frontend::EngineConfigList filtered_configs; | |||||
| filter_engine_configs(engine_configs, filtered_configs, deterministic); | |||||
| return filtered_configs; | |||||
| }; | |||||
| auto fallback_method = [&deterministic, &heur_mode, | |||||
| &operation](cudnn_frontend::OperationGraph& op_graph) | |||||
| -> cudnn_frontend::EngineConfigList { | |||||
| auto fallback = cudnn_frontend::EngineFallbackListBuilder() | |||||
| .setOperationGraph(op_graph) | |||||
| .setOperation(operation) | |||||
| .build(); | |||||
| auto& fallback_list = fallback.getFallbackList(); | |||||
| cudnn_frontend::EngineConfigList filtered_configs; | |||||
| filter_engine_configs(fallback_list, filtered_configs, deterministic); | |||||
| return filtered_configs; | |||||
| }; | |||||
| std::array<cudnn_frontend::GeneratorSource const, 2> sources = { | |||||
| heurgen_method, fallback_method}; | |||||
| cudnn_frontend::EngineConfigGenerator generator(sources.size(), sources.data()); | |||||
| auto configs = generator.generate_engine_config(op_graph); | |||||
| for (auto& config : configs) { | |||||
| try { | |||||
| auto plan = cudnn_frontend::ExecutionPlanBuilder() | |||||
| .setHandle(handle) | |||||
| .setEngineConfig(config) | |||||
| .build(); | |||||
| auto workspace_size = plan.getWorkspaceSize(); | |||||
| MEGDNN_MARK_USED_VAR(workspace_size); | |||||
| MEGDNN_LOCK_GUARD(static_data().cache_mutex); | |||||
| auto insert = cache.insert(std::make_pair(key, std::move(plan))); | |||||
| return &insert.first->second; | |||||
| } catch (cudnn_frontend::cudnnException& e) { | |||||
| continue; | |||||
| } | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| #define INST(_Opr) \ | |||||
| template cudnn_frontend::ExecutionPlan* megdnn::cuda::get_heuristic_plan_from_opr( \ | |||||
| const _Opr* opr, const TensorLayout& x, const TensorLayout& y, \ | |||||
| const TensorLayout& w, const TensorLayout& b, const TensorLayout& z, \ | |||||
| const typename _Opr::CanonizedFilterMeta& fm); | |||||
| INST(ConvolutionForward); | |||||
| INST(ConvolutionBackwardData); | |||||
| INST(ConvolutionBackwardFilter); | |||||
| /* --------- get heuristic plan from conv_bias opr -------- */ | |||||
| template <> | |||||
| cudnn_frontend::ExecutionPlan* megdnn::cuda::get_heuristic_plan_from_opr( | |||||
| const ConvBiasForward* opr, const TensorLayout& x, const TensorLayout& y, | |||||
| const TensorLayout& w, const TensorLayout& b, const TensorLayout& z, | |||||
| const typename ConvBiasForward::CanonizedFilterMeta& fm) { | |||||
| auto&& param = opr->param(); | |||||
| TensorLayoutArray layouts{x, y, w, b, z}; | |||||
| auto key = StaticData::Key{opr->handle(), opr->get_opr_type(), | |||||
| layouts.data(), layouts.size(), | |||||
| ¶m, sizeof(param)} | |||||
| .build_key_storage(); | |||||
| auto& cache = static_data().cache; | |||||
| { | |||||
| MEGDNN_LOCK_GUARD(static_data().cache_mutex); | |||||
| auto iter = cache.find(key); | |||||
| if (iter != cache.end()) { | |||||
| return &iter->second; | |||||
| } | |||||
| } | |||||
| size_t aligned = 16; | |||||
| uint8_t alignment = std::min(opr->handle()->alignment_requirement(), aligned); | |||||
| auto&& handle = cudnn_handle(opr->handle()); | |||||
| auto&& x_desc = make_tensor_descriptor('x', alignment, x, fm.format, fm.group); | |||||
| auto&& y_desc = make_tensor_descriptor('y', alignment, y, fm.format, fm.group); | |||||
| auto&& w_desc = make_filter_descriptor(alignment, fm); | |||||
| auto&& z_desc = make_tensor_descriptor('z', alignment, y, fm.format, fm.group); | |||||
| auto&& b_desc = make_tensor_descriptor('b', alignment, b, Format::NCHW, fm.group); | |||||
| auto&& after_conv = | |||||
| make_tensor_descriptor('C', alignment, y, fm.format, fm.group, true); | |||||
| auto&& after_add = | |||||
| make_tensor_descriptor('A', alignment, y, fm.format, fm.group, true); | |||||
| auto&& after_bias = | |||||
| make_tensor_descriptor('B', alignment, y, fm.format, fm.group, true); | |||||
| auto compute_type = get_compute_type(x.dtype, param.compute_mode); | |||||
| auto&& conv_desc = make_conv_descriptor(compute_type, param); | |||||
| float alpha, beta; | |||||
| std::tie(alpha, beta) = | |||||
| conv_bias::cudnn_get_conv_bias_act_scale_param(x, y, w, b, z); | |||||
| // Because the OperationGraph has no public copy constructor and default | |||||
| // constructor, here we use a lambda function to bypass the compile error. | |||||
| auto get_op_graph = [&]() { | |||||
| if (param.nonlineMode == param::ConvBias::NonlineMode::IDENTITY) { | |||||
| return build_opgraph_fused_nonactivation( | |||||
| handle, x_desc, y_desc, w_desc, b_desc, z_desc, after_add, | |||||
| after_conv, conv_desc, alpha, beta); | |||||
| } else { | |||||
| auto&& act_desc = | |||||
| make_activation_descriptor(dtype::Float32(), param.nonlineMode); | |||||
| return build_opgraph_fused( | |||||
| handle, x_desc, y_desc, w_desc, b_desc, z_desc, after_add, | |||||
| after_bias, after_conv, conv_desc, act_desc, alpha, beta); | |||||
| } | |||||
| }; | |||||
| auto op_graph = get_op_graph(); | |||||
| auto deterministic = static_data().deterministic; | |||||
| auto heur_mode = static_data().heur_mode; | |||||
| auto heurgen_method = [&deterministic, | |||||
| &heur_mode](cudnn_frontend::OperationGraph& op_graph) | |||||
| -> cudnn_frontend::EngineConfigList { | |||||
| auto heuristics = cudnn_frontend::EngineHeuristicsBuilder() | |||||
| .setOperationGraph(op_graph) | |||||
| .setHeurMode(heur_mode) | |||||
| .build(); | |||||
| auto& engine_configs = | |||||
| heuristics.getEngineConfig(heuristics.getEngineConfigCount()); | |||||
| cudnn_frontend::EngineConfigList filtered_configs; | |||||
| filter_engine_configs(engine_configs, filtered_configs, deterministic); | |||||
| return filtered_configs; | |||||
| }; | |||||
| std::array<cudnn_frontend::GeneratorSource const, 1> sources = {heurgen_method}; | |||||
| cudnn_frontend::EngineConfigGenerator generator(sources.size(), sources.data()); | |||||
| auto configs = generator.generate_engine_config(op_graph); | |||||
| for (auto& config : configs) { | |||||
| try { | |||||
| auto plan = cudnn_frontend::ExecutionPlanBuilder() | |||||
| .setHandle(handle) | |||||
| .setEngineConfig(config) | |||||
| .build(); | |||||
| auto workspace_size = plan.getWorkspaceSize(); | |||||
| MEGDNN_MARK_USED_VAR(workspace_size); | |||||
| MEGDNN_LOCK_GUARD(static_data().cache_mutex); | |||||
| auto insert = cache.insert(std::make_pair(key, std::move(plan))); | |||||
| return &insert.first->second; | |||||
| } catch (cudnn_frontend::cudnnException& e) { | |||||
| continue; | |||||
| } | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| /* ------ impl for running a single conv ----- */ | |||||
| void megdnn::cuda::run_single_conv_with_plan( | |||||
| const cudnnHandle_t& handle, const cudnn_frontend::ExecutionPlan& plan, | |||||
| const TensorND& x, const TensorND& y, const TensorND& w, | |||||
| const Workspace& workspace) { | |||||
| size_t workspace_size = plan.getWorkspaceSize(); | |||||
| megdnn_assert( | |||||
| workspace.size >= workspace_size, | |||||
| "workspace does not meet the requirement of execution " | |||||
| "plan(got:%zu,expected:%zu)", | |||||
| workspace.size, workspace_size); | |||||
| void* data_ptrs[] = {x.raw_ptr(), y.raw_ptr(), w.raw_ptr()}; | |||||
| int64_t uids[] = {'x', 'y', 'w'}; | |||||
| auto variant_pack = cudnn_frontend::VariantPackBuilder() | |||||
| .setWorkspacePointer(workspace.raw_ptr) | |||||
| .setDataPointers(3, data_ptrs) | |||||
| .setUids(3, uids) | |||||
| .build(); | |||||
| cudnn_check(cudnnBackendExecute( | |||||
| handle, plan.get_raw_desc(), variant_pack.get_raw_desc())); | |||||
| } | |||||
| /* ------ impl for running a fused conv bias activation ----- */ | |||||
| void megdnn::cuda::run_conv_bias_act_with_plan( | |||||
| const cudnnHandle_t& handle, const cudnn_frontend::ExecutionPlan& plan, | |||||
| const TensorND& x, const TensorND& y, const TensorND& w, const TensorND& b, | |||||
| const TensorND& z, const Workspace& workspace) { | |||||
| size_t workspace_size = plan.getWorkspaceSize(); | |||||
| megdnn_assert( | |||||
| workspace.size >= workspace_size, | |||||
| "workspace does not meet the requirement of execution " | |||||
| "plan(got:%zu,expected:%zu)", | |||||
| workspace.size, workspace_size); | |||||
| void* z_ptr = z.layout.ndim == 0 ? nullptr : z.raw_ptr(); | |||||
| void* data_ptrs[] = {x.raw_ptr(), y.raw_ptr(), w.raw_ptr(), z_ptr, b.raw_ptr()}; | |||||
| int64_t uids[] = {'x', 'y', 'w', 'z', 'b'}; | |||||
| auto variant_pack = cudnn_frontend::VariantPackBuilder() | |||||
| .setWorkspacePointer(workspace.raw_ptr) | |||||
| .setDataPointers(5, data_ptrs) | |||||
| .setUids(5, uids) | |||||
| .build(); | |||||
| cudnn_check(cudnnBackendExecute( | |||||
| handle, plan.get_raw_desc(), variant_pack.get_raw_desc())); | |||||
| } | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,70 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/cudnn_wrapper_v8.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include "megdnn/basic_types.h" | |||||
| #include "megdnn/oprs/nn.h" | |||||
| #include "src/common/utils.h" | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-variable" | |||||
| #pragma GCC diagnostic ignored "-Wunused-function" | |||||
| #pragma GCC diagnostic ignored "-Wreorder" | |||||
| #include "cudnn_frontend.h" | |||||
| #pragma GCC diagnostic pop | |||||
| namespace megdnn { | |||||
| namespace cuda { | |||||
| static inline std::pair<int64_t, int64_t> get_vector_count_and_dimension( | |||||
| const param::Convolution::Format format) { | |||||
| using Format = param::Convolution::Format; | |||||
| int64_t vector_count = 1; | |||||
| int64_t vector_dimension = 1; | |||||
| switch (format) { | |||||
| case Format::NCHW: | |||||
| break; | |||||
| case Format::NHWC: | |||||
| vector_dimension = 3; | |||||
| break; | |||||
| case Format::NCHW4: | |||||
| vector_count = 4; | |||||
| break; | |||||
| case Format::NCHW32: | |||||
| vector_count = 32; | |||||
| break; | |||||
| default: | |||||
| megdnn_assert( | |||||
| false, "unsupported format (got:%u) for cudnn", | |||||
| static_cast<uint32_t>(format)); | |||||
| } | |||||
| return {vector_count, vector_dimension}; | |||||
| } | |||||
| template <typename Opr> | |||||
| cudnn_frontend::ExecutionPlan* get_heuristic_plan_from_opr( | |||||
| const Opr* opr, const TensorLayout& x, const TensorLayout& y, | |||||
| const TensorLayout& w, const TensorLayout& b, const TensorLayout& z, | |||||
| const typename Opr::CanonizedFilterMeta& fm); | |||||
| void run_single_conv_with_plan( | |||||
| const cudnnHandle_t& handle, const cudnn_frontend::ExecutionPlan& plan, | |||||
| const TensorND& x, const TensorND& y, const TensorND& w, | |||||
| const Workspace& workspace); | |||||
| void run_conv_bias_act_with_plan( | |||||
| const cudnnHandle_t& handle, const cudnn_frontend::ExecutionPlan& plan, | |||||
| const TensorND& x, const TensorND& y, const TensorND& w, const TensorND& b, | |||||
| const TensorND& z, const Workspace& workspace); | |||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -58,6 +58,11 @@ HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle) | |||||
| For example `export CUDA_CACHE_MAXSIZE=2147483647` and `export CUDA_CACHE_PATH=/data/.cuda_cache`)"); | For example `export CUDA_CACHE_MAXSIZE=2147483647` and `export CUDA_CACHE_PATH=/data/.cuda_cache`)"); | ||||
| } | } | ||||
| #endif | #endif | ||||
| size_t free, tot; | |||||
| cudaMemGetInfo(&free, &tot); | |||||
| printf("before cudnn create, free: %.2f MB, tot: %.2f MB, allocated: %.2f MB\n", | |||||
| free / 1024.0 / 1024.0, tot / 1024.0 / 1024.0, | |||||
| (tot - free) / 1024.0 / 1024.0); | |||||
| cudnn_check(cudnnCreate(&m_cudnn_handle)); | cudnn_check(cudnnCreate(&m_cudnn_handle)); | ||||
| cublas_check(cublasCreate(&m_cublas_handle)); | cublas_check(cublasCreate(&m_cublas_handle)); | ||||
| #if CUDA_VERSION >= 10010 | #if CUDA_VERSION >= 10010 | ||||
| @@ -69,6 +74,11 @@ HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle) | |||||
| cudnn_check(cudnnSetStream(m_cudnn_handle, stream())); | cudnn_check(cudnnSetStream(m_cudnn_handle, stream())); | ||||
| cublas_check(cublasSetStream(m_cublas_handle, stream())); | cublas_check(cublasSetStream(m_cublas_handle, stream())); | ||||
| #if CUDNN_VERSION >= 8004 | |||||
| // cudnn_check(cudnnOpsInferVersionCheck()); | |||||
| // cudnn_check(cudnnCnnInferVersionCheck()); | |||||
| #endif | |||||
| // Note that all cublas scalars (alpha, beta) and scalar results such as dot | // Note that all cublas scalars (alpha, beta) and scalar results such as dot | ||||
| // output resides at device side. | // output resides at device side. | ||||
| cublas_check(cublasSetPointerMode(m_cublas_handle, CUBLAS_POINTER_MODE_DEVICE)); | cublas_check(cublasSetPointerMode(m_cublas_handle, CUBLAS_POINTER_MODE_DEVICE)); | ||||
| @@ -82,6 +92,11 @@ HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle) | |||||
| cudaMemcpyHostToDevice, stream())); | cudaMemcpyHostToDevice, stream())); | ||||
| cuda_check(cudaStreamSynchronize(stream())); | cuda_check(cudaStreamSynchronize(stream())); | ||||
| cudaMemGetInfo(&free, &tot); | |||||
| printf("after cudnn create, free: %.2f MB, tot: %.2f MB, allocated: %.2f MB\n", | |||||
| free / 1024.0 / 1024.0, tot / 1024.0 / 1024.0, | |||||
| (tot - free) / 1024.0 / 1024.0); | |||||
| // check tk1 | // check tk1 | ||||
| m_is_tegra_k1 = (strcmp(m_device_prop->name, "GK20A") == 0); | m_is_tegra_k1 = (strcmp(m_device_prop->name, "GK20A") == 0); | ||||
| m_cusolver_handle = nullptr; | m_cusolver_handle = nullptr; | ||||
| @@ -0,0 +1,304 @@ | |||||
| /** | |||||
| * \file dnn/test/cuda/conv_bias.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #include "megdnn/dtype.h" | |||||
| #include "test/cuda/fixture.h" | |||||
| #include "megdnn/opr_param_defs.h" | |||||
| #include "megdnn/oprs.h" | |||||
| #include "src/cuda/handle.h" | |||||
| #include "test/common/benchmarker.h" | |||||
| #include "test/common/checker.h" | |||||
| #include "test/common/conv_bias.h" | |||||
| #include "test/common/rng.h" | |||||
| #include "test/common/tensor.h" | |||||
| #include "test/common/workspace_wrapper.h" | |||||
| #include "test/cuda/utils.h" | |||||
| using namespace megdnn; | |||||
| using namespace test; | |||||
| using namespace conv_bias; | |||||
| #if CUDNN_VERSION >= 8004 | |||||
| TEST_F(CUDA, CONV_V8_FLOAT) { | |||||
| Checker<ConvBiasForward> checker(handle_cuda()); | |||||
| checker.set_before_exec_callback( | |||||
| conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(ExecutionPolicyAlgoName{ | |||||
| ConvBiasForward::algo_name<ConvBiasForward::DefaultParam>( | |||||
| "CUDNN:ConvolutionV8", {}) | |||||
| .c_str()})); | |||||
| UniformFloatRNG rng(0.f, 1.f); | |||||
| checker.set_rng(0, &rng) | |||||
| .set_rng(1, &rng) | |||||
| .set_rng(2, &rng) | |||||
| .set_rng(3, &rng) | |||||
| .set_dtype(0, dtype::Float32()) | |||||
| .set_dtype(1, dtype::Float32()) | |||||
| .set_dtype(2, dtype::Float32()) | |||||
| .set_dtype(3, dtype::Float32()); | |||||
| param::ConvBias param; | |||||
| param.pad_h = param.pad_w = 1; | |||||
| param.stride_h = param.stride_w = 1; | |||||
| param.format = param::ConvBias::Format::NCHW; | |||||
| param.nonlineMode = param::ConvBias::NonlineMode::RELU; | |||||
| checker.set_param(param).execs( | |||||
| {{1, 64, 7, 7}, {64, 64, 3, 3}, {1, 64, 1, 1}, {}, {}}); | |||||
| checker.set_param(param).execs( | |||||
| {{1, 64, 7, 7}, {64, 64, 3, 3}, {1, 64, 1, 1}, {1, 64, 7, 7}, {}}); | |||||
| // group | |||||
| param.sparse = param::ConvBias::Sparse::GROUP; | |||||
| checker.set_param(param).execs( | |||||
| {{1, 64, 7, 7}, {8, 8, 8, 3, 3}, {1, 64, 1, 1}, {}, {}}); | |||||
| checker.set_param(param).execs( | |||||
| {{1, 64, 7, 7}, {8, 8, 8, 3, 3}, {1, 64, 1, 1}, {1, 64, 7, 7}, {}}); | |||||
| // NHWC | |||||
| param.format = param::ConvBias::Format::NHWC; | |||||
| checker.set_param(param).execs( | |||||
| {{1, 7, 7, 64}, {8, 8, 3, 3, 8}, {1, 1, 1, 64}, {}, {}}); | |||||
| checker.set_param(param).execs( | |||||
| {{1, 7, 7, 64}, {8, 8, 3, 3, 8}, {1, 1, 1, 64}, {1, 7, 7, 64}, {}}); | |||||
| } | |||||
| TEST_F(CUDA, CONV_V8_HALF) { | |||||
| Checker<ConvBiasForward> checker(handle_cuda()); | |||||
| checker.set_before_exec_callback( | |||||
| conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(ExecutionPolicyAlgoName{ | |||||
| ConvBiasForward::algo_name<ConvBiasForward::DefaultParam>( | |||||
| "CUDNN:ConvolutionV8", {}) | |||||
| .c_str()})); | |||||
| UniformFloatRNG rng(0.f, 1.f); | |||||
| checker.set_rng(0, &rng) | |||||
| .set_rng(1, &rng) | |||||
| .set_rng(2, &rng) | |||||
| .set_rng(3, &rng) | |||||
| .set_dtype(0, dtype::Float16()) | |||||
| .set_dtype(1, dtype::Float16()) | |||||
| .set_dtype(2, dtype::Float16()) | |||||
| .set_dtype(3, dtype::Float16()) | |||||
| .set_dtype(4, dtype::Float16()) | |||||
| .set_epsilon(5e-2); | |||||
| param::ConvBias param; | |||||
| param.pad_h = param.pad_w = 1; | |||||
| param.stride_h = param.stride_w = 1; | |||||
| param.format = param::ConvBias::Format::NCHW; | |||||
| param.nonlineMode = param::ConvBias::NonlineMode::RELU; | |||||
| param.compute_mode = param::ConvBias::ComputeMode::FLOAT32; | |||||
| checker.set_param(param).execs( | |||||
| {{1, 64, 7, 7}, {64, 64, 3, 3}, {1, 64, 1, 1}, {}, {}}); | |||||
| checker.set_param(param).execs( | |||||
| {{1, 64, 7, 7}, {64, 64, 3, 3}, {1, 64, 1, 1}, {1, 64, 7, 7}, {}}); | |||||
| // group | |||||
| param.sparse = param::ConvBias::Sparse::GROUP; | |||||
| checker.set_param(param).execs( | |||||
| {{1, 64, 7, 7}, {8, 8, 8, 3, 3}, {1, 64, 1, 1}, {}, {}}); | |||||
| checker.set_param(param).execs( | |||||
| {{1, 64, 7, 7}, {8, 8, 8, 3, 3}, {1, 64, 1, 1}, {1, 64, 7, 7}, {}}); | |||||
| // NHWC | |||||
| param.format = param::ConvBias::Format::NHWC; | |||||
| checker.set_param(param).execs( | |||||
| {{1, 7, 7, 64}, {8, 8, 3, 3, 8}, {1, 1, 1, 64}, {}, {}}); | |||||
| checker.set_param(param).execs( | |||||
| {{1, 7, 7, 64}, {8, 8, 3, 3, 8}, {1, 1, 1, 64}, {1, 7, 7, 64}, {}}); | |||||
| } | |||||
| TEST_F(CUDA, CONV_BIAS_V8_FLOAT) { | |||||
| Checker<ConvBiasForward> checker(handle_cuda()); | |||||
| checker.set_before_exec_callback( | |||||
| conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(ExecutionPolicyAlgoName{ | |||||
| ConvBiasForward::algo_name<ConvBiasForward::DefaultParam>( | |||||
| "CUDNN:ConvBiasActivationV8", {}) | |||||
| .c_str()})); | |||||
| UniformFloatRNG rng(0.f, 1.f); | |||||
| UniformFloatRNG crng(0.f, 0.f); | |||||
| checker.set_rng(0, &rng) | |||||
| .set_rng(1, &rng) | |||||
| .set_rng(2, &rng) | |||||
| .set_rng(3, &rng) | |||||
| .set_dtype(0, dtype::Float32()) | |||||
| .set_dtype(1, dtype::Float32()) | |||||
| .set_dtype(2, dtype::Float32()) | |||||
| .set_dtype(3, dtype::Float32()); | |||||
| param::ConvBias param; | |||||
| param.pad_h = param.pad_w = 1; | |||||
| param.stride_h = param.stride_w = 1; | |||||
| param.format = param::ConvBias::Format::NCHW; | |||||
| param.nonlineMode = param::ConvBias::NonlineMode::RELU; | |||||
| checker.set_param(param).execs( | |||||
| {{1, 64, 7, 7}, {64, 64, 3, 3}, {1, 64, 1, 1}, {}, {}}); | |||||
| checker.set_param(param).execs( | |||||
| {{1, 64, 7, 7}, {64, 64, 3, 3}, {1, 64, 1, 1}, {1, 64, 7, 7}, {}}); | |||||
| // group | |||||
| param.sparse = param::ConvBias::Sparse::GROUP; | |||||
| checker.set_param(param).execs( | |||||
| {{1, 64, 7, 7}, {8, 8, 8, 3, 3}, {1, 64, 1, 1}, {}, {}}); | |||||
| checker.set_param(param).execs( | |||||
| {{1, 64, 7, 7}, {8, 8, 8, 3, 3}, {1, 64, 1, 1}, {1, 64, 7, 7}, {}}); | |||||
| // NHWC | |||||
| param.format = param::ConvBias::Format::NHWC; | |||||
| checker.set_param(param).execs( | |||||
| {{1, 7, 7, 64}, {8, 8, 3, 3, 8}, {1, 1, 1, 64}, {}, {}}); | |||||
| checker.set_param(param).execs( | |||||
| {{1, 7, 7, 64}, {8, 8, 3, 3, 8}, {1, 1, 1, 64}, {1, 7, 7, 64}, {}}); | |||||
| } | |||||
| TEST_F(CUDA, CONV_BIAS_V8_HALF) { | |||||
| Checker<ConvBiasForward> checker(handle_cuda()); | |||||
| checker.set_before_exec_callback( | |||||
| conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(ExecutionPolicyAlgoName{ | |||||
| ConvBiasForward::algo_name<ConvBiasForward::DefaultParam>( | |||||
| "CUDNN:ConvBiasActivationV8", {}) | |||||
| .c_str()})); | |||||
| UniformFloatRNG rng(0.f, 1.f); | |||||
| checker.set_rng(0, &rng) | |||||
| .set_rng(1, &rng) | |||||
| .set_rng(2, &rng) | |||||
| .set_rng(3, &rng) | |||||
| .set_dtype(0, dtype::Float16()) | |||||
| .set_dtype(1, dtype::Float16()) | |||||
| .set_dtype(2, dtype::Float16()) | |||||
| .set_dtype(3, dtype::Float16()) | |||||
| .set_dtype(4, dtype::Float16()) | |||||
| .set_epsilon(5e-2); | |||||
| param::ConvBias param; | |||||
| param.pad_h = param.pad_w = 1; | |||||
| param.stride_h = param.stride_w = 1; | |||||
| param.format = param::ConvBias::Format::NCHW; | |||||
| param.nonlineMode = param::ConvBias::NonlineMode::RELU; | |||||
| param.compute_mode = param::ConvBias::ComputeMode::FLOAT32; | |||||
| checker.set_param(param).execs( | |||||
| {{1, 64, 7, 7}, {64, 64, 3, 3}, {1, 64, 1, 1}, {}, {}}); | |||||
| checker.set_param(param).execs( | |||||
| {{1, 64, 7, 7}, {64, 64, 3, 3}, {1, 64, 1, 1}, {1, 64, 7, 7}, {}}); | |||||
| // group | |||||
| param.sparse = param::ConvBias::Sparse::GROUP; | |||||
| checker.set_param(param).execs( | |||||
| {{1, 64, 7, 7}, {8, 8, 8, 3, 3}, {1, 64, 1, 1}, {}, {}}); | |||||
| checker.set_param(param).execs( | |||||
| {{1, 64, 7, 7}, {8, 8, 8, 3, 3}, {1, 64, 1, 1}, {1, 64, 7, 7}, {}}); | |||||
| // NHWC | |||||
| param.format = param::ConvBias::Format::NHWC; | |||||
| checker.set_param(param).execs( | |||||
| {{1, 7, 7, 64}, {8, 8, 3, 3, 8}, {1, 1, 1, 64}, {}, {}}); | |||||
| checker.set_param(param).execs( | |||||
| {{1, 7, 7, 64}, {8, 8, 3, 3, 8}, {1, 1, 1, 64}, {1, 7, 7, 64}, {}}); | |||||
| } | |||||
| TEST_F(CUDA, CONV_BIAS_V8_DP4A) { | |||||
| Checker<ConvBiasForward> checker(handle_cuda()); | |||||
| checker.set_before_exec_callback( | |||||
| conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(ExecutionPolicyAlgoName{ | |||||
| ConvBiasForward::algo_name<ConvBiasForward::DefaultParam>( | |||||
| "CUDNN:ConvBiasActivationV8", {}) | |||||
| .c_str()})); | |||||
| UniformIntRNG rng{-3, 3}; | |||||
| UniformIntRNG bias_rng{-50, 50}; | |||||
| checker.set_rng(0, &rng) | |||||
| .set_rng(1, &rng) | |||||
| .set_rng(2, &bias_rng) | |||||
| .set_rng(3, &rng) | |||||
| .set_dtype(0, dtype::QuantizedS8{1.2f}) | |||||
| .set_dtype(1, dtype::QuantizedS8{1.3f}) | |||||
| .set_dtype(2, dtype::QuantizedS32{1.2f * 1.3f}) | |||||
| .set_dtype(3, dtype::QuantizedS8{1.1f}) | |||||
| .set_dtype(4, dtype::QuantizedS8{1.0f}) | |||||
| .set_epsilon(1 + 1e-3); | |||||
| param::ConvBias param; | |||||
| param.pad_h = param.pad_w = 1; | |||||
| param.stride_h = param.stride_w = 1; | |||||
| param.format = param::ConvBias::Format::NCHW4; | |||||
| param.nonlineMode = param::ConvBias::NonlineMode::RELU; | |||||
| checker.set_param(param).execs( | |||||
| {{1, 16, 7, 7, 4}, {64, 16, 3, 3, 4}, {1, 16, 1, 1, 4}, {}, {}}); | |||||
| checker.set_param(param).execs( | |||||
| {{1, 16, 7, 7, 4}, | |||||
| {64, 16, 3, 3, 4}, | |||||
| {1, 16, 1, 1, 4}, | |||||
| {1, 16, 7, 7, 4}, | |||||
| {}}); | |||||
| param.nonlineMode = param::ConvBias::NonlineMode::IDENTITY; | |||||
| checker.set_param(param).execs( | |||||
| {{1, 16, 7, 7, 4}, {64, 16, 3, 3, 4}, {1, 16, 1, 1, 4}, {}, {}}); | |||||
| checker.set_param(param).execs( | |||||
| {{1, 16, 7, 7, 4}, | |||||
| {64, 16, 3, 3, 4}, | |||||
| {1, 16, 1, 1, 4}, | |||||
| {1, 16, 7, 7, 4}, | |||||
| {}}); | |||||
| param.format = param::ConvBias::Format::NHWC; | |||||
| checker.set_param(param).execs( | |||||
| {{1, 7, 7, 64}, {64, 3, 3, 64}, {1, 1, 1, 64}, {}, {}}); | |||||
| checker.set_param(param).execs( | |||||
| {{1, 7, 7, 64}, {64, 3, 3, 64}, {1, 1, 1, 64}, {1, 7, 7, 64}, {}}); | |||||
| param.sparse = param::ConvBias::Sparse::GROUP; | |||||
| checker.set_param(param).execs( | |||||
| {{1, 7, 7, 64}, {8, 8, 3, 3, 8}, {1, 1, 1, 64}, {}, {}}); | |||||
| checker.set_param(param).execs( | |||||
| {{1, 7, 7, 64}, {8, 8, 3, 3, 8}, {1, 1, 1, 64}, {1, 7, 7, 64}, {}}); | |||||
| } | |||||
| TEST_F(CUDA, CONV_BIAS_V8_IMMA) { | |||||
| Checker<ConvBiasForward> checker(handle_cuda()); | |||||
| checker.set_before_exec_callback( | |||||
| conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(ExecutionPolicyAlgoName{ | |||||
| ConvBiasForward::algo_name<ConvBiasForward::DefaultParam>( | |||||
| "CUDNN:ConvBiasActivationV8", {}) | |||||
| .c_str()})); | |||||
| UniformIntRNG rng{-3, 3}; | |||||
| UniformIntRNG bias_rng{-50, 50}; | |||||
| checker.set_rng(0, &rng) | |||||
| .set_rng(1, &rng) | |||||
| .set_rng(2, &bias_rng) | |||||
| .set_rng(3, &rng) | |||||
| .set_dtype(0, dtype::QuantizedS8{1.2f}) | |||||
| .set_dtype(1, dtype::QuantizedS8{1.3f}) | |||||
| .set_dtype(2, dtype::QuantizedS32{1.2f * 1.3f}) | |||||
| .set_dtype(3, dtype::QuantizedS8{1.1f}) | |||||
| .set_dtype(4, dtype::QuantizedS8{1.0f}) | |||||
| .set_epsilon(1 + 1e-3); | |||||
| param::ConvBias param; | |||||
| param.pad_h = param.pad_w = 1; | |||||
| param.stride_h = param.stride_w = 1; | |||||
| param.format = param::ConvBias::Format::NCHW32; | |||||
| param.nonlineMode = param::ConvBias::NonlineMode::RELU; | |||||
| checker.set_param(param).execs( | |||||
| {{1, 2, 7, 7, 32}, {64, 2, 3, 3, 32}, {1, 2, 1, 1, 32}, {}, {}}); | |||||
| checker.set_param(param).execs( | |||||
| {{1, 2, 7, 7, 32}, | |||||
| {64, 2, 3, 3, 32}, | |||||
| {1, 2, 1, 1, 32}, | |||||
| {1, 2, 7, 7, 32}, | |||||
| {}}); | |||||
| param.nonlineMode = NonlineMode::RELU; | |||||
| param.stride_h = param.stride_w = 1; | |||||
| param.pad_h = param.pad_w = 0; | |||||
| checker.set_param(param).execs( | |||||
| {{2, 8, 12, 12, 32}, {512, 8, 1, 1, 32}, {1, 16, 1, 1, 32}, {}, {}}); | |||||
| } | |||||
| #endif | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -94,6 +94,7 @@ function git_submodule_update() { | |||||
| git submodule sync | git submodule sync | ||||
| git submodule update -f --init midout | git submodule update -f --init midout | ||||
| git submodule update -f --init flatbuffers | git submodule update -f --init flatbuffers | ||||
| git submodule update -f --init cudnn-frontend | |||||
| git submodule update -f --init Json | git submodule update -f --init Json | ||||
| git submodule update -f --init gflags | git submodule update -f --init gflags | ||||
| git submodule update -f --init cpuinfo | git submodule update -f --init cpuinfo | ||||