This reverts committags/v1.7.0188c62cdd6. GitOrigin-RevId:92a82b8cd9
| @@ -12,28 +12,32 @@ | |||
| #pragma once | |||
| #include <cstring> | |||
| #include <unordered_map> | |||
| #include <memory> | |||
| #include <cstring> | |||
| #include <tuple> | |||
| #include <unordered_map> | |||
| #include "megdnn/thin/function.h" | |||
| namespace megdnn { | |||
| template <typename... TArgs> | |||
| class FunctionCache { | |||
| template <typename TSignature> | |||
| class FunctionCache; | |||
| template <typename TRet, typename... TArgs> | |||
| class FunctionCache<TRet(TArgs...)> { | |||
| public: | |||
| using key_t = std::string; | |||
| using value_t = std::string; | |||
| using value_t = TRet; | |||
| using key_mapper_t = thin_function<key_t(TArgs...)>; | |||
| using value_mapper_t = thin_function<value_t(TArgs...)>; | |||
| using storage_t = std::unordered_map<key_t, value_t>; | |||
| public: | |||
| storage_t storage; | |||
| key_mapper_t key_mapper; | |||
| value_mapper_t value_mapper; | |||
| value_t operator()(TArgs... args) { | |||
| public: | |||
| TRet operator()(TArgs... args) { | |||
| key_t key = key_mapper(args...); | |||
| if (storage.count(key) == 0) { | |||
| storage[key] = value_mapper(std::forward<TArgs>(args)...); | |||
| @@ -42,28 +46,28 @@ public: | |||
| } | |||
| }; | |||
| // FIFO | |||
| class StringSerializer { | |||
| private: | |||
| std::string m_buffer; | |||
| size_t m_cursor = 0; | |||
| public: | |||
| template <typename T> | |||
| T read_plain() { | |||
| static_assert(std::is_trivially_copyable<T>::value, "invalid type"); | |||
| T ret; | |||
| memcpy(&ret, m_buffer.data() + m_cursor, sizeof(T)); | |||
| T result; | |||
| std::memcpy(&result, m_buffer.data() + m_cursor, sizeof(T)); | |||
| m_cursor += sizeof(T); | |||
| return ret; | |||
| return result; | |||
| } | |||
| template <typename T> | |||
| void write_plain(T value) { | |||
| static_assert(std::is_trivially_copyable<T>::value, | |||
| "type should be trivially copyable"); | |||
| m_buffer.append(reinterpret_cast<const char*>(&value), sizeof(T)); | |||
| m_buffer.resize(m_buffer.size() + sizeof(T)); | |||
| std::memcpy(const_cast<char*>(m_buffer.data()) + (m_buffer.size() - sizeof(T)), &value, sizeof(T)); | |||
| } | |||
| std::string take() { | |||
| std::string result; | |||
| m_buffer.erase(0, m_cursor); | |||
| return std::move(m_buffer); | |||
| } | |||
| void set(std::string new_buf) { | |||
| @@ -72,20 +76,20 @@ public: | |||
| } | |||
| }; | |||
| struct Empty {}; | |||
| template <typename... TParams> | |||
| class ParamBundle { | |||
| private: | |||
| template <std::size_t N, std::size_t... Seq> | |||
| static std::index_sequence<N + Seq...> add_all( | |||
| std::index_sequence<Seq...>) { | |||
| template<std::size_t N, std::size_t... Seq> | |||
| static std::index_sequence<N + Seq ...> add_all(std::index_sequence<Seq...>){ | |||
| return {}; | |||
| } | |||
| template <std::size_t Min, std::size_t Max> | |||
| using make_index_range = | |||
| decltype(add_all<Min>(std::make_index_sequence<Max - Min>())); | |||
| template<std::size_t Min, std::size_t Max> | |||
| using make_index_range = decltype(add_all<Min>(std::make_index_sequence<Max-Min>())); | |||
| using storage_t = std::tuple<typename std::remove_reference_t<TParams>...>; | |||
| storage_t m_storage; | |||
| @@ -95,31 +99,21 @@ private: | |||
| return functor(std::get<Indices>(m_storage).value...); | |||
| } | |||
| template <size_t Index, size_t... Indices, typename TPrev> | |||
| auto serialize_helper(StringSerializer& ser, TPrev&& prev, | |||
| std::index_sequence<Index, Indices...>) { | |||
| return serialize_helper(ser, | |||
| std::get<Index>(m_storage).serialize(ser, prev), | |||
| std::index_sequence<Indices...>()); | |||
| auto serialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence<Index, Indices...>) { | |||
| return serialize_helper(ser, std::get<Index>(m_storage).serialize(ser, prev), std::index_sequence<Indices...>()); | |||
| } | |||
| template <typename TPrev> | |||
| auto serialize_helper(StringSerializer& ser, TPrev&& prev, | |||
| std::index_sequence<>) {} | |||
| auto serialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence<>) {} | |||
| template <size_t Index, size_t... Indices, typename TPrev> | |||
| auto deserialize_helper(StringSerializer& ser, TPrev&& prev, | |||
| std::index_sequence<Index, Indices...>) { | |||
| return deserialize_helper( | |||
| ser, std::get<Index>(m_storage).deserialize(ser, prev), | |||
| std::index_sequence<Indices...>()); | |||
| auto deserialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence<Index, Indices...>) { | |||
| return deserialize_helper(ser, std::get<Index>(m_storage).deserialize(ser, prev), std::index_sequence<Indices...>()); | |||
| } | |||
| template <typename TPrev> | |||
| auto deserialize_helper(StringSerializer& ser, TPrev&& prev, | |||
| std::index_sequence<>) {} | |||
| auto deserialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence<>) {} | |||
| template <size_t Index, size_t... Indices, typename TArg, typename... TArgs> | |||
| void set_values_helper(std::index_sequence<Index, Indices...>, TArg&& arg, | |||
| TArgs&&... args) { | |||
| void set_values_helper(std::index_sequence<Index, Indices...>, TArg&& arg, TArgs&&... args) { | |||
| std::get<Index>(m_storage).value = arg; | |||
| set_values_helper(std::index_sequence<Indices...>(), | |||
| std::forward<TArgs>(args)...); | |||
| set_values_helper(std::index_sequence<Indices...>(), std::forward<TArgs>(args)...); | |||
| } | |||
| template <size_t... Indices> | |||
| void set_values_helper(std::index_sequence<Indices...>) { | |||
| @@ -129,33 +123,27 @@ private: | |||
| public: | |||
| template <typename TFunctor> | |||
| auto call_by(TFunctor&& functor) { | |||
| return call_helper(std::forward<TFunctor>(functor), | |||
| std::make_index_sequence<sizeof...(TParams)>()); | |||
| return call_helper(std::forward<TFunctor>(functor), std::make_index_sequence<sizeof...(TParams)>()); | |||
| } | |||
| template <size_t NBegin, size_t NEnd> | |||
| void serialize_params(StringSerializer& ser) { | |||
| static_assert(NEnd >= NBegin, "invalid range"); | |||
| serialize_helper( | |||
| ser, Empty{}, | |||
| add_all<NBegin>(std::make_index_sequence<NEnd - NBegin>())); | |||
| serialize_helper(ser, Empty{}, make_index_range<NBegin, NEnd>()); | |||
| } | |||
| template <size_t NBegin, size_t NEnd> | |||
| void deserialize_params(StringSerializer& ser) { | |||
| static_assert(NEnd >= NBegin, "invalid range"); | |||
| deserialize_helper( | |||
| ser, Empty{}, | |||
| add_all<NBegin>(std::make_index_sequence<NEnd - NBegin>())); | |||
| deserialize_helper(ser, Empty{}, make_index_range<NBegin, NEnd>()); | |||
| } | |||
| template <size_t NBegin, size_t NEnd, typename... TArgs> | |||
| void set_values(TArgs&&... args) { | |||
| set_values_helper( | |||
| add_all<NBegin>(std::make_index_sequence<NEnd - NBegin>()), | |||
| std::forward<TArgs>(args)...); | |||
| set_values_helper(make_index_range<NBegin, NEnd>(), std::forward<TArgs>(args)...); | |||
| } | |||
| }; | |||
| template <typename T> | |||
| class Param { | |||
| class RetParam { | |||
| public: | |||
| T value; | |||
| Empty serialize(StringSerializer& ser, Empty) { | |||
| @@ -168,68 +156,45 @@ public: | |||
| } | |||
| }; | |||
| template <typename TRet = Param<Empty>, typename TInputs = std::tuple<>, | |||
| typename TOutputs = std::tuple<>> | |||
| template <typename TRet=RetParam<Empty>, typename TInputs=std::tuple<>, typename TOutputs=std::tuple<>> | |||
| class FunctionCacheBuilder { | |||
| private: | |||
| static auto declargs() | |||
| -> decltype(std::tuple_cat(std::declval<TInputs>(), | |||
| std::declval<TOutputs>())) { | |||
| return {}; | |||
| } | |||
| static auto declargs() -> decltype(std::tuple_cat(std::declval<TInputs>(), std::declval<TOutputs>())) { return {}; } | |||
| template <size_t... Indices> | |||
| static auto declfunction_helper(std::index_sequence<Indices...>) | |||
| -> thin_function<decltype(std::declval<TRet>().value)( | |||
| decltype(std::get<Indices>(declargs()).value)...)> { | |||
| return {}; | |||
| } | |||
| static auto declfunction_helper(std::index_sequence<Indices...>) -> thin_function<decltype(std::declval<TRet>().value)(decltype(std::get<Indices>(declargs()).value)...)> { return {}; } | |||
| static auto declfunction() { | |||
| return declfunction_helper( | |||
| std::make_index_sequence<std::tuple_size<TInputs>::value + | |||
| std::tuple_size<TOutputs>::value>()); | |||
| return declfunction_helper(std::make_index_sequence<std::tuple_size<TInputs>::value + std::tuple_size<TOutputs>::value>()); | |||
| } | |||
| template <size_t... Indices> | |||
| static auto declbundle_helper(std::index_sequence<Indices...>) | |||
| -> ParamBundle<decltype(std::get<Indices>(declargs()))...> { | |||
| return {}; | |||
| } | |||
| static auto declbundle_helper(std::index_sequence<Indices...>) -> ParamBundle<decltype(std::get<Indices>(declargs()))...> { return {}; } | |||
| static auto declbundle() { | |||
| return declbundle_helper( | |||
| std::make_index_sequence<std::tuple_size<TInputs>::value + | |||
| std::tuple_size<TOutputs>::value>()); | |||
| return declbundle_helper(std::make_index_sequence<std::tuple_size<TInputs>::value+std::tuple_size<TOutputs>::value>()); | |||
| } | |||
| using function_t = decltype(declfunction()); | |||
| using bundle_t = decltype(declbundle()); | |||
| public: | |||
| template <typename TNewRet> | |||
| auto ret() { | |||
| static_assert(std::is_same<TRet, Param<Empty>>::value, | |||
| "return value redefinition"); | |||
| static_assert(std::is_same<TRet, RetParam<Empty>>::value, "return value redefinition"); | |||
| return FunctionCacheBuilder<TNewRet, TInputs, TOutputs>{}; | |||
| } | |||
| template <typename TNewInput> | |||
| auto input() { | |||
| using TNewInputs = decltype( | |||
| std::tuple_cat(std::declval<TInputs>(), | |||
| std::make_tuple(std::declval<TNewInput>()))); | |||
| using TNewInputs = decltype(std::tuple_cat(std::declval<TInputs>(), std::make_tuple(std::declval<TNewInput>()))); | |||
| return FunctionCacheBuilder<TRet, TNewInputs, TOutputs>{}; | |||
| } | |||
| template <typename TNewOutput> | |||
| auto output() { | |||
| using TNewOutputs = decltype( | |||
| std::tuple_cat(std::declval<TOutputs>(), | |||
| std::make_tuple(std::declval<TNewOutput>()))); | |||
| using TNewOutputs = decltype(std::tuple_cat(std::declval<TOutputs>(), std::make_tuple(std::declval<TNewOutput>()))); | |||
| return FunctionCacheBuilder<TRet, TInputs, TNewOutputs>{}; | |||
| } | |||
| template <typename TFunctor> | |||
| function_t build(TFunctor func) { | |||
| FunctionCache<bundle_t> cache; | |||
| FunctionCache<std::string(bundle_t)> cache; | |||
| cache.key_mapper = [](bundle_t bundle) { | |||
| StringSerializer ser; | |||
| bundle.template serialize_params<0, | |||
| std::tuple_size<TInputs>::value>( | |||
| ser); | |||
| bundle.template serialize_params<0, std::tuple_size<TInputs>::value>(ser); | |||
| return ser.take(); | |||
| }; | |||
| cache.value_mapper = [=](bundle_t bundle) { | |||
| @@ -237,33 +202,42 @@ public: | |||
| TRet ret; | |||
| ret.value = bundle.call_by(func); | |||
| ret.serialize(ser, Empty{}); | |||
| bundle.template serialize_params< | |||
| std::tuple_size<TInputs>::value, | |||
| std::tuple_size<TInputs>::value + | |||
| std::tuple_size<TOutputs>::value>(ser); | |||
| bundle.template serialize_params<std::tuple_size<TInputs>::value, std::tuple_size<TInputs>::value+std::tuple_size<TOutputs>::value>(ser); | |||
| return ser.take(); | |||
| }; | |||
| return [=](auto&&... args) mutable { | |||
| bundle_t bundle; | |||
| TRet ret; | |||
| StringSerializer ser; | |||
| static_assert( | |||
| sizeof...(args) == std::tuple_size<TInputs>::value + | |||
| std::tuple_size<TOutputs>::value, | |||
| "args count mismatch"); | |||
| bundle.template set_values<0, sizeof...(args)>( | |||
| std::forward<decltype(args)>(args)...); | |||
| static_assert(sizeof...(args) == std::tuple_size<TInputs>::value+std::tuple_size<TOutputs>::value, | |||
| "arg count mismatch"); | |||
| bundle.template set_values<0, sizeof...(args)>(std::forward<decltype(args)>(args)...); | |||
| ser.set(cache(bundle)); | |||
| ret.deserialize(ser, Empty{}); | |||
| constexpr size_t n_inputs = std::tuple_size<TInputs>::value; | |||
| constexpr size_t n_outputs = std::tuple_size<TOutputs>::value; | |||
| bundle.template deserialize_params<n_inputs, n_inputs + n_outputs>( | |||
| ser); | |||
| bundle.template deserialize_params<n_inputs, n_inputs+n_outputs>(ser); | |||
| return ret.value; | |||
| }; | |||
| } | |||
| }; | |||
| template <typename T> | |||
| class PlainParam { | |||
| public: | |||
| T value; | |||
| Empty serialize(StringSerializer& ser, Empty) { | |||
| ser.write_plain(value); | |||
| return Empty{}; | |||
| } | |||
| Empty deserialize(StringSerializer& ser, Empty) { | |||
| value = ser.read_plain<T>(); | |||
| return Empty{}; | |||
| } | |||
| }; | |||
| template <typename T> | |||
| class RefParam { | |||
| public: | |||
| @@ -278,6 +252,7 @@ public: | |||
| } | |||
| }; | |||
| template <typename T> | |||
| class RefArraySizeParam { | |||
| public: | |||
| @@ -291,6 +266,7 @@ public: | |||
| } | |||
| }; | |||
| template <typename TSize, typename TItem> | |||
| class ArrayParam { | |||
| public: | |||
| @@ -309,4 +285,4 @@ public: | |||
| } | |||
| }; | |||
| } // namespace megdnn | |||
| } | |||
| @@ -16,109 +16,105 @@ | |||
| #include "src/cuda/cudnn_wrapper.h" | |||
| namespace megdnn { | |||
| class CudnnConvDescParam { | |||
| public: | |||
| cudnnConvolutionDescriptor_t value; | |||
| Empty serialize(StringSerializer& ser, Empty) { | |||
| constexpr int nbDims = MEGDNN_MAX_NDIM; | |||
| int padA[MEGDNN_MAX_NDIM]; | |||
| int strideA[MEGDNN_MAX_NDIM]; | |||
| int dilationA[MEGDNN_MAX_NDIM]; | |||
| cudnnConvolutionMode_t mode; | |||
| cudnnDataType_t computeType; | |||
| cudnnGetConvolutionNdDescriptor(value, nbDims, &nbDims, padA, strideA, | |||
| dilationA, &mode, &computeType); | |||
| ser.write_plain(nbDims); | |||
| for (int i = 0; i < nbDims; ++i) { | |||
| ser.write_plain(padA[i]); | |||
| ser.write_plain(strideA[i]); | |||
| ser.write_plain(dilationA[i]); | |||
| class CudnnConvDescParam { | |||
| public: | |||
| cudnnConvolutionDescriptor_t value; | |||
| Empty serialize(StringSerializer& ser, Empty) { | |||
| int ndim = MEGDNN_MAX_NDIM; | |||
| int padA[MEGDNN_MAX_NDIM]; | |||
| int strideA[MEGDNN_MAX_NDIM]; | |||
| int dilationA[MEGDNN_MAX_NDIM]; | |||
| cudnnConvolutionMode_t mode; | |||
| cudnnDataType_t computeType; | |||
| cudnnGetConvolutionNdDescriptor(value, MEGDNN_MAX_NDIM, &ndim, padA, strideA, dilationA, &mode, &computeType); | |||
| ser.write_plain(ndim); | |||
| for (int i = 0; i < ndim; ++i) { | |||
| ser.write_plain(padA[i]); | |||
| ser.write_plain(strideA[i]); | |||
| ser.write_plain(dilationA[i]); | |||
| } | |||
| ser.write_plain(mode); | |||
| ser.write_plain(computeType); | |||
| return Empty{}; | |||
| } | |||
| ser.write_plain(mode); | |||
| ser.write_plain(computeType); | |||
| return Empty{}; | |||
| } | |||
| Empty deserialize(StringSerializer& ser, Empty) { | |||
| int ndim = ser.read_plain<int>(); | |||
| int padA[MEGDNN_MAX_NDIM]; | |||
| int strideA[MEGDNN_MAX_NDIM]; | |||
| int dilationA[MEGDNN_MAX_NDIM]; | |||
| for (int i = 0; i < ndim; ++i) { | |||
| padA[i] = ser.read_plain<int>(); | |||
| strideA[i] = ser.read_plain<int>(); | |||
| dilationA[i] = ser.read_plain<int>(); | |||
| Empty deserialize(StringSerializer& ser, Empty) { | |||
| int ndim = ser.read_plain<int>(); | |||
| int padA[MEGDNN_MAX_NDIM]; | |||
| int strideA[MEGDNN_MAX_NDIM]; | |||
| int dilationA[MEGDNN_MAX_NDIM]; | |||
| for (int i = 0; i < ndim; ++i) { | |||
| padA[i] = ser.read_plain<int>(); | |||
| strideA[i] = ser.read_plain<int>(); | |||
| dilationA[i] = ser.read_plain<int>(); | |||
| } | |||
| cudnnConvolutionMode_t mode = ser.read_plain<cudnnConvolutionMode_t>(); | |||
| cudnnDataType_t computeType = ser.read_plain<cudnnDataType_t>(); | |||
| cudnnSetConvolutionNdDescriptor(value, ndim, padA, strideA, dilationA, mode, computeType); | |||
| return Empty{}; | |||
| } | |||
| cudnnConvolutionMode_t mode = ser.read_plain<cudnnConvolutionMode_t>(); | |||
| cudnnDataType_t computeType = ser.read_plain<cudnnDataType_t>(); | |||
| cudnnSetConvolutionNdDescriptor(value, ndim, padA, strideA, dilationA, | |||
| mode, computeType); | |||
| return Empty{}; | |||
| } | |||
| }; | |||
| class CudnnTensorDescParam { | |||
| public: | |||
| cudnnTensorDescriptor_t value; | |||
| Empty serialize(StringSerializer& ser, Empty) { | |||
| constexpr int nbDims = MEGDNN_MAX_NDIM; | |||
| cudnnDataType_t dataType; | |||
| int dimA[MEGDNN_MAX_NDIM]; | |||
| int strideA[MEGDNN_MAX_NDIM]; | |||
| cudnnGetTensorNdDescriptor(value, nbDims, &dataType, &nbDims, dimA, | |||
| strideA); | |||
| ser.write_plain(nbDims); | |||
| for (int i = 0; i < nbDims; ++i) { | |||
| ser.write_plain(dimA[i]); | |||
| ser.write_plain(strideA[i]); | |||
| }; | |||
| class CudnnTensorDescParam { | |||
| public: | |||
| cudnnTensorDescriptor_t value; | |||
| Empty serialize(StringSerializer& ser, Empty) { | |||
| int nbDims = MEGDNN_MAX_NDIM; | |||
| cudnnDataType_t dataType; | |||
| int dimA[MEGDNN_MAX_NDIM]; | |||
| int strideA[MEGDNN_MAX_NDIM]; | |||
| cudnnGetTensorNdDescriptor(value, nbDims, &dataType, &nbDims, dimA, strideA); | |||
| ser.write_plain(nbDims); | |||
| for (int i = 0; i < nbDims; ++i) { | |||
| ser.write_plain(dimA[i]); | |||
| ser.write_plain(strideA[i]); | |||
| } | |||
| ser.write_plain(dataType); | |||
| return Empty{}; | |||
| } | |||
| ser.write_plain(dataType); | |||
| return Empty{}; | |||
| } | |||
| Empty deserialize(StringSerializer& ser, Empty) { | |||
| constexpr int nbDims = MEGDNN_MAX_NDIM; | |||
| cudnnDataType_t dataType; | |||
| int dimA[MEGDNN_MAX_NDIM]; | |||
| int strideA[MEGDNN_MAX_NDIM]; | |||
| nbDims = ser.read_plain<int>(); | |||
| for (int i = 0; i < nbDims; ++i) { | |||
| dimA[i] = ser.read_plain<int>(); | |||
| strideA[i] = ser.read_plain<int>(); | |||
| Empty deserialize(StringSerializer& ser, Empty) { | |||
| int nbDims = MEGDNN_MAX_NDIM; | |||
| cudnnDataType_t dataType; | |||
| int dimA[MEGDNN_MAX_NDIM]; | |||
| int strideA[MEGDNN_MAX_NDIM]; | |||
| nbDims = ser.read_plain<int>(); | |||
| for (int i = 0; i < nbDims; ++i) { | |||
| dimA[i] = ser.read_plain<int>(); | |||
| strideA[i] = ser.read_plain<int>(); | |||
| } | |||
| dataType = ser.read_plain<cudnnDataType_t>(); | |||
| cudnnSetTensorNdDescriptor(value, dataType, nbDims, dimA, strideA); | |||
| return Empty{}; | |||
| } | |||
| dataType = ser.read_plain<cudnnDataType_t>(); | |||
| cudnnSetTensorNdDescriptor(value, dataType, nbDims, dimA, strideA); | |||
| return Empty{}; | |||
| } | |||
| }; | |||
| class CudnnFilterDescParam { | |||
| public: | |||
| cudnnFilterDescriptor_t value; | |||
| Empty serialize(StringSerializer& ser, Empty) { | |||
| constexpr int nbDims = MEGDNN_MAX_NDIM; | |||
| cudnnDataType_t dataType; | |||
| cudnnTensorFormat_t format; | |||
| int filterDimA[MEGDNN_MAX_NDIM]; | |||
| cudnnGetFilterNdDescriptor(value, nbDims, &dataType, &format, &nbDims, | |||
| filterDimA); | |||
| ser.write_plain(nbDims); | |||
| for (int i = 0; i < nbDims; ++i) { | |||
| ser.write_plain(filterDimA[i]); | |||
| }; | |||
| class CudnnFilterDescParam { | |||
| public: | |||
| cudnnFilterDescriptor_t value; | |||
| Empty serialize(StringSerializer& ser, Empty) { | |||
| int nbDims = MEGDNN_MAX_NDIM; | |||
| cudnnDataType_t dataType; | |||
| cudnnTensorFormat_t format; | |||
| int filterDimA[MEGDNN_MAX_NDIM]; | |||
| cudnnGetFilterNdDescriptor(value, nbDims, &dataType, &format, &nbDims, filterDimA); | |||
| ser.write_plain(nbDims); | |||
| for (int i = 0; i < nbDims; ++i) { | |||
| ser.write_plain(filterDimA[i]); | |||
| } | |||
| ser.write_plain(dataType); | |||
| ser.write_plain(format); | |||
| return Empty{}; | |||
| } | |||
| ser.write_plain(dataType); | |||
| ser.write_plain(format); | |||
| return Empty{}; | |||
| } | |||
| Empty deserialize(StringSerializer& ser, Empty) { | |||
| constexpr int nbDims = MEGDNN_MAX_NDIM; | |||
| cudnnDataType_t dataType; | |||
| cudnnTensorFormat_t format; | |||
| int filterDimA[MEGDNN_MAX_NDIM]; | |||
| nbDims = ser.read_plain<int>(); | |||
| for (int i = 0; i < nbDims; ++i) { | |||
| filterDimA[i] = ser.read_plain<int>(); | |||
| Empty deserialize(StringSerializer& ser, Empty) { | |||
| int nbDims = MEGDNN_MAX_NDIM; | |||
| cudnnDataType_t dataType; | |||
| cudnnTensorFormat_t format; | |||
| int filterDimA[MEGDNN_MAX_NDIM]; | |||
| nbDims = ser.read_plain<int>(); | |||
| for (int i = 0; i < nbDims; ++i) { | |||
| filterDimA[i] = ser.read_plain<int>(); | |||
| } | |||
| dataType = ser.read_plain<cudnnDataType_t>(); | |||
| format = ser.read_plain<cudnnTensorFormat_t>(); | |||
| cudnnSetFilterNdDescriptor(value, dataType, format, nbDims, filterDimA); | |||
| return Empty{}; | |||
| } | |||
| dataType = ser.read_plain<cudnnDataType_t>(); | |||
| format = ser.read_plain<cudnnTensorFormat_t>(); | |||
| cudnnSetFilterNdDescriptor(value, dataType, format, nbDims, filterDimA); | |||
| return Empty{}; | |||
| } | |||
| }; | |||
| } // namespace megdnn | |||
| }; | |||
| } | |||
| @@ -56,8 +56,7 @@ bool ConvBiasForwardImpl::AlgoCUDNNConv::is_available( | |||
| conv_args.init_conv_desc(D); | |||
| size_t workspace_size; | |||
| auto& cudnn = conv_args.handle->cudnn(); | |||
| auto status = cudnn.GetConvolutionForwardWorkspaceSize( | |||
| auto status = cudnnGetConvolutionForwardWorkspaceSize( | |||
| conv_args.handle->cudnn_handle(), D.src_desc.desc, | |||
| D.filter_desc.desc, D.conv_desc.conv_desc, D.dst_desc.desc, | |||
| m_cudnn_enum, &workspace_size); | |||
| @@ -83,8 +82,7 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoCUDNNConv::get_workspace_bundle( | |||
| conv_args.init_conv_desc(D); | |||
| size_t conv_workspace_size; | |||
| auto& cudnn = conv_args.handle->cudnn(); | |||
| auto status = cudnn.GetConvolutionForwardWorkspaceSize( | |||
| auto status = cudnnGetConvolutionForwardWorkspaceSize( | |||
| conv_args.handle->cudnn_handle(), D.src_desc.desc, | |||
| D.filter_desc.desc, D.conv_desc.conv_desc, D.dst_desc.desc, | |||
| m_cudnn_enum, &conv_workspace_size); | |||
| @@ -149,8 +149,7 @@ bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available( | |||
| megdnn_throw("unsupported NonlineMode"); | |||
| } | |||
| size_t workspace_size; | |||
| auto& cudnn = args.handle->cudnn(); | |||
| auto status = cudnn.GetConvolutionForwardWorkspaceSize( | |||
| auto status = cudnnGetConvolutionForwardWorkspaceSize( | |||
| args.handle->cudnn_handle(), D.src_desc.desc, D.filter_desc.desc, | |||
| D.conv_desc.conv_desc, D.dst_desc.desc, m_cudnn_enum, | |||
| &workspace_size); | |||
| @@ -163,8 +162,7 @@ size_t ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::get_workspace_in_bytes( | |||
| args.init_conv_bias_desc(D); | |||
| size_t workspace_size; | |||
| auto& cudnn = args.handle->cudnn(); | |||
| auto status = cudnn.GetConvolutionForwardWorkspaceSize( | |||
| auto status = cudnnGetConvolutionForwardWorkspaceSize( | |||
| args.handle->cudnn_handle(), D.src_desc.desc, D.filter_desc.desc, | |||
| D.conv_desc.conv_desc, D.dst_desc.desc, m_cudnn_enum, | |||
| &workspace_size); | |||
| @@ -95,13 +95,12 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | |||
| CUDNNForwardDescs desc; | |||
| conv_args.init_conv_desc(desc); | |||
| #if CUDNN_MAJOR >= 7 | |||
| auto& cudnn = static_cast<HandleImpl*>(this->handle())->cudnn(); | |||
| int max_count = 0; | |||
| cudnn_check(cudnn.GetConvolutionForwardAlgorithmMaxCount(cudnn_handle, | |||
| cudnn_check(cudnnGetConvolutionForwardAlgorithmMaxCount(cudnn_handle, | |||
| &max_count)); | |||
| SmallVector<cudnnConvolutionFwdAlgoPerf_t> algo_perf(max_count); | |||
| int ret_count = 0; | |||
| cudnn_check(cudnn.GetConvolutionForwardAlgorithm_v7( | |||
| cudnn_check(cudnnGetConvolutionForwardAlgorithm_v7( | |||
| cudnn_handle, desc.src_desc.desc, desc.filter_desc.desc, | |||
| desc.conv_desc.conv_desc, desc.dst_desc.desc, max_count, | |||
| &ret_count, algo_perf.data())); | |||
| @@ -42,10 +42,9 @@ bool ConvolutionBackwardDataImpl::AlgoCUDNN::is_available( | |||
| if (!conv_bias::is_cudnn_supported(bias_args)) | |||
| return false; | |||
| auto& cudnn = args.handle->cudnn(); | |||
| args.init_desc(D); | |||
| size_t workspace_size; | |||
| auto status = cudnn.GetConvolutionBackwardDataWorkspaceSize( | |||
| auto status = cudnnGetConvolutionBackwardDataWorkspaceSize( | |||
| args.handle->cudnn_handle(), | |||
| D.filter_desc.desc, | |||
| D.diff_desc.desc, | |||
| @@ -58,11 +57,10 @@ bool ConvolutionBackwardDataImpl::AlgoCUDNN::is_available( | |||
| size_t ConvolutionBackwardDataImpl::AlgoCUDNN::get_workspace_in_bytes( | |||
| const SizeArgs &args) const { | |||
| auto& cudnn = args.handle->cudnn(); | |||
| CUDNNBwdDataDescs D; | |||
| args.init_desc(D); | |||
| size_t workspace_size; | |||
| auto status = cudnn.GetConvolutionBackwardDataWorkspaceSize( | |||
| auto status = cudnnGetConvolutionBackwardDataWorkspaceSize( | |||
| args.handle->cudnn_handle(), | |||
| D.filter_desc.desc, | |||
| D.diff_desc.desc, | |||
| @@ -29,7 +29,6 @@ bool ConvolutionBackwardFilterImpl::AlgoCUDNN::is_available( | |||
| return false; | |||
| } | |||
| } | |||
| auto& cudnn = args.handle->cudnn(); | |||
| CUDNNBwdFilterDescs D; | |||
| TensorLayout bias_layout, z_layout; | |||
| @@ -44,7 +43,7 @@ bool ConvolutionBackwardFilterImpl::AlgoCUDNN::is_available( | |||
| args.init_desc(D); | |||
| size_t workspace_size; | |||
| auto status = cudnn.GetConvolutionBackwardFilterWorkspaceSize( | |||
| auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize( | |||
| args.handle->cudnn_handle(), | |||
| D.src_desc.desc, | |||
| D.diff_desc.desc, | |||
| @@ -57,11 +56,10 @@ bool ConvolutionBackwardFilterImpl::AlgoCUDNN::is_available( | |||
| size_t ConvolutionBackwardFilterImpl::AlgoCUDNN::get_workspace_in_bytes( | |||
| const SizeArgs &args) const { | |||
| auto& cudnn = args.handle->cudnn(); | |||
| CUDNNBwdFilterDescs D; | |||
| args.init_desc(D); | |||
| size_t workspace_size; | |||
| auto status = cudnn.GetConvolutionBackwardFilterWorkspaceSize( | |||
| auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize( | |||
| args.handle->cudnn_handle(), | |||
| D.src_desc.desc, | |||
| D.diff_desc.desc, | |||
| @@ -144,13 +144,12 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic( | |||
| #if CUDNN_MAJOR >= 7 | |||
| MEGDNN_MARK_USED_VAR(negative_attr); | |||
| auto& cudnn = args.handle->cudnn(); | |||
| int max_count = 0; | |||
| cudnn_check(cudnn.GetConvolutionBackwardDataAlgorithmMaxCount( | |||
| cudnn_check(cudnnGetConvolutionBackwardDataAlgorithmMaxCount( | |||
| cudnn_handle, &max_count)); | |||
| SmallVector<cudnnConvolutionBwdDataAlgoPerf_t> algo_perf(max_count); | |||
| int ret_count = 0; | |||
| cudnn_check(cudnn.GetConvolutionBackwardDataAlgorithm_v7( | |||
| cudnn_check(cudnnGetConvolutionBackwardDataAlgorithm_v7( | |||
| cudnn_handle, desc.filter_desc.desc, desc.diff_desc.desc, | |||
| desc.conv_desc.desc, desc.grad_desc.desc, max_count, &ret_count, | |||
| algo_perf.data())); | |||
| @@ -280,13 +279,12 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | |||
| #endif | |||
| #if CUDNN_MAJOR >= 7 | |||
| MEGDNN_MARK_USED_VAR(negative_attr); | |||
| auto& cudnn = args.handle->cudnn(); | |||
| int max_count = 0; | |||
| cudnn_check(cudnn.GetConvolutionBackwardFilterAlgorithmMaxCount( | |||
| cudnn_check(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount( | |||
| cudnn_handle, &max_count)); | |||
| SmallVector<cudnnConvolutionBwdFilterAlgoPerf_t> algo_perf(max_count); | |||
| int ret_count = 0; | |||
| cudnn_check(cudnn.GetConvolutionBackwardFilterAlgorithm_v7( | |||
| cudnn_check(cudnnGetConvolutionBackwardFilterAlgorithm_v7( | |||
| cudnn_handle, desc.src_desc.desc, desc.diff_desc.desc, | |||
| desc.conv_desc.desc, desc.grad_desc.desc, max_count, &ret_count, | |||
| algo_perf.data())); | |||
| @@ -28,8 +28,7 @@ bool Convolution3DBackwardDataImpl::AlgoCUDNN::is_available( | |||
| args.init_desc(D); | |||
| size_t workspace_size; | |||
| auto& cudnn = args.handle->cudnn(); | |||
| auto status = cudnn.GetConvolutionBackwardDataWorkspaceSize( | |||
| auto status = cudnnGetConvolutionBackwardDataWorkspaceSize( | |||
| args.handle->cudnn_handle(), | |||
| D.filter_desc.desc, | |||
| D.diff_desc.desc, | |||
| @@ -45,8 +44,7 @@ size_t Convolution3DBackwardDataImpl::AlgoCUDNN::get_workspace_in_bytes( | |||
| CUDNNBwdDataDescs D; | |||
| args.init_desc(D); | |||
| size_t workspace_size; | |||
| auto& cudnn = args.handle->cudnn(); | |||
| auto status = cudnn.GetConvolutionBackwardDataWorkspaceSize( | |||
| auto status = cudnnGetConvolutionBackwardDataWorkspaceSize( | |||
| args.handle->cudnn_handle(), | |||
| D.filter_desc.desc, | |||
| D.diff_desc.desc, | |||
| @@ -28,8 +28,7 @@ bool Convolution3DBackwardFilterImpl::AlgoCUDNN::is_available( | |||
| args.init_desc(D); | |||
| size_t workspace_size; | |||
| auto& cudnn = args.handle->cudnn(); | |||
| auto status = cudnn.GetConvolutionBackwardFilterWorkspaceSize( | |||
| auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize( | |||
| args.handle->cudnn_handle(), D.src_desc.desc, D.diff_desc.desc, | |||
| D.conv_desc.desc, D.grad_desc.desc, m_cudnn_enum, &workspace_size); | |||
| return status == CUDNN_STATUS_SUCCESS; | |||
| @@ -41,8 +40,7 @@ size_t Convolution3DBackwardFilterImpl::AlgoCUDNN::get_workspace_in_bytes( | |||
| args.init_desc(D); | |||
| size_t workspace_size; | |||
| auto& cudnn = args.handle->cudnn(); | |||
| auto status = cudnn.GetConvolutionBackwardFilterWorkspaceSize( | |||
| auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize( | |||
| args.handle->cudnn_handle(), D.src_desc.desc, D.diff_desc.desc, | |||
| D.conv_desc.desc, D.grad_desc.desc, m_cudnn_enum, &workspace_size); | |||
| megdnn_assert(status == CUDNN_STATUS_SUCCESS, | |||
| @@ -27,8 +27,7 @@ bool Convolution3DForwardImpl::AlgoCUDNN::is_available( | |||
| args.init_desc(D); | |||
| size_t workspace_size; | |||
| auto& cudnn = args.handle->cudnn(); | |||
| auto status = cudnn.GetConvolutionForwardWorkspaceSize( | |||
| auto status = cudnnGetConvolutionForwardWorkspaceSize( | |||
| args.handle->cudnn_handle(), | |||
| D.src_desc.desc, | |||
| D.filter_desc.desc, | |||
| @@ -44,8 +43,7 @@ size_t Convolution3DForwardImpl::AlgoCUDNN::get_workspace_in_bytes( | |||
| CUDNNForwardDescs D; | |||
| args.init_desc(D); | |||
| size_t workspace_size; | |||
| auto& cudnn = args.handle->cudnn(); | |||
| auto status = cudnn.GetConvolutionForwardWorkspaceSize( | |||
| auto status = cudnnGetConvolutionForwardWorkspaceSize( | |||
| args.handle->cudnn_handle(), | |||
| D.src_desc.desc, | |||
| D.filter_desc.desc, | |||
| @@ -93,7 +93,7 @@ namespace convolution3d { | |||
| const Workspace &workspace, void *&raw_ptr); | |||
| inline bool cudnn_get_convolution_fwd_algo_helper( | |||
| Handle* handle, const cudnnTensorDescriptor_t x_desc, | |||
| cudnnHandle_t cudnn_handle, const cudnnTensorDescriptor_t x_desc, | |||
| const cudnnFilterDescriptor_t w_desc, | |||
| const cudnnConvolutionDescriptor_t conv_desc, | |||
| const cudnnTensorDescriptor_t y_desc, | |||
| @@ -103,14 +103,13 @@ namespace convolution3d { | |||
| MEGDNN_MARK_USED_VAR(positive_attr); | |||
| MEGDNN_MARK_USED_VAR(negative_attr); | |||
| #if CUDNN_MAJOR >= 7 | |||
| auto& cudnn = static_cast<HandleImpl*>(handle)->cudnn(); | |||
| int algo_max_count = 0; | |||
| cudnn_check(cudnn.GetConvolutionForwardAlgorithmMaxCount( | |||
| cuda::cudnn_handle(handle), &algo_max_count)); | |||
| cudnn_check(cudnnGetConvolutionForwardAlgorithmMaxCount( | |||
| cudnn_handle, &algo_max_count)); | |||
| SmallVector<cudnnConvolutionFwdAlgoPerf_t> algo_perf(algo_max_count); | |||
| int algo_count = 0; | |||
| cudnn_check(cudnn.GetConvolutionForwardAlgorithm_v7( | |||
| cuda::cudnn_handle(handle), x_desc, w_desc, conv_desc, y_desc, algo_max_count, | |||
| cudnn_check(cudnnGetConvolutionForwardAlgorithm_v7( | |||
| cudnn_handle, x_desc, w_desc, conv_desc, y_desc, algo_max_count, | |||
| &algo_count, algo_perf.data())); | |||
| for (int i = 0; i < algo_count; ++i) { | |||
| if (algo_perf[i].algo == | |||
| @@ -118,8 +117,8 @@ namespace convolution3d { | |||
| CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING) | |||
| continue; | |||
| size_t workspace_size = 0; | |||
| cudnn_check(cudnn.GetConvolutionForwardWorkspaceSize( | |||
| cuda::cudnn_handle(handle), x_desc, w_desc, conv_desc, y_desc, | |||
| cudnn_check(cudnnGetConvolutionForwardWorkspaceSize( | |||
| cudnn_handle, x_desc, w_desc, conv_desc, y_desc, | |||
| algo_perf[i].algo, &workspace_size)); | |||
| if (workspace_size > workspace_limit_in_bytes) continue; | |||
| if (!(positive_attr & AlgoAttribute::REPRODUCIBLE)) { | |||
| @@ -135,7 +134,7 @@ namespace convolution3d { | |||
| return false; | |||
| #else | |||
| cudnn_check(cudnnGetConvolutionForwardAlgorithm( | |||
| cuda::cudnn_handle(handle), x_desc, w_desc, conv_desc, y_desc, | |||
| cudnn_handle, x_desc, w_desc, conv_desc, y_desc, | |||
| CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, | |||
| workspace_limit_in_bytes, algo)); | |||
| return true; | |||
| @@ -64,12 +64,13 @@ Convolution3DForwardImpl::get_algorithm_heuristic( | |||
| auto get_cudnn_algo = | |||
| [this, &args, workspace_limit_in_bytes, positive_attr, | |||
| negative_attr]() -> Convolution3DForwardImpl::AlgoBase* { | |||
| auto cudnn_handle = cuda::cudnn_handle(this->handle()); | |||
| cudnnConvolutionFwdAlgo_t algo; | |||
| CUDNNForwardDescs desc; | |||
| args.init_desc(desc); | |||
| bool got = cudnn_get_convolution_fwd_algo_helper( | |||
| this->handle(), desc.src_desc.desc, desc.filter_desc.desc, | |||
| cudnn_handle, desc.src_desc.desc, desc.filter_desc.desc, | |||
| desc.conv_desc.desc, desc.dst_desc.desc, | |||
| workspace_limit_in_bytes, &algo, positive_attr, negative_attr); | |||
| if (got) { | |||
| @@ -56,7 +56,7 @@ namespace convolution { | |||
| using KernLayout = _kern_layout; \ | |||
| using OutputLayout = _output_layout; \ | |||
| using Param = _conv_param; \ | |||
| static constexpr bool check_bounds = check_bounds_ | |||
| static constexpr bool check_bounds = check_bounds_; | |||
| #define MEGDNN_COMMA , | |||
| template <bool check_bounds_, typename src_ldg_dtype, typename filter_ldg_dtype, | |||
| @@ -53,7 +53,7 @@ namespace convolution { | |||
| using KernLayout = _kern_layout; \ | |||
| using OutputLayout = _output_layout; \ | |||
| using Param = _conv_param; \ | |||
| static constexpr bool check_bounds = check_bounds_ | |||
| static constexpr bool check_bounds = check_bounds_; | |||
| #define MEGDNN_COMMA , | |||
| template <bool check_bounds_, typename IMMAConfig_, typename WarpTileConfig_, | |||
| @@ -53,7 +53,7 @@ namespace convolution { | |||
| using KernLayout = _kern_layout; \ | |||
| using OutputLayout = _output_layout; \ | |||
| using Param = _conv_param; \ | |||
| static constexpr bool check_bounds = check_bounds_ | |||
| static constexpr bool check_bounds = check_bounds_; | |||
| #define MEGDNN_COMMA , | |||
| template <bool check_bounds_, typename ldg_dtype, typename RegBlockConfig_, | |||
| @@ -11,16 +11,13 @@ | |||
| #include "src/common/handle_impl.h" | |||
| #include "src/common/version_symbol.h" | |||
| #include "src/common/api_cache.h" | |||
| #include "src/cuda/handle.h" | |||
| #include "src/cuda/utils.h" | |||
| #include "src/cuda/api_cache.h" | |||
| #include "megdnn/common.h" | |||
| #include <cuda.h> | |||
| #include <cstring> | |||
| #include <memory> | |||
| #define STR_HELPER(x) #x | |||
| #define STR(x) STR_HELPER(x) | |||
| @@ -94,8 +91,6 @@ HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle): | |||
| // check tk1 | |||
| m_is_tegra_k1 = (strcmp(m_device_prop->name, "GK20A") == 0); | |||
| m_cusolver_handle = nullptr; | |||
| m_cudnn_api_cache = std::make_unique<CUDNN>(m_cudnn_handle); | |||
| } | |||
| HandleImpl::~HandleImpl() noexcept { | |||
| @@ -141,111 +136,8 @@ HandleImpl::HandleVendorType HandleImpl::vendor_type() const { | |||
| return HandleVendorType::CUDA; | |||
| } | |||
| HandleImpl::CUDNN& HandleImpl::cudnn() { | |||
| return *m_cudnn_api_cache; | |||
| } | |||
| HandleImpl::CUDNN::CUDNN(cudnnHandle_t handle) { | |||
| m_handle = handle; | |||
| GetConvolutionForwardWorkspaceSize = | |||
| FunctionCacheBuilder<>() | |||
| .input<Param<cudnnHandle_t>>() | |||
| .input<CudnnTensorDescParam>() | |||
| .input<CudnnFilterDescParam>() | |||
| .input<CudnnConvDescParam>() | |||
| .input<CudnnTensorDescParam>() | |||
| .input<Param<cudnnConvolutionFwdAlgo_t>>() | |||
| .output<RefParam<size_t>>() | |||
| .ret<Param<cudnnStatus_t>>() | |||
| .build(&cudnnGetConvolutionForwardWorkspaceSize); | |||
| #if CUDNN_MAJOR >= 7 | |||
| GetConvolutionForwardAlgorithm_v7 = | |||
| FunctionCacheBuilder<>() | |||
| .input<Param<cudnnHandle_t>>() | |||
| .input<CudnnTensorDescParam>() | |||
| .input<CudnnFilterDescParam>() | |||
| .input<CudnnConvDescParam>() | |||
| .input<CudnnTensorDescParam>() | |||
| .input<Param<int>>() | |||
| .output<RefArraySizeParam<int>>() | |||
| .output<ArrayParam<int, cudnnConvolutionFwdAlgoPerf_t>>() | |||
| .ret<Param<cudnnStatus_t>>() | |||
| .build(&cudnnGetConvolutionForwardAlgorithm_v7); | |||
| GetConvolutionForwardAlgorithmMaxCount = | |||
| FunctionCacheBuilder<>() | |||
| .input<Param<cudnnHandle_t>>() | |||
| .output<RefParam<int>>() | |||
| .ret<Param<cudnnStatus_t>>() | |||
| .build(&cudnnGetConvolutionForwardAlgorithmMaxCount); | |||
| #endif | |||
| GetConvolutionBackwardDataWorkspaceSize = | |||
| FunctionCacheBuilder<>() | |||
| .input<Param<cudnnHandle_t>>() | |||
| .input<CudnnFilterDescParam>() | |||
| .input<CudnnTensorDescParam>() | |||
| .input<CudnnConvDescParam>() | |||
| .input<CudnnTensorDescParam>() | |||
| .input<Param<cudnnConvolutionBwdDataAlgo_t>>() | |||
| .output<RefParam<size_t>>() | |||
| .ret<Param<cudnnStatus_t>>() | |||
| .build(&cudnnGetConvolutionBackwardDataWorkspaceSize); | |||
| #if CUDNN_MAJOR >= 7 | |||
| GetConvolutionBackwardDataAlgorithm_v7 = | |||
| FunctionCacheBuilder<>() | |||
| .input<Param<cudnnHandle_t>>() | |||
| .input<CudnnFilterDescParam>() | |||
| .input<CudnnTensorDescParam>() | |||
| .input<CudnnConvDescParam>() | |||
| .input<CudnnTensorDescParam>() | |||
| .input<Param<int>>() | |||
| .output<RefArraySizeParam<int>>() | |||
| .output<ArrayParam<int, | |||
| cudnnConvolutionBwdDataAlgoPerf_t>>() | |||
| .ret<Param<cudnnStatus_t>>() | |||
| .build(&cudnnGetConvolutionBackwardDataAlgorithm_v7); | |||
| GetConvolutionBackwardDataAlgorithmMaxCount = | |||
| FunctionCacheBuilder<>() | |||
| .input<Param<cudnnHandle_t>>() | |||
| .output<RefParam<int>>() | |||
| .ret<Param<cudnnStatus_t>>() | |||
| .build(&cudnnGetConvolutionBackwardDataAlgorithmMaxCount); | |||
| #endif | |||
| GetConvolutionBackwardFilterWorkspaceSize = | |||
| FunctionCacheBuilder<>() | |||
| .input<Param<cudnnHandle_t>>() | |||
| .input<CudnnTensorDescParam>() | |||
| .input<CudnnTensorDescParam>() | |||
| .input<CudnnConvDescParam>() | |||
| .input<CudnnFilterDescParam>() | |||
| .input<Param<cudnnConvolutionBwdFilterAlgo_t>>() | |||
| .output<RefParam<size_t>>() | |||
| .ret<Param<cudnnStatus_t>>() | |||
| .build(&cudnnGetConvolutionBackwardFilterWorkspaceSize); | |||
| #if CUDNN_MAJOR >= 7 | |||
| GetConvolutionBackwardFilterAlgorithm_v7 = | |||
| FunctionCacheBuilder<>() | |||
| .input<Param<cudnnHandle_t>>() | |||
| .input<CudnnTensorDescParam>() | |||
| .input<CudnnTensorDescParam>() | |||
| .input<CudnnConvDescParam>() | |||
| .input<CudnnFilterDescParam>() | |||
| .input<Param<int>>() | |||
| .output<RefArraySizeParam<int>>() | |||
| .output<ArrayParam<int, | |||
| cudnnConvolutionBwdFilterAlgoPerf_t>>() | |||
| .ret<Param<cudnnStatus_t>>() | |||
| .build(&cudnnGetConvolutionBackwardFilterAlgorithm_v7); | |||
| GetConvolutionBackwardFilterAlgorithmMaxCount = | |||
| FunctionCacheBuilder<>() | |||
| .input<Param<cudnnHandle_t>>() | |||
| .output<RefParam<int>>() | |||
| .ret<Param<cudnnStatus_t>>() | |||
| .build(&cudnnGetConvolutionBackwardFilterAlgorithmMaxCount); | |||
| #endif | |||
| } | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||
| MEGDNN_VERSION_SYMBOL(CUDA, CUDA_VERSION); | |||
| MEGDNN_VERSION_SYMBOL3(CUDNN, CUDNN_MAJOR, CUDNN_MINOR, CUDNN_PATCHLEVEL); | |||
| @@ -124,10 +124,6 @@ class HandleImpl: public HandleImplHelper { | |||
| size_t image2d_pitch_alignment() const override; | |||
| HandleVendorType vendor_type() const override; | |||
| class CUDNN; | |||
| CUDNN& cudnn(); | |||
| private: | |||
| bool m_is_tegra_k1; | |||
| int m_device_id; | |||
| @@ -160,34 +156,9 @@ class HandleImpl: public HandleImplHelper { | |||
| //! device ptr to const scalars | |||
| ConstScalars* m_const_scalars; | |||
| std::unique_ptr<CUDNN> m_cudnn_api_cache; | |||
| void initialize_cusolver(); | |||
| }; | |||
| class HandleImpl::CUDNN { | |||
| cudnnHandle_t m_handle; | |||
| public: | |||
| CUDNN(cudnnHandle_t handle); | |||
| #define WRAP_CUDNN_API(NAME) thin_function<decltype(cudnn##NAME)> NAME; | |||
| WRAP_CUDNN_API(GetConvolutionForwardWorkspaceSize); | |||
| #if CUDNN_MAJOR >= 7 | |||
| WRAP_CUDNN_API(GetConvolutionForwardAlgorithm_v7); | |||
| WRAP_CUDNN_API(GetConvolutionForwardAlgorithmMaxCount); | |||
| #endif | |||
| #if CUDNN_MAJOR >= 7 | |||
| WRAP_CUDNN_API(GetConvolutionBackwardDataAlgorithm_v7); | |||
| WRAP_CUDNN_API(GetConvolutionBackwardDataAlgorithmMaxCount); | |||
| #endif | |||
| WRAP_CUDNN_API(GetConvolutionBackwardDataWorkspaceSize); | |||
| #if CUDNN_MAJOR >= 7 | |||
| WRAP_CUDNN_API(GetConvolutionBackwardFilterAlgorithmMaxCount); | |||
| WRAP_CUDNN_API(GetConvolutionBackwardFilterAlgorithm_v7); | |||
| #endif | |||
| WRAP_CUDNN_API(GetConvolutionBackwardFilterWorkspaceSize); | |||
| #undef WRAP_CUDNN_API | |||
| }; | |||
| } // namespace cuda | |||
| } // namespace megdnn | |||