This reverts committags/v1.7.0188c62cdd6. GitOrigin-RevId:92a82b8cd9
| @@ -12,28 +12,32 @@ | |||||
| #pragma once | #pragma once | ||||
| #include <cstring> | |||||
| #include <unordered_map> | |||||
| #include <memory> | #include <memory> | ||||
| #include <cstring> | |||||
| #include <tuple> | #include <tuple> | ||||
| #include <unordered_map> | |||||
| #include "megdnn/thin/function.h" | #include "megdnn/thin/function.h" | ||||
| namespace megdnn { | namespace megdnn { | ||||
| template <typename... TArgs> | |||||
| class FunctionCache { | |||||
| template <typename TSignature> | |||||
| class FunctionCache; | |||||
| template <typename TRet, typename... TArgs> | |||||
| class FunctionCache<TRet(TArgs...)> { | |||||
| public: | public: | ||||
| using key_t = std::string; | using key_t = std::string; | ||||
| using value_t = std::string; | |||||
| using value_t = TRet; | |||||
| using key_mapper_t = thin_function<key_t(TArgs...)>; | using key_mapper_t = thin_function<key_t(TArgs...)>; | ||||
| using value_mapper_t = thin_function<value_t(TArgs...)>; | using value_mapper_t = thin_function<value_t(TArgs...)>; | ||||
| using storage_t = std::unordered_map<key_t, value_t>; | using storage_t = std::unordered_map<key_t, value_t>; | ||||
| public: | |||||
| storage_t storage; | storage_t storage; | ||||
| key_mapper_t key_mapper; | key_mapper_t key_mapper; | ||||
| value_mapper_t value_mapper; | value_mapper_t value_mapper; | ||||
| value_t operator()(TArgs... args) { | |||||
| public: | |||||
| TRet operator()(TArgs... args) { | |||||
| key_t key = key_mapper(args...); | key_t key = key_mapper(args...); | ||||
| if (storage.count(key) == 0) { | if (storage.count(key) == 0) { | ||||
| storage[key] = value_mapper(std::forward<TArgs>(args)...); | storage[key] = value_mapper(std::forward<TArgs>(args)...); | ||||
| @@ -42,28 +46,28 @@ public: | |||||
| } | } | ||||
| }; | }; | ||||
| // FIFO | // FIFO | ||||
| class StringSerializer { | class StringSerializer { | ||||
| private: | private: | ||||
| std::string m_buffer; | std::string m_buffer; | ||||
| size_t m_cursor = 0; | size_t m_cursor = 0; | ||||
| public: | public: | ||||
| template <typename T> | template <typename T> | ||||
| T read_plain() { | T read_plain() { | ||||
| static_assert(std::is_trivially_copyable<T>::value, "invalid type"); | |||||
| T ret; | |||||
| memcpy(&ret, m_buffer.data() + m_cursor, sizeof(T)); | |||||
| T result; | |||||
| std::memcpy(&result, m_buffer.data() + m_cursor, sizeof(T)); | |||||
| m_cursor += sizeof(T); | m_cursor += sizeof(T); | ||||
| return ret; | |||||
| return result; | |||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| void write_plain(T value) { | void write_plain(T value) { | ||||
| static_assert(std::is_trivially_copyable<T>::value, | |||||
| "type should be trivially copyable"); | |||||
| m_buffer.append(reinterpret_cast<const char*>(&value), sizeof(T)); | |||||
| m_buffer.resize(m_buffer.size() + sizeof(T)); | |||||
| std::memcpy(const_cast<char*>(m_buffer.data()) + (m_buffer.size() - sizeof(T)), &value, sizeof(T)); | |||||
| } | } | ||||
| std::string take() { | std::string take() { | ||||
| std::string result; | |||||
| m_buffer.erase(0, m_cursor); | |||||
| return std::move(m_buffer); | return std::move(m_buffer); | ||||
| } | } | ||||
| void set(std::string new_buf) { | void set(std::string new_buf) { | ||||
| @@ -72,20 +76,20 @@ public: | |||||
| } | } | ||||
| }; | }; | ||||
| struct Empty {}; | struct Empty {}; | ||||
| template <typename... TParams> | template <typename... TParams> | ||||
| class ParamBundle { | class ParamBundle { | ||||
| private: | private: | ||||
| template <std::size_t N, std::size_t... Seq> | |||||
| static std::index_sequence<N + Seq...> add_all( | |||||
| std::index_sequence<Seq...>) { | |||||
| template<std::size_t N, std::size_t... Seq> | |||||
| static std::index_sequence<N + Seq ...> add_all(std::index_sequence<Seq...>){ | |||||
| return {}; | return {}; | ||||
| } | } | ||||
| template <std::size_t Min, std::size_t Max> | |||||
| using make_index_range = | |||||
| decltype(add_all<Min>(std::make_index_sequence<Max - Min>())); | |||||
| template<std::size_t Min, std::size_t Max> | |||||
| using make_index_range = decltype(add_all<Min>(std::make_index_sequence<Max-Min>())); | |||||
| using storage_t = std::tuple<typename std::remove_reference_t<TParams>...>; | using storage_t = std::tuple<typename std::remove_reference_t<TParams>...>; | ||||
| storage_t m_storage; | storage_t m_storage; | ||||
| @@ -95,31 +99,21 @@ private: | |||||
| return functor(std::get<Indices>(m_storage).value...); | return functor(std::get<Indices>(m_storage).value...); | ||||
| } | } | ||||
| template <size_t Index, size_t... Indices, typename TPrev> | template <size_t Index, size_t... Indices, typename TPrev> | ||||
| auto serialize_helper(StringSerializer& ser, TPrev&& prev, | |||||
| std::index_sequence<Index, Indices...>) { | |||||
| return serialize_helper(ser, | |||||
| std::get<Index>(m_storage).serialize(ser, prev), | |||||
| std::index_sequence<Indices...>()); | |||||
| auto serialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence<Index, Indices...>) { | |||||
| return serialize_helper(ser, std::get<Index>(m_storage).serialize(ser, prev), std::index_sequence<Indices...>()); | |||||
| } | } | ||||
| template <typename TPrev> | template <typename TPrev> | ||||
| auto serialize_helper(StringSerializer& ser, TPrev&& prev, | |||||
| std::index_sequence<>) {} | |||||
| auto serialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence<>) {} | |||||
| template <size_t Index, size_t... Indices, typename TPrev> | template <size_t Index, size_t... Indices, typename TPrev> | ||||
| auto deserialize_helper(StringSerializer& ser, TPrev&& prev, | |||||
| std::index_sequence<Index, Indices...>) { | |||||
| return deserialize_helper( | |||||
| ser, std::get<Index>(m_storage).deserialize(ser, prev), | |||||
| std::index_sequence<Indices...>()); | |||||
| auto deserialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence<Index, Indices...>) { | |||||
| return deserialize_helper(ser, std::get<Index>(m_storage).deserialize(ser, prev), std::index_sequence<Indices...>()); | |||||
| } | } | ||||
| template <typename TPrev> | template <typename TPrev> | ||||
| auto deserialize_helper(StringSerializer& ser, TPrev&& prev, | |||||
| std::index_sequence<>) {} | |||||
| auto deserialize_helper(StringSerializer& ser, TPrev&& prev, std::index_sequence<>) {} | |||||
| template <size_t Index, size_t... Indices, typename TArg, typename... TArgs> | template <size_t Index, size_t... Indices, typename TArg, typename... TArgs> | ||||
| void set_values_helper(std::index_sequence<Index, Indices...>, TArg&& arg, | |||||
| TArgs&&... args) { | |||||
| void set_values_helper(std::index_sequence<Index, Indices...>, TArg&& arg, TArgs&&... args) { | |||||
| std::get<Index>(m_storage).value = arg; | std::get<Index>(m_storage).value = arg; | ||||
| set_values_helper(std::index_sequence<Indices...>(), | |||||
| std::forward<TArgs>(args)...); | |||||
| set_values_helper(std::index_sequence<Indices...>(), std::forward<TArgs>(args)...); | |||||
| } | } | ||||
| template <size_t... Indices> | template <size_t... Indices> | ||||
| void set_values_helper(std::index_sequence<Indices...>) { | void set_values_helper(std::index_sequence<Indices...>) { | ||||
| @@ -129,33 +123,27 @@ private: | |||||
| public: | public: | ||||
| template <typename TFunctor> | template <typename TFunctor> | ||||
| auto call_by(TFunctor&& functor) { | auto call_by(TFunctor&& functor) { | ||||
| return call_helper(std::forward<TFunctor>(functor), | |||||
| std::make_index_sequence<sizeof...(TParams)>()); | |||||
| return call_helper(std::forward<TFunctor>(functor), std::make_index_sequence<sizeof...(TParams)>()); | |||||
| } | } | ||||
| template <size_t NBegin, size_t NEnd> | template <size_t NBegin, size_t NEnd> | ||||
| void serialize_params(StringSerializer& ser) { | void serialize_params(StringSerializer& ser) { | ||||
| static_assert(NEnd >= NBegin, "invalid range"); | static_assert(NEnd >= NBegin, "invalid range"); | ||||
| serialize_helper( | |||||
| ser, Empty{}, | |||||
| add_all<NBegin>(std::make_index_sequence<NEnd - NBegin>())); | |||||
| serialize_helper(ser, Empty{}, make_index_range<NBegin, NEnd>()); | |||||
| } | } | ||||
| template <size_t NBegin, size_t NEnd> | template <size_t NBegin, size_t NEnd> | ||||
| void deserialize_params(StringSerializer& ser) { | void deserialize_params(StringSerializer& ser) { | ||||
| static_assert(NEnd >= NBegin, "invalid range"); | static_assert(NEnd >= NBegin, "invalid range"); | ||||
| deserialize_helper( | |||||
| ser, Empty{}, | |||||
| add_all<NBegin>(std::make_index_sequence<NEnd - NBegin>())); | |||||
| deserialize_helper(ser, Empty{}, make_index_range<NBegin, NEnd>()); | |||||
| } | } | ||||
| template <size_t NBegin, size_t NEnd, typename... TArgs> | template <size_t NBegin, size_t NEnd, typename... TArgs> | ||||
| void set_values(TArgs&&... args) { | void set_values(TArgs&&... args) { | ||||
| set_values_helper( | |||||
| add_all<NBegin>(std::make_index_sequence<NEnd - NBegin>()), | |||||
| std::forward<TArgs>(args)...); | |||||
| set_values_helper(make_index_range<NBegin, NEnd>(), std::forward<TArgs>(args)...); | |||||
| } | } | ||||
| }; | }; | ||||
| template <typename T> | template <typename T> | ||||
| class Param { | |||||
| class RetParam { | |||||
| public: | public: | ||||
| T value; | T value; | ||||
| Empty serialize(StringSerializer& ser, Empty) { | Empty serialize(StringSerializer& ser, Empty) { | ||||
| @@ -168,68 +156,45 @@ public: | |||||
| } | } | ||||
| }; | }; | ||||
| template <typename TRet = Param<Empty>, typename TInputs = std::tuple<>, | |||||
| typename TOutputs = std::tuple<>> | |||||
| template <typename TRet=RetParam<Empty>, typename TInputs=std::tuple<>, typename TOutputs=std::tuple<>> | |||||
| class FunctionCacheBuilder { | class FunctionCacheBuilder { | ||||
| private: | private: | ||||
| static auto declargs() | |||||
| -> decltype(std::tuple_cat(std::declval<TInputs>(), | |||||
| std::declval<TOutputs>())) { | |||||
| return {}; | |||||
| } | |||||
| static auto declargs() -> decltype(std::tuple_cat(std::declval<TInputs>(), std::declval<TOutputs>())) { return {}; } | |||||
| template <size_t... Indices> | template <size_t... Indices> | ||||
| static auto declfunction_helper(std::index_sequence<Indices...>) | |||||
| -> thin_function<decltype(std::declval<TRet>().value)( | |||||
| decltype(std::get<Indices>(declargs()).value)...)> { | |||||
| return {}; | |||||
| } | |||||
| static auto declfunction_helper(std::index_sequence<Indices...>) -> thin_function<decltype(std::declval<TRet>().value)(decltype(std::get<Indices>(declargs()).value)...)> { return {}; } | |||||
| static auto declfunction() { | static auto declfunction() { | ||||
| return declfunction_helper( | |||||
| std::make_index_sequence<std::tuple_size<TInputs>::value + | |||||
| std::tuple_size<TOutputs>::value>()); | |||||
| return declfunction_helper(std::make_index_sequence<std::tuple_size<TInputs>::value + std::tuple_size<TOutputs>::value>()); | |||||
| } | } | ||||
| template <size_t... Indices> | template <size_t... Indices> | ||||
| static auto declbundle_helper(std::index_sequence<Indices...>) | |||||
| -> ParamBundle<decltype(std::get<Indices>(declargs()))...> { | |||||
| return {}; | |||||
| } | |||||
| static auto declbundle_helper(std::index_sequence<Indices...>) -> ParamBundle<decltype(std::get<Indices>(declargs()))...> { return {}; } | |||||
| static auto declbundle() { | static auto declbundle() { | ||||
| return declbundle_helper( | |||||
| std::make_index_sequence<std::tuple_size<TInputs>::value + | |||||
| std::tuple_size<TOutputs>::value>()); | |||||
| return declbundle_helper(std::make_index_sequence<std::tuple_size<TInputs>::value+std::tuple_size<TOutputs>::value>()); | |||||
| } | } | ||||
| using function_t = decltype(declfunction()); | using function_t = decltype(declfunction()); | ||||
| using bundle_t = decltype(declbundle()); | using bundle_t = decltype(declbundle()); | ||||
| public: | public: | ||||
| template <typename TNewRet> | template <typename TNewRet> | ||||
| auto ret() { | auto ret() { | ||||
| static_assert(std::is_same<TRet, Param<Empty>>::value, | |||||
| "return value redefinition"); | |||||
| static_assert(std::is_same<TRet, RetParam<Empty>>::value, "return value redefinition"); | |||||
| return FunctionCacheBuilder<TNewRet, TInputs, TOutputs>{}; | return FunctionCacheBuilder<TNewRet, TInputs, TOutputs>{}; | ||||
| } | } | ||||
| template <typename TNewInput> | template <typename TNewInput> | ||||
| auto input() { | auto input() { | ||||
| using TNewInputs = decltype( | |||||
| std::tuple_cat(std::declval<TInputs>(), | |||||
| std::make_tuple(std::declval<TNewInput>()))); | |||||
| using TNewInputs = decltype(std::tuple_cat(std::declval<TInputs>(), std::make_tuple(std::declval<TNewInput>()))); | |||||
| return FunctionCacheBuilder<TRet, TNewInputs, TOutputs>{}; | return FunctionCacheBuilder<TRet, TNewInputs, TOutputs>{}; | ||||
| } | } | ||||
| template <typename TNewOutput> | template <typename TNewOutput> | ||||
| auto output() { | auto output() { | ||||
| using TNewOutputs = decltype( | |||||
| std::tuple_cat(std::declval<TOutputs>(), | |||||
| std::make_tuple(std::declval<TNewOutput>()))); | |||||
| using TNewOutputs = decltype(std::tuple_cat(std::declval<TOutputs>(), std::make_tuple(std::declval<TNewOutput>()))); | |||||
| return FunctionCacheBuilder<TRet, TInputs, TNewOutputs>{}; | return FunctionCacheBuilder<TRet, TInputs, TNewOutputs>{}; | ||||
| } | } | ||||
| template <typename TFunctor> | template <typename TFunctor> | ||||
| function_t build(TFunctor func) { | function_t build(TFunctor func) { | ||||
| FunctionCache<bundle_t> cache; | |||||
| FunctionCache<std::string(bundle_t)> cache; | |||||
| cache.key_mapper = [](bundle_t bundle) { | cache.key_mapper = [](bundle_t bundle) { | ||||
| StringSerializer ser; | StringSerializer ser; | ||||
| bundle.template serialize_params<0, | |||||
| std::tuple_size<TInputs>::value>( | |||||
| ser); | |||||
| bundle.template serialize_params<0, std::tuple_size<TInputs>::value>(ser); | |||||
| return ser.take(); | return ser.take(); | ||||
| }; | }; | ||||
| cache.value_mapper = [=](bundle_t bundle) { | cache.value_mapper = [=](bundle_t bundle) { | ||||
| @@ -237,33 +202,42 @@ public: | |||||
| TRet ret; | TRet ret; | ||||
| ret.value = bundle.call_by(func); | ret.value = bundle.call_by(func); | ||||
| ret.serialize(ser, Empty{}); | ret.serialize(ser, Empty{}); | ||||
| bundle.template serialize_params< | |||||
| std::tuple_size<TInputs>::value, | |||||
| std::tuple_size<TInputs>::value + | |||||
| std::tuple_size<TOutputs>::value>(ser); | |||||
| bundle.template serialize_params<std::tuple_size<TInputs>::value, std::tuple_size<TInputs>::value+std::tuple_size<TOutputs>::value>(ser); | |||||
| return ser.take(); | return ser.take(); | ||||
| }; | }; | ||||
| return [=](auto&&... args) mutable { | return [=](auto&&... args) mutable { | ||||
| bundle_t bundle; | bundle_t bundle; | ||||
| TRet ret; | TRet ret; | ||||
| StringSerializer ser; | StringSerializer ser; | ||||
| static_assert( | |||||
| sizeof...(args) == std::tuple_size<TInputs>::value + | |||||
| std::tuple_size<TOutputs>::value, | |||||
| "args count mismatch"); | |||||
| bundle.template set_values<0, sizeof...(args)>( | |||||
| std::forward<decltype(args)>(args)...); | |||||
| static_assert(sizeof...(args) == std::tuple_size<TInputs>::value+std::tuple_size<TOutputs>::value, | |||||
| "arg count mismatch"); | |||||
| bundle.template set_values<0, sizeof...(args)>(std::forward<decltype(args)>(args)...); | |||||
| ser.set(cache(bundle)); | ser.set(cache(bundle)); | ||||
| ret.deserialize(ser, Empty{}); | ret.deserialize(ser, Empty{}); | ||||
| constexpr size_t n_inputs = std::tuple_size<TInputs>::value; | constexpr size_t n_inputs = std::tuple_size<TInputs>::value; | ||||
| constexpr size_t n_outputs = std::tuple_size<TOutputs>::value; | constexpr size_t n_outputs = std::tuple_size<TOutputs>::value; | ||||
| bundle.template deserialize_params<n_inputs, n_inputs + n_outputs>( | |||||
| ser); | |||||
| bundle.template deserialize_params<n_inputs, n_inputs+n_outputs>(ser); | |||||
| return ret.value; | return ret.value; | ||||
| }; | }; | ||||
| } | } | ||||
| }; | }; | ||||
| template <typename T> | |||||
| class PlainParam { | |||||
| public: | |||||
| T value; | |||||
| Empty serialize(StringSerializer& ser, Empty) { | |||||
| ser.write_plain(value); | |||||
| return Empty{}; | |||||
| } | |||||
| Empty deserialize(StringSerializer& ser, Empty) { | |||||
| value = ser.read_plain<T>(); | |||||
| return Empty{}; | |||||
| } | |||||
| }; | |||||
| template <typename T> | template <typename T> | ||||
| class RefParam { | class RefParam { | ||||
| public: | public: | ||||
| @@ -278,6 +252,7 @@ public: | |||||
| } | } | ||||
| }; | }; | ||||
| template <typename T> | template <typename T> | ||||
| class RefArraySizeParam { | class RefArraySizeParam { | ||||
| public: | public: | ||||
| @@ -291,6 +266,7 @@ public: | |||||
| } | } | ||||
| }; | }; | ||||
| template <typename TSize, typename TItem> | template <typename TSize, typename TItem> | ||||
| class ArrayParam { | class ArrayParam { | ||||
| public: | public: | ||||
| @@ -309,4 +285,4 @@ public: | |||||
| } | } | ||||
| }; | }; | ||||
| } // namespace megdnn | |||||
| } | |||||
| @@ -16,109 +16,105 @@ | |||||
| #include "src/cuda/cudnn_wrapper.h" | #include "src/cuda/cudnn_wrapper.h" | ||||
| namespace megdnn { | namespace megdnn { | ||||
| class CudnnConvDescParam { | |||||
| public: | |||||
| cudnnConvolutionDescriptor_t value; | |||||
| Empty serialize(StringSerializer& ser, Empty) { | |||||
| constexpr int nbDims = MEGDNN_MAX_NDIM; | |||||
| int padA[MEGDNN_MAX_NDIM]; | |||||
| int strideA[MEGDNN_MAX_NDIM]; | |||||
| int dilationA[MEGDNN_MAX_NDIM]; | |||||
| cudnnConvolutionMode_t mode; | |||||
| cudnnDataType_t computeType; | |||||
| cudnnGetConvolutionNdDescriptor(value, nbDims, &nbDims, padA, strideA, | |||||
| dilationA, &mode, &computeType); | |||||
| ser.write_plain(nbDims); | |||||
| for (int i = 0; i < nbDims; ++i) { | |||||
| ser.write_plain(padA[i]); | |||||
| ser.write_plain(strideA[i]); | |||||
| ser.write_plain(dilationA[i]); | |||||
| class CudnnConvDescParam { | |||||
| public: | |||||
| cudnnConvolutionDescriptor_t value; | |||||
| Empty serialize(StringSerializer& ser, Empty) { | |||||
| int ndim = MEGDNN_MAX_NDIM; | |||||
| int padA[MEGDNN_MAX_NDIM]; | |||||
| int strideA[MEGDNN_MAX_NDIM]; | |||||
| int dilationA[MEGDNN_MAX_NDIM]; | |||||
| cudnnConvolutionMode_t mode; | |||||
| cudnnDataType_t computeType; | |||||
| cudnnGetConvolutionNdDescriptor(value, MEGDNN_MAX_NDIM, &ndim, padA, strideA, dilationA, &mode, &computeType); | |||||
| ser.write_plain(ndim); | |||||
| for (int i = 0; i < ndim; ++i) { | |||||
| ser.write_plain(padA[i]); | |||||
| ser.write_plain(strideA[i]); | |||||
| ser.write_plain(dilationA[i]); | |||||
| } | |||||
| ser.write_plain(mode); | |||||
| ser.write_plain(computeType); | |||||
| return Empty{}; | |||||
| } | } | ||||
| ser.write_plain(mode); | |||||
| ser.write_plain(computeType); | |||||
| return Empty{}; | |||||
| } | |||||
| Empty deserialize(StringSerializer& ser, Empty) { | |||||
| int ndim = ser.read_plain<int>(); | |||||
| int padA[MEGDNN_MAX_NDIM]; | |||||
| int strideA[MEGDNN_MAX_NDIM]; | |||||
| int dilationA[MEGDNN_MAX_NDIM]; | |||||
| for (int i = 0; i < ndim; ++i) { | |||||
| padA[i] = ser.read_plain<int>(); | |||||
| strideA[i] = ser.read_plain<int>(); | |||||
| dilationA[i] = ser.read_plain<int>(); | |||||
| Empty deserialize(StringSerializer& ser, Empty) { | |||||
| int ndim = ser.read_plain<int>(); | |||||
| int padA[MEGDNN_MAX_NDIM]; | |||||
| int strideA[MEGDNN_MAX_NDIM]; | |||||
| int dilationA[MEGDNN_MAX_NDIM]; | |||||
| for (int i = 0; i < ndim; ++i) { | |||||
| padA[i] = ser.read_plain<int>(); | |||||
| strideA[i] = ser.read_plain<int>(); | |||||
| dilationA[i] = ser.read_plain<int>(); | |||||
| } | |||||
| cudnnConvolutionMode_t mode = ser.read_plain<cudnnConvolutionMode_t>(); | |||||
| cudnnDataType_t computeType = ser.read_plain<cudnnDataType_t>(); | |||||
| cudnnSetConvolutionNdDescriptor(value, ndim, padA, strideA, dilationA, mode, computeType); | |||||
| return Empty{}; | |||||
| } | } | ||||
| cudnnConvolutionMode_t mode = ser.read_plain<cudnnConvolutionMode_t>(); | |||||
| cudnnDataType_t computeType = ser.read_plain<cudnnDataType_t>(); | |||||
| cudnnSetConvolutionNdDescriptor(value, ndim, padA, strideA, dilationA, | |||||
| mode, computeType); | |||||
| return Empty{}; | |||||
| } | |||||
| }; | |||||
| class CudnnTensorDescParam { | |||||
| public: | |||||
| cudnnTensorDescriptor_t value; | |||||
| Empty serialize(StringSerializer& ser, Empty) { | |||||
| constexpr int nbDims = MEGDNN_MAX_NDIM; | |||||
| cudnnDataType_t dataType; | |||||
| int dimA[MEGDNN_MAX_NDIM]; | |||||
| int strideA[MEGDNN_MAX_NDIM]; | |||||
| cudnnGetTensorNdDescriptor(value, nbDims, &dataType, &nbDims, dimA, | |||||
| strideA); | |||||
| ser.write_plain(nbDims); | |||||
| for (int i = 0; i < nbDims; ++i) { | |||||
| ser.write_plain(dimA[i]); | |||||
| ser.write_plain(strideA[i]); | |||||
| }; | |||||
| class CudnnTensorDescParam { | |||||
| public: | |||||
| cudnnTensorDescriptor_t value; | |||||
| Empty serialize(StringSerializer& ser, Empty) { | |||||
| int nbDims = MEGDNN_MAX_NDIM; | |||||
| cudnnDataType_t dataType; | |||||
| int dimA[MEGDNN_MAX_NDIM]; | |||||
| int strideA[MEGDNN_MAX_NDIM]; | |||||
| cudnnGetTensorNdDescriptor(value, nbDims, &dataType, &nbDims, dimA, strideA); | |||||
| ser.write_plain(nbDims); | |||||
| for (int i = 0; i < nbDims; ++i) { | |||||
| ser.write_plain(dimA[i]); | |||||
| ser.write_plain(strideA[i]); | |||||
| } | |||||
| ser.write_plain(dataType); | |||||
| return Empty{}; | |||||
| } | } | ||||
| ser.write_plain(dataType); | |||||
| return Empty{}; | |||||
| } | |||||
| Empty deserialize(StringSerializer& ser, Empty) { | |||||
| constexpr int nbDims = MEGDNN_MAX_NDIM; | |||||
| cudnnDataType_t dataType; | |||||
| int dimA[MEGDNN_MAX_NDIM]; | |||||
| int strideA[MEGDNN_MAX_NDIM]; | |||||
| nbDims = ser.read_plain<int>(); | |||||
| for (int i = 0; i < nbDims; ++i) { | |||||
| dimA[i] = ser.read_plain<int>(); | |||||
| strideA[i] = ser.read_plain<int>(); | |||||
| Empty deserialize(StringSerializer& ser, Empty) { | |||||
| int nbDims = MEGDNN_MAX_NDIM; | |||||
| cudnnDataType_t dataType; | |||||
| int dimA[MEGDNN_MAX_NDIM]; | |||||
| int strideA[MEGDNN_MAX_NDIM]; | |||||
| nbDims = ser.read_plain<int>(); | |||||
| for (int i = 0; i < nbDims; ++i) { | |||||
| dimA[i] = ser.read_plain<int>(); | |||||
| strideA[i] = ser.read_plain<int>(); | |||||
| } | |||||
| dataType = ser.read_plain<cudnnDataType_t>(); | |||||
| cudnnSetTensorNdDescriptor(value, dataType, nbDims, dimA, strideA); | |||||
| return Empty{}; | |||||
| } | } | ||||
| dataType = ser.read_plain<cudnnDataType_t>(); | |||||
| cudnnSetTensorNdDescriptor(value, dataType, nbDims, dimA, strideA); | |||||
| return Empty{}; | |||||
| } | |||||
| }; | |||||
| class CudnnFilterDescParam { | |||||
| public: | |||||
| cudnnFilterDescriptor_t value; | |||||
| Empty serialize(StringSerializer& ser, Empty) { | |||||
| constexpr int nbDims = MEGDNN_MAX_NDIM; | |||||
| cudnnDataType_t dataType; | |||||
| cudnnTensorFormat_t format; | |||||
| int filterDimA[MEGDNN_MAX_NDIM]; | |||||
| cudnnGetFilterNdDescriptor(value, nbDims, &dataType, &format, &nbDims, | |||||
| filterDimA); | |||||
| ser.write_plain(nbDims); | |||||
| for (int i = 0; i < nbDims; ++i) { | |||||
| ser.write_plain(filterDimA[i]); | |||||
| }; | |||||
| class CudnnFilterDescParam { | |||||
| public: | |||||
| cudnnFilterDescriptor_t value; | |||||
| Empty serialize(StringSerializer& ser, Empty) { | |||||
| int nbDims = MEGDNN_MAX_NDIM; | |||||
| cudnnDataType_t dataType; | |||||
| cudnnTensorFormat_t format; | |||||
| int filterDimA[MEGDNN_MAX_NDIM]; | |||||
| cudnnGetFilterNdDescriptor(value, nbDims, &dataType, &format, &nbDims, filterDimA); | |||||
| ser.write_plain(nbDims); | |||||
| for (int i = 0; i < nbDims; ++i) { | |||||
| ser.write_plain(filterDimA[i]); | |||||
| } | |||||
| ser.write_plain(dataType); | |||||
| ser.write_plain(format); | |||||
| return Empty{}; | |||||
| } | } | ||||
| ser.write_plain(dataType); | |||||
| ser.write_plain(format); | |||||
| return Empty{}; | |||||
| } | |||||
| Empty deserialize(StringSerializer& ser, Empty) { | |||||
| constexpr int nbDims = MEGDNN_MAX_NDIM; | |||||
| cudnnDataType_t dataType; | |||||
| cudnnTensorFormat_t format; | |||||
| int filterDimA[MEGDNN_MAX_NDIM]; | |||||
| nbDims = ser.read_plain<int>(); | |||||
| for (int i = 0; i < nbDims; ++i) { | |||||
| filterDimA[i] = ser.read_plain<int>(); | |||||
| Empty deserialize(StringSerializer& ser, Empty) { | |||||
| int nbDims = MEGDNN_MAX_NDIM; | |||||
| cudnnDataType_t dataType; | |||||
| cudnnTensorFormat_t format; | |||||
| int filterDimA[MEGDNN_MAX_NDIM]; | |||||
| nbDims = ser.read_plain<int>(); | |||||
| for (int i = 0; i < nbDims; ++i) { | |||||
| filterDimA[i] = ser.read_plain<int>(); | |||||
| } | |||||
| dataType = ser.read_plain<cudnnDataType_t>(); | |||||
| format = ser.read_plain<cudnnTensorFormat_t>(); | |||||
| cudnnSetFilterNdDescriptor(value, dataType, format, nbDims, filterDimA); | |||||
| return Empty{}; | |||||
| } | } | ||||
| dataType = ser.read_plain<cudnnDataType_t>(); | |||||
| format = ser.read_plain<cudnnTensorFormat_t>(); | |||||
| cudnnSetFilterNdDescriptor(value, dataType, format, nbDims, filterDimA); | |||||
| return Empty{}; | |||||
| } | |||||
| }; | |||||
| } // namespace megdnn | |||||
| }; | |||||
| } | |||||
| @@ -56,8 +56,7 @@ bool ConvBiasForwardImpl::AlgoCUDNNConv::is_available( | |||||
| conv_args.init_conv_desc(D); | conv_args.init_conv_desc(D); | ||||
| size_t workspace_size; | size_t workspace_size; | ||||
| auto& cudnn = conv_args.handle->cudnn(); | |||||
| auto status = cudnn.GetConvolutionForwardWorkspaceSize( | |||||
| auto status = cudnnGetConvolutionForwardWorkspaceSize( | |||||
| conv_args.handle->cudnn_handle(), D.src_desc.desc, | conv_args.handle->cudnn_handle(), D.src_desc.desc, | ||||
| D.filter_desc.desc, D.conv_desc.conv_desc, D.dst_desc.desc, | D.filter_desc.desc, D.conv_desc.conv_desc, D.dst_desc.desc, | ||||
| m_cudnn_enum, &workspace_size); | m_cudnn_enum, &workspace_size); | ||||
| @@ -83,8 +82,7 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoCUDNNConv::get_workspace_bundle( | |||||
| conv_args.init_conv_desc(D); | conv_args.init_conv_desc(D); | ||||
| size_t conv_workspace_size; | size_t conv_workspace_size; | ||||
| auto& cudnn = conv_args.handle->cudnn(); | |||||
| auto status = cudnn.GetConvolutionForwardWorkspaceSize( | |||||
| auto status = cudnnGetConvolutionForwardWorkspaceSize( | |||||
| conv_args.handle->cudnn_handle(), D.src_desc.desc, | conv_args.handle->cudnn_handle(), D.src_desc.desc, | ||||
| D.filter_desc.desc, D.conv_desc.conv_desc, D.dst_desc.desc, | D.filter_desc.desc, D.conv_desc.conv_desc, D.dst_desc.desc, | ||||
| m_cudnn_enum, &conv_workspace_size); | m_cudnn_enum, &conv_workspace_size); | ||||
| @@ -149,8 +149,7 @@ bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available( | |||||
| megdnn_throw("unsupported NonlineMode"); | megdnn_throw("unsupported NonlineMode"); | ||||
| } | } | ||||
| size_t workspace_size; | size_t workspace_size; | ||||
| auto& cudnn = args.handle->cudnn(); | |||||
| auto status = cudnn.GetConvolutionForwardWorkspaceSize( | |||||
| auto status = cudnnGetConvolutionForwardWorkspaceSize( | |||||
| args.handle->cudnn_handle(), D.src_desc.desc, D.filter_desc.desc, | args.handle->cudnn_handle(), D.src_desc.desc, D.filter_desc.desc, | ||||
| D.conv_desc.conv_desc, D.dst_desc.desc, m_cudnn_enum, | D.conv_desc.conv_desc, D.dst_desc.desc, m_cudnn_enum, | ||||
| &workspace_size); | &workspace_size); | ||||
| @@ -163,8 +162,7 @@ size_t ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::get_workspace_in_bytes( | |||||
| args.init_conv_bias_desc(D); | args.init_conv_bias_desc(D); | ||||
| size_t workspace_size; | size_t workspace_size; | ||||
| auto& cudnn = args.handle->cudnn(); | |||||
| auto status = cudnn.GetConvolutionForwardWorkspaceSize( | |||||
| auto status = cudnnGetConvolutionForwardWorkspaceSize( | |||||
| args.handle->cudnn_handle(), D.src_desc.desc, D.filter_desc.desc, | args.handle->cudnn_handle(), D.src_desc.desc, D.filter_desc.desc, | ||||
| D.conv_desc.conv_desc, D.dst_desc.desc, m_cudnn_enum, | D.conv_desc.conv_desc, D.dst_desc.desc, m_cudnn_enum, | ||||
| &workspace_size); | &workspace_size); | ||||
| @@ -95,13 +95,12 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | |||||
| CUDNNForwardDescs desc; | CUDNNForwardDescs desc; | ||||
| conv_args.init_conv_desc(desc); | conv_args.init_conv_desc(desc); | ||||
| #if CUDNN_MAJOR >= 7 | #if CUDNN_MAJOR >= 7 | ||||
| auto& cudnn = static_cast<HandleImpl*>(this->handle())->cudnn(); | |||||
| int max_count = 0; | int max_count = 0; | ||||
| cudnn_check(cudnn.GetConvolutionForwardAlgorithmMaxCount(cudnn_handle, | |||||
| cudnn_check(cudnnGetConvolutionForwardAlgorithmMaxCount(cudnn_handle, | |||||
| &max_count)); | &max_count)); | ||||
| SmallVector<cudnnConvolutionFwdAlgoPerf_t> algo_perf(max_count); | SmallVector<cudnnConvolutionFwdAlgoPerf_t> algo_perf(max_count); | ||||
| int ret_count = 0; | int ret_count = 0; | ||||
| cudnn_check(cudnn.GetConvolutionForwardAlgorithm_v7( | |||||
| cudnn_check(cudnnGetConvolutionForwardAlgorithm_v7( | |||||
| cudnn_handle, desc.src_desc.desc, desc.filter_desc.desc, | cudnn_handle, desc.src_desc.desc, desc.filter_desc.desc, | ||||
| desc.conv_desc.conv_desc, desc.dst_desc.desc, max_count, | desc.conv_desc.conv_desc, desc.dst_desc.desc, max_count, | ||||
| &ret_count, algo_perf.data())); | &ret_count, algo_perf.data())); | ||||
| @@ -42,10 +42,9 @@ bool ConvolutionBackwardDataImpl::AlgoCUDNN::is_available( | |||||
| if (!conv_bias::is_cudnn_supported(bias_args)) | if (!conv_bias::is_cudnn_supported(bias_args)) | ||||
| return false; | return false; | ||||
| auto& cudnn = args.handle->cudnn(); | |||||
| args.init_desc(D); | args.init_desc(D); | ||||
| size_t workspace_size; | size_t workspace_size; | ||||
| auto status = cudnn.GetConvolutionBackwardDataWorkspaceSize( | |||||
| auto status = cudnnGetConvolutionBackwardDataWorkspaceSize( | |||||
| args.handle->cudnn_handle(), | args.handle->cudnn_handle(), | ||||
| D.filter_desc.desc, | D.filter_desc.desc, | ||||
| D.diff_desc.desc, | D.diff_desc.desc, | ||||
| @@ -58,11 +57,10 @@ bool ConvolutionBackwardDataImpl::AlgoCUDNN::is_available( | |||||
| size_t ConvolutionBackwardDataImpl::AlgoCUDNN::get_workspace_in_bytes( | size_t ConvolutionBackwardDataImpl::AlgoCUDNN::get_workspace_in_bytes( | ||||
| const SizeArgs &args) const { | const SizeArgs &args) const { | ||||
| auto& cudnn = args.handle->cudnn(); | |||||
| CUDNNBwdDataDescs D; | CUDNNBwdDataDescs D; | ||||
| args.init_desc(D); | args.init_desc(D); | ||||
| size_t workspace_size; | size_t workspace_size; | ||||
| auto status = cudnn.GetConvolutionBackwardDataWorkspaceSize( | |||||
| auto status = cudnnGetConvolutionBackwardDataWorkspaceSize( | |||||
| args.handle->cudnn_handle(), | args.handle->cudnn_handle(), | ||||
| D.filter_desc.desc, | D.filter_desc.desc, | ||||
| D.diff_desc.desc, | D.diff_desc.desc, | ||||
| @@ -29,7 +29,6 @@ bool ConvolutionBackwardFilterImpl::AlgoCUDNN::is_available( | |||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| auto& cudnn = args.handle->cudnn(); | |||||
| CUDNNBwdFilterDescs D; | CUDNNBwdFilterDescs D; | ||||
| TensorLayout bias_layout, z_layout; | TensorLayout bias_layout, z_layout; | ||||
| @@ -44,7 +43,7 @@ bool ConvolutionBackwardFilterImpl::AlgoCUDNN::is_available( | |||||
| args.init_desc(D); | args.init_desc(D); | ||||
| size_t workspace_size; | size_t workspace_size; | ||||
| auto status = cudnn.GetConvolutionBackwardFilterWorkspaceSize( | |||||
| auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize( | |||||
| args.handle->cudnn_handle(), | args.handle->cudnn_handle(), | ||||
| D.src_desc.desc, | D.src_desc.desc, | ||||
| D.diff_desc.desc, | D.diff_desc.desc, | ||||
| @@ -57,11 +56,10 @@ bool ConvolutionBackwardFilterImpl::AlgoCUDNN::is_available( | |||||
| size_t ConvolutionBackwardFilterImpl::AlgoCUDNN::get_workspace_in_bytes( | size_t ConvolutionBackwardFilterImpl::AlgoCUDNN::get_workspace_in_bytes( | ||||
| const SizeArgs &args) const { | const SizeArgs &args) const { | ||||
| auto& cudnn = args.handle->cudnn(); | |||||
| CUDNNBwdFilterDescs D; | CUDNNBwdFilterDescs D; | ||||
| args.init_desc(D); | args.init_desc(D); | ||||
| size_t workspace_size; | size_t workspace_size; | ||||
| auto status = cudnn.GetConvolutionBackwardFilterWorkspaceSize( | |||||
| auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize( | |||||
| args.handle->cudnn_handle(), | args.handle->cudnn_handle(), | ||||
| D.src_desc.desc, | D.src_desc.desc, | ||||
| D.diff_desc.desc, | D.diff_desc.desc, | ||||
| @@ -144,13 +144,12 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic( | |||||
| #if CUDNN_MAJOR >= 7 | #if CUDNN_MAJOR >= 7 | ||||
| MEGDNN_MARK_USED_VAR(negative_attr); | MEGDNN_MARK_USED_VAR(negative_attr); | ||||
| auto& cudnn = args.handle->cudnn(); | |||||
| int max_count = 0; | int max_count = 0; | ||||
| cudnn_check(cudnn.GetConvolutionBackwardDataAlgorithmMaxCount( | |||||
| cudnn_check(cudnnGetConvolutionBackwardDataAlgorithmMaxCount( | |||||
| cudnn_handle, &max_count)); | cudnn_handle, &max_count)); | ||||
| SmallVector<cudnnConvolutionBwdDataAlgoPerf_t> algo_perf(max_count); | SmallVector<cudnnConvolutionBwdDataAlgoPerf_t> algo_perf(max_count); | ||||
| int ret_count = 0; | int ret_count = 0; | ||||
| cudnn_check(cudnn.GetConvolutionBackwardDataAlgorithm_v7( | |||||
| cudnn_check(cudnnGetConvolutionBackwardDataAlgorithm_v7( | |||||
| cudnn_handle, desc.filter_desc.desc, desc.diff_desc.desc, | cudnn_handle, desc.filter_desc.desc, desc.diff_desc.desc, | ||||
| desc.conv_desc.desc, desc.grad_desc.desc, max_count, &ret_count, | desc.conv_desc.desc, desc.grad_desc.desc, max_count, &ret_count, | ||||
| algo_perf.data())); | algo_perf.data())); | ||||
| @@ -280,13 +279,12 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | |||||
| #endif | #endif | ||||
| #if CUDNN_MAJOR >= 7 | #if CUDNN_MAJOR >= 7 | ||||
| MEGDNN_MARK_USED_VAR(negative_attr); | MEGDNN_MARK_USED_VAR(negative_attr); | ||||
| auto& cudnn = args.handle->cudnn(); | |||||
| int max_count = 0; | int max_count = 0; | ||||
| cudnn_check(cudnn.GetConvolutionBackwardFilterAlgorithmMaxCount( | |||||
| cudnn_check(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount( | |||||
| cudnn_handle, &max_count)); | cudnn_handle, &max_count)); | ||||
| SmallVector<cudnnConvolutionBwdFilterAlgoPerf_t> algo_perf(max_count); | SmallVector<cudnnConvolutionBwdFilterAlgoPerf_t> algo_perf(max_count); | ||||
| int ret_count = 0; | int ret_count = 0; | ||||
| cudnn_check(cudnn.GetConvolutionBackwardFilterAlgorithm_v7( | |||||
| cudnn_check(cudnnGetConvolutionBackwardFilterAlgorithm_v7( | |||||
| cudnn_handle, desc.src_desc.desc, desc.diff_desc.desc, | cudnn_handle, desc.src_desc.desc, desc.diff_desc.desc, | ||||
| desc.conv_desc.desc, desc.grad_desc.desc, max_count, &ret_count, | desc.conv_desc.desc, desc.grad_desc.desc, max_count, &ret_count, | ||||
| algo_perf.data())); | algo_perf.data())); | ||||
| @@ -28,8 +28,7 @@ bool Convolution3DBackwardDataImpl::AlgoCUDNN::is_available( | |||||
| args.init_desc(D); | args.init_desc(D); | ||||
| size_t workspace_size; | size_t workspace_size; | ||||
| auto& cudnn = args.handle->cudnn(); | |||||
| auto status = cudnn.GetConvolutionBackwardDataWorkspaceSize( | |||||
| auto status = cudnnGetConvolutionBackwardDataWorkspaceSize( | |||||
| args.handle->cudnn_handle(), | args.handle->cudnn_handle(), | ||||
| D.filter_desc.desc, | D.filter_desc.desc, | ||||
| D.diff_desc.desc, | D.diff_desc.desc, | ||||
| @@ -45,8 +44,7 @@ size_t Convolution3DBackwardDataImpl::AlgoCUDNN::get_workspace_in_bytes( | |||||
| CUDNNBwdDataDescs D; | CUDNNBwdDataDescs D; | ||||
| args.init_desc(D); | args.init_desc(D); | ||||
| size_t workspace_size; | size_t workspace_size; | ||||
| auto& cudnn = args.handle->cudnn(); | |||||
| auto status = cudnn.GetConvolutionBackwardDataWorkspaceSize( | |||||
| auto status = cudnnGetConvolutionBackwardDataWorkspaceSize( | |||||
| args.handle->cudnn_handle(), | args.handle->cudnn_handle(), | ||||
| D.filter_desc.desc, | D.filter_desc.desc, | ||||
| D.diff_desc.desc, | D.diff_desc.desc, | ||||
| @@ -28,8 +28,7 @@ bool Convolution3DBackwardFilterImpl::AlgoCUDNN::is_available( | |||||
| args.init_desc(D); | args.init_desc(D); | ||||
| size_t workspace_size; | size_t workspace_size; | ||||
| auto& cudnn = args.handle->cudnn(); | |||||
| auto status = cudnn.GetConvolutionBackwardFilterWorkspaceSize( | |||||
| auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize( | |||||
| args.handle->cudnn_handle(), D.src_desc.desc, D.diff_desc.desc, | args.handle->cudnn_handle(), D.src_desc.desc, D.diff_desc.desc, | ||||
| D.conv_desc.desc, D.grad_desc.desc, m_cudnn_enum, &workspace_size); | D.conv_desc.desc, D.grad_desc.desc, m_cudnn_enum, &workspace_size); | ||||
| return status == CUDNN_STATUS_SUCCESS; | return status == CUDNN_STATUS_SUCCESS; | ||||
| @@ -41,8 +40,7 @@ size_t Convolution3DBackwardFilterImpl::AlgoCUDNN::get_workspace_in_bytes( | |||||
| args.init_desc(D); | args.init_desc(D); | ||||
| size_t workspace_size; | size_t workspace_size; | ||||
| auto& cudnn = args.handle->cudnn(); | |||||
| auto status = cudnn.GetConvolutionBackwardFilterWorkspaceSize( | |||||
| auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize( | |||||
| args.handle->cudnn_handle(), D.src_desc.desc, D.diff_desc.desc, | args.handle->cudnn_handle(), D.src_desc.desc, D.diff_desc.desc, | ||||
| D.conv_desc.desc, D.grad_desc.desc, m_cudnn_enum, &workspace_size); | D.conv_desc.desc, D.grad_desc.desc, m_cudnn_enum, &workspace_size); | ||||
| megdnn_assert(status == CUDNN_STATUS_SUCCESS, | megdnn_assert(status == CUDNN_STATUS_SUCCESS, | ||||
| @@ -27,8 +27,7 @@ bool Convolution3DForwardImpl::AlgoCUDNN::is_available( | |||||
| args.init_desc(D); | args.init_desc(D); | ||||
| size_t workspace_size; | size_t workspace_size; | ||||
| auto& cudnn = args.handle->cudnn(); | |||||
| auto status = cudnn.GetConvolutionForwardWorkspaceSize( | |||||
| auto status = cudnnGetConvolutionForwardWorkspaceSize( | |||||
| args.handle->cudnn_handle(), | args.handle->cudnn_handle(), | ||||
| D.src_desc.desc, | D.src_desc.desc, | ||||
| D.filter_desc.desc, | D.filter_desc.desc, | ||||
| @@ -44,8 +43,7 @@ size_t Convolution3DForwardImpl::AlgoCUDNN::get_workspace_in_bytes( | |||||
| CUDNNForwardDescs D; | CUDNNForwardDescs D; | ||||
| args.init_desc(D); | args.init_desc(D); | ||||
| size_t workspace_size; | size_t workspace_size; | ||||
| auto& cudnn = args.handle->cudnn(); | |||||
| auto status = cudnn.GetConvolutionForwardWorkspaceSize( | |||||
| auto status = cudnnGetConvolutionForwardWorkspaceSize( | |||||
| args.handle->cudnn_handle(), | args.handle->cudnn_handle(), | ||||
| D.src_desc.desc, | D.src_desc.desc, | ||||
| D.filter_desc.desc, | D.filter_desc.desc, | ||||
| @@ -93,7 +93,7 @@ namespace convolution3d { | |||||
| const Workspace &workspace, void *&raw_ptr); | const Workspace &workspace, void *&raw_ptr); | ||||
| inline bool cudnn_get_convolution_fwd_algo_helper( | inline bool cudnn_get_convolution_fwd_algo_helper( | ||||
| Handle* handle, const cudnnTensorDescriptor_t x_desc, | |||||
| cudnnHandle_t cudnn_handle, const cudnnTensorDescriptor_t x_desc, | |||||
| const cudnnFilterDescriptor_t w_desc, | const cudnnFilterDescriptor_t w_desc, | ||||
| const cudnnConvolutionDescriptor_t conv_desc, | const cudnnConvolutionDescriptor_t conv_desc, | ||||
| const cudnnTensorDescriptor_t y_desc, | const cudnnTensorDescriptor_t y_desc, | ||||
| @@ -103,14 +103,13 @@ namespace convolution3d { | |||||
| MEGDNN_MARK_USED_VAR(positive_attr); | MEGDNN_MARK_USED_VAR(positive_attr); | ||||
| MEGDNN_MARK_USED_VAR(negative_attr); | MEGDNN_MARK_USED_VAR(negative_attr); | ||||
| #if CUDNN_MAJOR >= 7 | #if CUDNN_MAJOR >= 7 | ||||
| auto& cudnn = static_cast<HandleImpl*>(handle)->cudnn(); | |||||
| int algo_max_count = 0; | int algo_max_count = 0; | ||||
| cudnn_check(cudnn.GetConvolutionForwardAlgorithmMaxCount( | |||||
| cuda::cudnn_handle(handle), &algo_max_count)); | |||||
| cudnn_check(cudnnGetConvolutionForwardAlgorithmMaxCount( | |||||
| cudnn_handle, &algo_max_count)); | |||||
| SmallVector<cudnnConvolutionFwdAlgoPerf_t> algo_perf(algo_max_count); | SmallVector<cudnnConvolutionFwdAlgoPerf_t> algo_perf(algo_max_count); | ||||
| int algo_count = 0; | int algo_count = 0; | ||||
| cudnn_check(cudnn.GetConvolutionForwardAlgorithm_v7( | |||||
| cuda::cudnn_handle(handle), x_desc, w_desc, conv_desc, y_desc, algo_max_count, | |||||
| cudnn_check(cudnnGetConvolutionForwardAlgorithm_v7( | |||||
| cudnn_handle, x_desc, w_desc, conv_desc, y_desc, algo_max_count, | |||||
| &algo_count, algo_perf.data())); | &algo_count, algo_perf.data())); | ||||
| for (int i = 0; i < algo_count; ++i) { | for (int i = 0; i < algo_count; ++i) { | ||||
| if (algo_perf[i].algo == | if (algo_perf[i].algo == | ||||
| @@ -118,8 +117,8 @@ namespace convolution3d { | |||||
| CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING) | CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING) | ||||
| continue; | continue; | ||||
| size_t workspace_size = 0; | size_t workspace_size = 0; | ||||
| cudnn_check(cudnn.GetConvolutionForwardWorkspaceSize( | |||||
| cuda::cudnn_handle(handle), x_desc, w_desc, conv_desc, y_desc, | |||||
| cudnn_check(cudnnGetConvolutionForwardWorkspaceSize( | |||||
| cudnn_handle, x_desc, w_desc, conv_desc, y_desc, | |||||
| algo_perf[i].algo, &workspace_size)); | algo_perf[i].algo, &workspace_size)); | ||||
| if (workspace_size > workspace_limit_in_bytes) continue; | if (workspace_size > workspace_limit_in_bytes) continue; | ||||
| if (!(positive_attr & AlgoAttribute::REPRODUCIBLE)) { | if (!(positive_attr & AlgoAttribute::REPRODUCIBLE)) { | ||||
| @@ -135,7 +134,7 @@ namespace convolution3d { | |||||
| return false; | return false; | ||||
| #else | #else | ||||
| cudnn_check(cudnnGetConvolutionForwardAlgorithm( | cudnn_check(cudnnGetConvolutionForwardAlgorithm( | ||||
| cuda::cudnn_handle(handle), x_desc, w_desc, conv_desc, y_desc, | |||||
| cudnn_handle, x_desc, w_desc, conv_desc, y_desc, | |||||
| CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, | CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, | ||||
| workspace_limit_in_bytes, algo)); | workspace_limit_in_bytes, algo)); | ||||
| return true; | return true; | ||||
| @@ -64,12 +64,13 @@ Convolution3DForwardImpl::get_algorithm_heuristic( | |||||
| auto get_cudnn_algo = | auto get_cudnn_algo = | ||||
| [this, &args, workspace_limit_in_bytes, positive_attr, | [this, &args, workspace_limit_in_bytes, positive_attr, | ||||
| negative_attr]() -> Convolution3DForwardImpl::AlgoBase* { | negative_attr]() -> Convolution3DForwardImpl::AlgoBase* { | ||||
| auto cudnn_handle = cuda::cudnn_handle(this->handle()); | |||||
| cudnnConvolutionFwdAlgo_t algo; | cudnnConvolutionFwdAlgo_t algo; | ||||
| CUDNNForwardDescs desc; | CUDNNForwardDescs desc; | ||||
| args.init_desc(desc); | args.init_desc(desc); | ||||
| bool got = cudnn_get_convolution_fwd_algo_helper( | bool got = cudnn_get_convolution_fwd_algo_helper( | ||||
| this->handle(), desc.src_desc.desc, desc.filter_desc.desc, | |||||
| cudnn_handle, desc.src_desc.desc, desc.filter_desc.desc, | |||||
| desc.conv_desc.desc, desc.dst_desc.desc, | desc.conv_desc.desc, desc.dst_desc.desc, | ||||
| workspace_limit_in_bytes, &algo, positive_attr, negative_attr); | workspace_limit_in_bytes, &algo, positive_attr, negative_attr); | ||||
| if (got) { | if (got) { | ||||
| @@ -56,7 +56,7 @@ namespace convolution { | |||||
| using KernLayout = _kern_layout; \ | using KernLayout = _kern_layout; \ | ||||
| using OutputLayout = _output_layout; \ | using OutputLayout = _output_layout; \ | ||||
| using Param = _conv_param; \ | using Param = _conv_param; \ | ||||
| static constexpr bool check_bounds = check_bounds_ | |||||
| static constexpr bool check_bounds = check_bounds_; | |||||
| #define MEGDNN_COMMA , | #define MEGDNN_COMMA , | ||||
| template <bool check_bounds_, typename src_ldg_dtype, typename filter_ldg_dtype, | template <bool check_bounds_, typename src_ldg_dtype, typename filter_ldg_dtype, | ||||
| @@ -53,7 +53,7 @@ namespace convolution { | |||||
| using KernLayout = _kern_layout; \ | using KernLayout = _kern_layout; \ | ||||
| using OutputLayout = _output_layout; \ | using OutputLayout = _output_layout; \ | ||||
| using Param = _conv_param; \ | using Param = _conv_param; \ | ||||
| static constexpr bool check_bounds = check_bounds_ | |||||
| static constexpr bool check_bounds = check_bounds_; | |||||
| #define MEGDNN_COMMA , | #define MEGDNN_COMMA , | ||||
| template <bool check_bounds_, typename IMMAConfig_, typename WarpTileConfig_, | template <bool check_bounds_, typename IMMAConfig_, typename WarpTileConfig_, | ||||
| @@ -53,7 +53,7 @@ namespace convolution { | |||||
| using KernLayout = _kern_layout; \ | using KernLayout = _kern_layout; \ | ||||
| using OutputLayout = _output_layout; \ | using OutputLayout = _output_layout; \ | ||||
| using Param = _conv_param; \ | using Param = _conv_param; \ | ||||
| static constexpr bool check_bounds = check_bounds_ | |||||
| static constexpr bool check_bounds = check_bounds_; | |||||
| #define MEGDNN_COMMA , | #define MEGDNN_COMMA , | ||||
| template <bool check_bounds_, typename ldg_dtype, typename RegBlockConfig_, | template <bool check_bounds_, typename ldg_dtype, typename RegBlockConfig_, | ||||
| @@ -11,16 +11,13 @@ | |||||
| #include "src/common/handle_impl.h" | #include "src/common/handle_impl.h" | ||||
| #include "src/common/version_symbol.h" | #include "src/common/version_symbol.h" | ||||
| #include "src/common/api_cache.h" | |||||
| #include "src/cuda/handle.h" | #include "src/cuda/handle.h" | ||||
| #include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
| #include "src/cuda/api_cache.h" | |||||
| #include "megdnn/common.h" | #include "megdnn/common.h" | ||||
| #include <cuda.h> | #include <cuda.h> | ||||
| #include <cstring> | #include <cstring> | ||||
| #include <memory> | |||||
| #define STR_HELPER(x) #x | #define STR_HELPER(x) #x | ||||
| #define STR(x) STR_HELPER(x) | #define STR(x) STR_HELPER(x) | ||||
| @@ -94,8 +91,6 @@ HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle): | |||||
| // check tk1 | // check tk1 | ||||
| m_is_tegra_k1 = (strcmp(m_device_prop->name, "GK20A") == 0); | m_is_tegra_k1 = (strcmp(m_device_prop->name, "GK20A") == 0); | ||||
| m_cusolver_handle = nullptr; | m_cusolver_handle = nullptr; | ||||
| m_cudnn_api_cache = std::make_unique<CUDNN>(m_cudnn_handle); | |||||
| } | } | ||||
| HandleImpl::~HandleImpl() noexcept { | HandleImpl::~HandleImpl() noexcept { | ||||
| @@ -141,111 +136,8 @@ HandleImpl::HandleVendorType HandleImpl::vendor_type() const { | |||||
| return HandleVendorType::CUDA; | return HandleVendorType::CUDA; | ||||
| } | } | ||||
| HandleImpl::CUDNN& HandleImpl::cudnn() { | |||||
| return *m_cudnn_api_cache; | |||||
| } | |||||
| HandleImpl::CUDNN::CUDNN(cudnnHandle_t handle) { | |||||
| m_handle = handle; | |||||
| GetConvolutionForwardWorkspaceSize = | |||||
| FunctionCacheBuilder<>() | |||||
| .input<Param<cudnnHandle_t>>() | |||||
| .input<CudnnTensorDescParam>() | |||||
| .input<CudnnFilterDescParam>() | |||||
| .input<CudnnConvDescParam>() | |||||
| .input<CudnnTensorDescParam>() | |||||
| .input<Param<cudnnConvolutionFwdAlgo_t>>() | |||||
| .output<RefParam<size_t>>() | |||||
| .ret<Param<cudnnStatus_t>>() | |||||
| .build(&cudnnGetConvolutionForwardWorkspaceSize); | |||||
| #if CUDNN_MAJOR >= 7 | |||||
| GetConvolutionForwardAlgorithm_v7 = | |||||
| FunctionCacheBuilder<>() | |||||
| .input<Param<cudnnHandle_t>>() | |||||
| .input<CudnnTensorDescParam>() | |||||
| .input<CudnnFilterDescParam>() | |||||
| .input<CudnnConvDescParam>() | |||||
| .input<CudnnTensorDescParam>() | |||||
| .input<Param<int>>() | |||||
| .output<RefArraySizeParam<int>>() | |||||
| .output<ArrayParam<int, cudnnConvolutionFwdAlgoPerf_t>>() | |||||
| .ret<Param<cudnnStatus_t>>() | |||||
| .build(&cudnnGetConvolutionForwardAlgorithm_v7); | |||||
| GetConvolutionForwardAlgorithmMaxCount = | |||||
| FunctionCacheBuilder<>() | |||||
| .input<Param<cudnnHandle_t>>() | |||||
| .output<RefParam<int>>() | |||||
| .ret<Param<cudnnStatus_t>>() | |||||
| .build(&cudnnGetConvolutionForwardAlgorithmMaxCount); | |||||
| #endif | |||||
| GetConvolutionBackwardDataWorkspaceSize = | |||||
| FunctionCacheBuilder<>() | |||||
| .input<Param<cudnnHandle_t>>() | |||||
| .input<CudnnFilterDescParam>() | |||||
| .input<CudnnTensorDescParam>() | |||||
| .input<CudnnConvDescParam>() | |||||
| .input<CudnnTensorDescParam>() | |||||
| .input<Param<cudnnConvolutionBwdDataAlgo_t>>() | |||||
| .output<RefParam<size_t>>() | |||||
| .ret<Param<cudnnStatus_t>>() | |||||
| .build(&cudnnGetConvolutionBackwardDataWorkspaceSize); | |||||
| #if CUDNN_MAJOR >= 7 | |||||
| GetConvolutionBackwardDataAlgorithm_v7 = | |||||
| FunctionCacheBuilder<>() | |||||
| .input<Param<cudnnHandle_t>>() | |||||
| .input<CudnnFilterDescParam>() | |||||
| .input<CudnnTensorDescParam>() | |||||
| .input<CudnnConvDescParam>() | |||||
| .input<CudnnTensorDescParam>() | |||||
| .input<Param<int>>() | |||||
| .output<RefArraySizeParam<int>>() | |||||
| .output<ArrayParam<int, | |||||
| cudnnConvolutionBwdDataAlgoPerf_t>>() | |||||
| .ret<Param<cudnnStatus_t>>() | |||||
| .build(&cudnnGetConvolutionBackwardDataAlgorithm_v7); | |||||
| GetConvolutionBackwardDataAlgorithmMaxCount = | |||||
| FunctionCacheBuilder<>() | |||||
| .input<Param<cudnnHandle_t>>() | |||||
| .output<RefParam<int>>() | |||||
| .ret<Param<cudnnStatus_t>>() | |||||
| .build(&cudnnGetConvolutionBackwardDataAlgorithmMaxCount); | |||||
| #endif | |||||
| GetConvolutionBackwardFilterWorkspaceSize = | |||||
| FunctionCacheBuilder<>() | |||||
| .input<Param<cudnnHandle_t>>() | |||||
| .input<CudnnTensorDescParam>() | |||||
| .input<CudnnTensorDescParam>() | |||||
| .input<CudnnConvDescParam>() | |||||
| .input<CudnnFilterDescParam>() | |||||
| .input<Param<cudnnConvolutionBwdFilterAlgo_t>>() | |||||
| .output<RefParam<size_t>>() | |||||
| .ret<Param<cudnnStatus_t>>() | |||||
| .build(&cudnnGetConvolutionBackwardFilterWorkspaceSize); | |||||
| #if CUDNN_MAJOR >= 7 | |||||
| GetConvolutionBackwardFilterAlgorithm_v7 = | |||||
| FunctionCacheBuilder<>() | |||||
| .input<Param<cudnnHandle_t>>() | |||||
| .input<CudnnTensorDescParam>() | |||||
| .input<CudnnTensorDescParam>() | |||||
| .input<CudnnConvDescParam>() | |||||
| .input<CudnnFilterDescParam>() | |||||
| .input<Param<int>>() | |||||
| .output<RefArraySizeParam<int>>() | |||||
| .output<ArrayParam<int, | |||||
| cudnnConvolutionBwdFilterAlgoPerf_t>>() | |||||
| .ret<Param<cudnnStatus_t>>() | |||||
| .build(&cudnnGetConvolutionBackwardFilterAlgorithm_v7); | |||||
| GetConvolutionBackwardFilterAlgorithmMaxCount = | |||||
| FunctionCacheBuilder<>() | |||||
| .input<Param<cudnnHandle_t>>() | |||||
| .output<RefParam<int>>() | |||||
| .ret<Param<cudnnStatus_t>>() | |||||
| .build(&cudnnGetConvolutionBackwardFilterAlgorithmMaxCount); | |||||
| #endif | |||||
| } | |||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| MEGDNN_VERSION_SYMBOL(CUDA, CUDA_VERSION); | MEGDNN_VERSION_SYMBOL(CUDA, CUDA_VERSION); | ||||
| MEGDNN_VERSION_SYMBOL3(CUDNN, CUDNN_MAJOR, CUDNN_MINOR, CUDNN_PATCHLEVEL); | MEGDNN_VERSION_SYMBOL3(CUDNN, CUDNN_MAJOR, CUDNN_MINOR, CUDNN_PATCHLEVEL); | ||||
| @@ -124,10 +124,6 @@ class HandleImpl: public HandleImplHelper { | |||||
| size_t image2d_pitch_alignment() const override; | size_t image2d_pitch_alignment() const override; | ||||
| HandleVendorType vendor_type() const override; | HandleVendorType vendor_type() const override; | ||||
| class CUDNN; | |||||
| CUDNN& cudnn(); | |||||
| private: | private: | ||||
| bool m_is_tegra_k1; | bool m_is_tegra_k1; | ||||
| int m_device_id; | int m_device_id; | ||||
| @@ -160,34 +156,9 @@ class HandleImpl: public HandleImplHelper { | |||||
| //! device ptr to const scalars | //! device ptr to const scalars | ||||
| ConstScalars* m_const_scalars; | ConstScalars* m_const_scalars; | ||||
| std::unique_ptr<CUDNN> m_cudnn_api_cache; | |||||
| void initialize_cusolver(); | void initialize_cusolver(); | ||||
| }; | }; | ||||
| class HandleImpl::CUDNN { | |||||
| cudnnHandle_t m_handle; | |||||
| public: | |||||
| CUDNN(cudnnHandle_t handle); | |||||
| #define WRAP_CUDNN_API(NAME) thin_function<decltype(cudnn##NAME)> NAME; | |||||
| WRAP_CUDNN_API(GetConvolutionForwardWorkspaceSize); | |||||
| #if CUDNN_MAJOR >= 7 | |||||
| WRAP_CUDNN_API(GetConvolutionForwardAlgorithm_v7); | |||||
| WRAP_CUDNN_API(GetConvolutionForwardAlgorithmMaxCount); | |||||
| #endif | |||||
| #if CUDNN_MAJOR >= 7 | |||||
| WRAP_CUDNN_API(GetConvolutionBackwardDataAlgorithm_v7); | |||||
| WRAP_CUDNN_API(GetConvolutionBackwardDataAlgorithmMaxCount); | |||||
| #endif | |||||
| WRAP_CUDNN_API(GetConvolutionBackwardDataWorkspaceSize); | |||||
| #if CUDNN_MAJOR >= 7 | |||||
| WRAP_CUDNN_API(GetConvolutionBackwardFilterAlgorithmMaxCount); | |||||
| WRAP_CUDNN_API(GetConvolutionBackwardFilterAlgorithm_v7); | |||||
| #endif | |||||
| WRAP_CUDNN_API(GetConvolutionBackwardFilterWorkspaceSize); | |||||
| #undef WRAP_CUDNN_API | |||||
| }; | |||||
| } // namespace cuda | } // namespace cuda | ||||
| } // namespace megdnn | } // namespace megdnn | ||||