/** * \file dnn/src/cuda/api_cache.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ #pragma once #include "src/common/api_cache.h" #include "src/cuda/cudnn_wrapper.h" namespace megdnn { class CudnnConvDescParam { public: cudnnConvolutionDescriptor_t value; Empty serialize(StringSerializer& ser, Empty) { constexpr int nbDims = MEGDNN_MAX_NDIM; int padA[MEGDNN_MAX_NDIM]; int strideA[MEGDNN_MAX_NDIM]; int dilationA[MEGDNN_MAX_NDIM]; cudnnConvolutionMode_t mode; cudnnDataType_t computeType; cudnnGetConvolutionNdDescriptor(value, nbDims, &nbDims, padA, strideA, dilationA, &mode, &computeType); ser.write_plain(nbDims); for (int i = 0; i < nbDims; ++i) { ser.write_plain(padA[i]); ser.write_plain(strideA[i]); ser.write_plain(dilationA[i]); } ser.write_plain(mode); ser.write_plain(computeType); return Empty{}; } Empty deserialize(StringSerializer& ser, Empty) { int ndim = ser.read_plain(); int padA[MEGDNN_MAX_NDIM]; int strideA[MEGDNN_MAX_NDIM]; int dilationA[MEGDNN_MAX_NDIM]; for (int i = 0; i < ndim; ++i) { padA[i] = ser.read_plain(); strideA[i] = ser.read_plain(); dilationA[i] = ser.read_plain(); } cudnnConvolutionMode_t mode = ser.read_plain(); cudnnDataType_t computeType = ser.read_plain(); cudnnSetConvolutionNdDescriptor(value, ndim, padA, strideA, dilationA, mode, computeType); return Empty{}; } }; class CudnnTensorDescParam { public: cudnnTensorDescriptor_t value; Empty serialize(StringSerializer& ser, Empty) { constexpr int nbDims = MEGDNN_MAX_NDIM; cudnnDataType_t dataType; int dimA[MEGDNN_MAX_NDIM]; int strideA[MEGDNN_MAX_NDIM]; cudnnGetTensorNdDescriptor(value, nbDims, &dataType, &nbDims, dimA, strideA); ser.write_plain(nbDims); for (int i = 0; i < nbDims; ++i) { ser.write_plain(dimA[i]); ser.write_plain(strideA[i]); } ser.write_plain(dataType); return Empty{}; } Empty deserialize(StringSerializer& ser, Empty) { constexpr int nbDims = MEGDNN_MAX_NDIM; cudnnDataType_t dataType; int dimA[MEGDNN_MAX_NDIM]; int strideA[MEGDNN_MAX_NDIM]; nbDims = ser.read_plain(); for (int i = 0; i < nbDims; ++i) { dimA[i] = ser.read_plain(); strideA[i] = ser.read_plain(); } dataType = ser.read_plain(); cudnnSetTensorNdDescriptor(value, dataType, nbDims, dimA, strideA); return Empty{}; } }; class CudnnFilterDescParam { public: cudnnFilterDescriptor_t value; Empty serialize(StringSerializer& ser, Empty) { constexpr int nbDims = MEGDNN_MAX_NDIM; cudnnDataType_t dataType; cudnnTensorFormat_t format; int filterDimA[MEGDNN_MAX_NDIM]; cudnnGetFilterNdDescriptor(value, nbDims, &dataType, &format, &nbDims, filterDimA); ser.write_plain(nbDims); for (int i = 0; i < nbDims; ++i) { ser.write_plain(filterDimA[i]); } ser.write_plain(dataType); ser.write_plain(format); return Empty{}; } Empty deserialize(StringSerializer& ser, Empty) { constexpr int nbDims = MEGDNN_MAX_NDIM; cudnnDataType_t dataType; cudnnTensorFormat_t format; int filterDimA[MEGDNN_MAX_NDIM]; nbDims = ser.read_plain(); for (int i = 0; i < nbDims; ++i) { filterDimA[i] = ser.read_plain(); } dataType = ser.read_plain(); format = ser.read_plain(); cudnnSetFilterNdDescriptor(value, dataType, format, nbDims, filterDimA); return Empty{}; } }; } // namespace megdnn