| @@ -29,6 +29,7 @@ | |||
| #define MEGDNN_FLOAT16_SELECT(_x, _y) _y | |||
| #else | |||
| #include "megdnn/dtype/half.hpp" | |||
| #include "megdnn/dtype/bfloat16.hpp" | |||
| #define MEGDNN_INC_FLOAT16(_x) _x | |||
| #define MEGDNN_FLOAT16_SELECT(_x, _y) _x | |||
| #endif | |||
| @@ -49,6 +50,7 @@ namespace megdnn { | |||
| cb(IntB4) \ | |||
| cb(Byte) \ | |||
| MEGDNN_INC_FLOAT16(cb(Float16)) \ | |||
| MEGDNN_INC_FLOAT16(cb(BFloat16)) \ | |||
| cb(UintB4) \ | |||
| /*! | |||
| @@ -62,6 +64,7 @@ namespace megdnn { | |||
| cb(Int32) \ | |||
| cb(Byte) \ | |||
| MEGDNN_INC_FLOAT16(cb(Float16)) \ | |||
| MEGDNN_INC_FLOAT16(cb(BFloat16)) \ | |||
| /*! | |||
| * \brief iterate through each fractional byte dtype | |||
| @@ -101,6 +104,7 @@ namespace megdnn { | |||
| #define MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) \ | |||
| cb(::megdnn::dtype::Float32) \ | |||
| MEGDNN_INC_FLOAT16(cb(::megdnn::dtype::Float16)) \ | |||
| MEGDNN_INC_FLOAT16(cb(::megdnn::dtype::BFloat16)) \ | |||
| /*! | |||
| * \brief iterate through each dtype object that can be involved in integer | |||
| @@ -345,6 +349,7 @@ typedef int16_t dt_int16; | |||
| typedef int8_t dt_int8; | |||
| typedef uint8_t dt_uint8; | |||
| MEGDNN_INC_FLOAT16(typedef half_float::half dt_float16;) | |||
| MEGDNN_INC_FLOAT16(typedef half_bfloat16::bfloat16 dt_bfloat16;) | |||
| #define MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE 100000 | |||
| #if MEGDNN_CC_HOST | |||
| @@ -367,6 +372,9 @@ MEGDNN_INC_FLOAT16(typedef half_float::half dt_float16;) | |||
| Float16, | |||
| #endif | |||
| UintB4 = 10, | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| BFloat16 = 11, | |||
| #endif | |||
| #define FST(_name) _name = MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE, | |||
| #define D(_name) _name, | |||
| @@ -702,6 +710,9 @@ MEGDNN_DEF_DT(Uint8, dt_uint8, INT, UNSIGNED, 0, UINT8_MAX); | |||
| MEGDNN_INC_FLOAT16(MEGDNN_DEF_DT(Float16, dt_float16, FLOAT, SIGNED, | |||
| std::numeric_limits<dt_float16>::lowest(), | |||
| std::numeric_limits<dt_float16>::max())); | |||
| MEGDNN_INC_FLOAT16(MEGDNN_DEF_DT(BFloat16, dt_bfloat16, FLOAT, SIGNED, | |||
| std::numeric_limits<dt_bfloat16>::lowest(), | |||
| std::numeric_limits<dt_bfloat16>::max())); | |||
| template <> | |||
| struct DTypeTrait<dtype::Byte> { | |||
| @@ -50,167 +50,7 @@ | |||
| #include <hip/hip_fp16.h> | |||
| #endif | |||
| /// Combined gcc version number. | |||
| #define HALF_GNUC_VERSION (__GNUC__*100+__GNUC_MINOR__) | |||
| //check C++11 language features | |||
| #if defined(__clang__) //clang | |||
| #if __has_feature(cxx_static_assert) && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||
| #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||
| #endif | |||
| #if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR) | |||
| #define HALF_ENABLE_CPP11_CONSTEXPR 1 | |||
| #endif | |||
| #if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT) | |||
| #define HALF_ENABLE_CPP11_NOEXCEPT 1 | |||
| #endif | |||
| #if __has_feature(cxx_user_literals) && !defined(HALF_ENABLE_CPP11_USER_LITERALS) | |||
| #define HALF_ENABLE_CPP11_USER_LITERALS 1 | |||
| #endif | |||
| #if (defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && !defined(HALF_ENABLE_CPP11_LONG_LONG) | |||
| #define HALF_ENABLE_CPP11_LONG_LONG 1 | |||
| #endif | |||
| /*#elif defined(__INTEL_COMPILER) //Intel C++ | |||
| #if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) ???????? | |||
| #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||
| #endif | |||
| #if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) ???????? | |||
| #define HALF_ENABLE_CPP11_CONSTEXPR 1 | |||
| #endif | |||
| #if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) ???????? | |||
| #define HALF_ENABLE_CPP11_NOEXCEPT 1 | |||
| #endif | |||
| #if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_LONG_LONG) ???????? | |||
| #define HALF_ENABLE_CPP11_LONG_LONG 1 | |||
| #endif*/ | |||
| #elif defined(__GNUC__) //gcc | |||
| #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L | |||
| #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||
| #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||
| #endif | |||
| #if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) | |||
| #define HALF_ENABLE_CPP11_CONSTEXPR 1 | |||
| #endif | |||
| #if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) | |||
| #define HALF_ENABLE_CPP11_NOEXCEPT 1 | |||
| #endif | |||
| #if HALF_GNUC_VERSION >= 407 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) | |||
| #define HALF_ENABLE_CPP11_USER_LITERALS 1 | |||
| #endif | |||
| #if !defined(HALF_ENABLE_CPP11_LONG_LONG) | |||
| #define HALF_ENABLE_CPP11_LONG_LONG 1 | |||
| #endif | |||
| #endif | |||
| #elif defined(_MSC_VER) //Visual C++ | |||
| #if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||
| #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||
| #endif | |||
| #if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG) | |||
| #define HALF_ENABLE_CPP11_LONG_LONG 1 | |||
| #endif | |||
| #define HALF_POP_WARNINGS 1 | |||
| #pragma warning(push) | |||
| //! 4521 and 4522 is multiple copy/assigment operator specified | |||
| #pragma warning(disable : 4099 4127 4146 4521 4522) //struct vs class, constant in if, negative unsigned | |||
| #endif | |||
| //check C++11 library features | |||
| #include <utility> | |||
| #if defined(_LIBCPP_VERSION) //libc++ | |||
| #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 | |||
| #ifndef HALF_ENABLE_CPP11_TYPE_TRAITS | |||
| #define HALF_ENABLE_CPP11_TYPE_TRAITS 1 | |||
| #endif | |||
| #ifndef HALF_ENABLE_CPP11_CSTDINT | |||
| #define HALF_ENABLE_CPP11_CSTDINT 1 | |||
| #endif | |||
| #ifndef HALF_ENABLE_CPP11_CMATH | |||
| #define HALF_ENABLE_CPP11_CMATH 1 | |||
| #endif | |||
| #ifndef HALF_ENABLE_CPP11_HASH | |||
| #define HALF_ENABLE_CPP11_HASH 1 | |||
| #endif | |||
| #endif | |||
| #elif defined(__GLIBCXX__) //libstdc++ | |||
| #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 | |||
| #ifdef __clang__ | |||
| #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) | |||
| #define HALF_ENABLE_CPP11_TYPE_TRAITS 1 | |||
| #endif | |||
| #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT) | |||
| #define HALF_ENABLE_CPP11_CSTDINT 1 | |||
| #endif | |||
| #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH) | |||
| #define HALF_ENABLE_CPP11_CMATH 1 | |||
| #endif | |||
| #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH) | |||
| #define HALF_ENABLE_CPP11_HASH 1 | |||
| #endif | |||
| #else | |||
| #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CSTDINT) | |||
| #define HALF_ENABLE_CPP11_CSTDINT 1 | |||
| #endif | |||
| #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CMATH) | |||
| #define HALF_ENABLE_CPP11_CMATH 1 | |||
| #endif | |||
| #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH) | |||
| #define HALF_ENABLE_CPP11_HASH 1 | |||
| #endif | |||
| #endif | |||
| #endif | |||
| #elif defined(_CPPLIB_VER) //Dinkumware/Visual C++ | |||
| #if _CPPLIB_VER >= 520 | |||
| #ifndef HALF_ENABLE_CPP11_TYPE_TRAITS | |||
| #define HALF_ENABLE_CPP11_TYPE_TRAITS 1 | |||
| #endif | |||
| #ifndef HALF_ENABLE_CPP11_CSTDINT | |||
| #define HALF_ENABLE_CPP11_CSTDINT 1 | |||
| #endif | |||
| #ifndef HALF_ENABLE_CPP11_HASH | |||
| #define HALF_ENABLE_CPP11_HASH 1 | |||
| #endif | |||
| #endif | |||
| #if _CPPLIB_VER >= 610 | |||
| #ifndef HALF_ENABLE_CPP11_CMATH | |||
| #define HALF_ENABLE_CPP11_CMATH 1 | |||
| #endif | |||
| #endif | |||
| #endif | |||
| #undef HALF_GNUC_VERSION | |||
| //support constexpr | |||
| #if HALF_ENABLE_CPP11_CONSTEXPR | |||
| #define HALF_CONSTEXPR constexpr | |||
| #define HALF_CONSTEXPR_CONST constexpr | |||
| #else | |||
| #define HALF_CONSTEXPR | |||
| #define HALF_CONSTEXPR_CONST const | |||
| #endif | |||
| //support noexcept | |||
| #if HALF_ENABLE_CPP11_NOEXCEPT | |||
| #define HALF_NOEXCEPT noexcept | |||
| #define HALF_NOTHROW noexcept | |||
| #else | |||
| #define HALF_NOEXCEPT | |||
| #define HALF_NOTHROW throw() | |||
| #endif | |||
| #include <algorithm> | |||
| #include <limits> | |||
| #include <climits> | |||
| #include <cmath> | |||
| #include <cstring> | |||
| #if HALF_ENABLE_CPP11_TYPE_TRAITS | |||
| #include <type_traits> | |||
| #endif | |||
| #if HALF_ENABLE_CPP11_CSTDINT | |||
| #include <cstdint> | |||
| #endif | |||
| #if HALF_ENABLE_CPP11_HASH | |||
| #include <functional> | |||
| #endif | |||
| #include "megdnn/dtype/half_common_prologue.h" | |||
| /// Default rounding mode. | |||
| /// This specifies the rounding mode used for all conversions between [half](\ref half_float::half)s and `float`s as well as | |||
| @@ -3141,16 +2981,7 @@ namespace std | |||
| #endif | |||
| } | |||
| #undef HALF_CONSTEXPR | |||
| #undef HALF_CONSTEXPR_CONST | |||
| #undef HALF_NOEXCEPT | |||
| #undef HALF_NOTHROW | |||
| #ifdef HALF_POP_WARNINGS | |||
| #pragma warning(pop) | |||
| #undef HALF_POP_WARNINGS | |||
| #endif | |||
| #include "megdnn/dtype/half_common_epilogue.h" | |||
| #endif | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,48 @@ | |||
| /** | |||
| * half - IEEE 754-based half-precision floating point library. | |||
| * | |||
| * Copyright (c) 2012-2013 Christian Rau <rauy@users.sourceforge.net> | |||
| * | |||
| * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation | |||
| * files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, | |||
| * modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the | |||
| * Software is furnished to do so, subject to the following conditions: | |||
| * | |||
| * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. | |||
| * | |||
| * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE | |||
| * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR | |||
| * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, | |||
| * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | |||
| * | |||
| * Version 1.11.0 | |||
| * \file | |||
| * Main header file for half precision functionality. | |||
| * | |||
| * -------------------------------------------------------------------------- | |||
| * \file include/megdnn/dtype/half_common_epilogue.h | |||
| * | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * | |||
| * This file has been modified by Megvii ("Megvii Modifications"). | |||
| * All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved. | |||
| * | |||
| * -------------------------------------------------------------------------- | |||
| */ | |||
| #undef HALF_CONSTEXPR | |||
| #undef HALF_CONSTEXPR_CONST | |||
| #undef HALF_NOEXCEPT | |||
| #undef HALF_NOTHROW | |||
| #ifdef HALF_POP_WARNINGS | |||
| #pragma warning(pop) | |||
| #undef HALF_POP_WARNINGS | |||
| #endif | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -0,0 +1,202 @@ | |||
| /** | |||
| * half - IEEE 754-based half-precision floating point library. | |||
| * | |||
| * Copyright (c) 2012-2013 Christian Rau <rauy@users.sourceforge.net> | |||
| * | |||
| * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation | |||
| * files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, | |||
| * modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the | |||
| * Software is furnished to do so, subject to the following conditions: | |||
| * | |||
| * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. | |||
| * | |||
| * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE | |||
| * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR | |||
| * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, | |||
| * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | |||
| * | |||
| * Version 1.11.0 | |||
| * \file | |||
| * Main header file for half precision functionality. | |||
| * | |||
| * -------------------------------------------------------------------------- | |||
| * \file dnn/include/megdnn/dtype/half_common_prologue.h | |||
| * | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * | |||
| * This file has been modified by Megvii ("Megvii Modifications"). | |||
| * All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved. | |||
| * | |||
| * -------------------------------------------------------------------------- | |||
| */ | |||
| #include "megdnn/arch.h" | |||
| /// Combined gcc version number. | |||
| #define HALF_GNUC_VERSION (__GNUC__*100+__GNUC_MINOR__) | |||
| //check C++11 language features | |||
| #if defined(__clang__) //clang | |||
| #if __has_feature(cxx_static_assert) && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||
| #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||
| #endif | |||
| #if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR) | |||
| #define HALF_ENABLE_CPP11_CONSTEXPR 1 | |||
| #endif | |||
| #if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT) | |||
| #define HALF_ENABLE_CPP11_NOEXCEPT 1 | |||
| #endif | |||
| #if __has_feature(cxx_user_literals) && !defined(HALF_ENABLE_CPP11_USER_LITERALS) | |||
| #define HALF_ENABLE_CPP11_USER_LITERALS 1 | |||
| #endif | |||
| #if (defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && !defined(HALF_ENABLE_CPP11_LONG_LONG) | |||
| #define HALF_ENABLE_CPP11_LONG_LONG 1 | |||
| #endif | |||
| /*#elif defined(__INTEL_COMPILER) //Intel C++ | |||
| #if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) ???????? | |||
| #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||
| #endif | |||
| #if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) ???????? | |||
| #define HALF_ENABLE_CPP11_CONSTEXPR 1 | |||
| #endif | |||
| #if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) ???????? | |||
| #define HALF_ENABLE_CPP11_NOEXCEPT 1 | |||
| #endif | |||
| #if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_LONG_LONG) ???????? | |||
| #define HALF_ENABLE_CPP11_LONG_LONG 1 | |||
| #endif*/ | |||
| #elif defined(__GNUC__) //gcc | |||
| #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L | |||
| #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||
| #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||
| #endif | |||
| #if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) | |||
| #define HALF_ENABLE_CPP11_CONSTEXPR 1 | |||
| #endif | |||
| #if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) | |||
| #define HALF_ENABLE_CPP11_NOEXCEPT 1 | |||
| #endif | |||
| #if HALF_GNUC_VERSION >= 407 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) | |||
| #define HALF_ENABLE_CPP11_USER_LITERALS 1 | |||
| #endif | |||
| #if !defined(HALF_ENABLE_CPP11_LONG_LONG) | |||
| #define HALF_ENABLE_CPP11_LONG_LONG 1 | |||
| #endif | |||
| #endif | |||
| #elif defined(_MSC_VER) //Visual C++ | |||
| #if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||
| #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||
| #endif | |||
| #if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG) | |||
| #define HALF_ENABLE_CPP11_LONG_LONG 1 | |||
| #endif | |||
| #define HALF_POP_WARNINGS 1 | |||
| #pragma warning(push) | |||
| //! 4521 and 4522 is multiple copy/assigment operator specified | |||
| #pragma warning(disable : 4099 4127 4146 4521 4522) //struct vs class, constant in if, negative unsigned | |||
| #endif | |||
| //check C++11 library features | |||
| #include <utility> | |||
| #if defined(_LIBCPP_VERSION) //libc++ | |||
| #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 | |||
| #ifndef HALF_ENABLE_CPP11_TYPE_TRAITS | |||
| #define HALF_ENABLE_CPP11_TYPE_TRAITS 1 | |||
| #endif | |||
| #ifndef HALF_ENABLE_CPP11_CSTDINT | |||
| #define HALF_ENABLE_CPP11_CSTDINT 1 | |||
| #endif | |||
| #ifndef HALF_ENABLE_CPP11_CMATH | |||
| #define HALF_ENABLE_CPP11_CMATH 1 | |||
| #endif | |||
| #ifndef HALF_ENABLE_CPP11_HASH | |||
| #define HALF_ENABLE_CPP11_HASH 1 | |||
| #endif | |||
| #endif | |||
| #elif defined(__GLIBCXX__) //libstdc++ | |||
| #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 | |||
| #ifdef __clang__ | |||
| #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) | |||
| #define HALF_ENABLE_CPP11_TYPE_TRAITS 1 | |||
| #endif | |||
| #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT) | |||
| #define HALF_ENABLE_CPP11_CSTDINT 1 | |||
| #endif | |||
| #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH) | |||
| #define HALF_ENABLE_CPP11_CMATH 1 | |||
| #endif | |||
| #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH) | |||
| #define HALF_ENABLE_CPP11_HASH 1 | |||
| #endif | |||
| #else | |||
| #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CSTDINT) | |||
| #define HALF_ENABLE_CPP11_CSTDINT 1 | |||
| #endif | |||
| #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CMATH) | |||
| #define HALF_ENABLE_CPP11_CMATH 1 | |||
| #endif | |||
| #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH) | |||
| #define HALF_ENABLE_CPP11_HASH 1 | |||
| #endif | |||
| #endif | |||
| #endif | |||
| #elif defined(_CPPLIB_VER) //Dinkumware/Visual C++ | |||
| #if _CPPLIB_VER >= 520 | |||
| #ifndef HALF_ENABLE_CPP11_TYPE_TRAITS | |||
| #define HALF_ENABLE_CPP11_TYPE_TRAITS 1 | |||
| #endif | |||
| #ifndef HALF_ENABLE_CPP11_CSTDINT | |||
| #define HALF_ENABLE_CPP11_CSTDINT 1 | |||
| #endif | |||
| #ifndef HALF_ENABLE_CPP11_HASH | |||
| #define HALF_ENABLE_CPP11_HASH 1 | |||
| #endif | |||
| #endif | |||
| #if _CPPLIB_VER >= 610 | |||
| #ifndef HALF_ENABLE_CPP11_CMATH | |||
| #define HALF_ENABLE_CPP11_CMATH 1 | |||
| #endif | |||
| #endif | |||
| #endif | |||
| #undef HALF_GNUC_VERSION | |||
| //support constexpr | |||
| #if HALF_ENABLE_CPP11_CONSTEXPR | |||
| #define HALF_CONSTEXPR constexpr | |||
| #define HALF_CONSTEXPR_CONST constexpr | |||
| #else | |||
| #define HALF_CONSTEXPR | |||
| #define HALF_CONSTEXPR_CONST const | |||
| #endif | |||
| //support noexcept | |||
| #if HALF_ENABLE_CPP11_NOEXCEPT | |||
| #define HALF_NOEXCEPT noexcept | |||
| #define HALF_NOTHROW noexcept | |||
| #else | |||
| #define HALF_NOEXCEPT | |||
| #define HALF_NOTHROW throw() | |||
| #endif | |||
| #include <algorithm> | |||
| #include <limits> | |||
| #include <climits> | |||
| #include <cmath> | |||
| #include <cstring> | |||
| #if HALF_ENABLE_CPP11_TYPE_TRAITS | |||
| #include <type_traits> | |||
| #endif | |||
| #if HALF_ENABLE_CPP11_CSTDINT | |||
| #include <cstdint> | |||
| #endif | |||
| #if HALF_ENABLE_CPP11_HASH | |||
| #include <functional> | |||
| #endif | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -30,7 +30,7 @@ def main(): | |||
| w('// generated by gen_cond_take_kern_impls.py') | |||
| w('#include "../kern.inl"') | |||
| w('') | |||
| if dtype == 'dt_float16': | |||
| if dtype == 'dt_float16' or dtype == 'dt_bfloat16': | |||
| w('#if !MEGDNN_DISABLE_FLOAT16') | |||
| w('namespace megdnn {') | |||
| w('namespace cuda {') | |||
| @@ -48,7 +48,7 @@ def main(): | |||
| w('} // cond_take') | |||
| w('} // cuda') | |||
| w('} // megdnn') | |||
| if dtype == 'dt_float16': | |||
| if dtype == 'dt_float16' or dtype == 'dt_bfloat16': | |||
| w('#endif') | |||
| print('generated {}'.format(fname)) | |||
| @@ -34,7 +34,7 @@ def main(): | |||
| w = lambda s: print(s, file=fout) | |||
| w('// generated by gen_elemwise_kern_impls.py') | |||
| if ctype == 'dt_float16': | |||
| if ctype == 'dt_float16' or ctype == 'dt_bfloat16': | |||
| w('#if !MEGDNN_DISABLE_FLOAT16') | |||
| w('#define KERN_IMPL_MODE(cb) {}'.format(formode)) | |||
| @@ -42,7 +42,7 @@ def main(): | |||
| w('#define KERN_IMPL_CTYPE {}'.format(ctype)) | |||
| w('#include "../kern_impl.inl"') | |||
| if ctype == 'dt_float16': | |||
| if ctype == 'dt_float16' or ctype == 'dt_bfloat16': | |||
| w('#endif') | |||
| print('generated {}'.format(fname)) | |||
| @@ -30,14 +30,14 @@ def main(): | |||
| w = lambda s: print(s, file=fout) | |||
| w('// generated by gen_elemwise_special_kern_impls.py') | |||
| if dtype == 'dt_float16': | |||
| if dtype == 'dt_float16' or dtype == 'dt_bfloat16': | |||
| w('#if !MEGDNN_DISABLE_FLOAT16') | |||
| w('#include "../special_kerns.inl"') | |||
| w('INST(::megdnn::dtype::{})'.format(DTYPES[dtype][0])) | |||
| w('#undef INST') | |||
| w('}') | |||
| w('}') | |||
| if dtype == 'dt_float16': | |||
| if dtype == 'dt_float16' or dtype == 'dt_bfloat16': | |||
| w('#endif') | |||
| print('generated {}'.format(fname)) | |||
| @@ -6,7 +6,8 @@ DTYPES = {'dt_int32': ('Int32', 'INT'), | |||
| 'dt_int8': ('Int8', 'INT'), | |||
| 'dt_int16': ('Int16', 'INT'), | |||
| 'dt_float32': ('Float32', 'FLOAT'), | |||
| 'dt_float16': ('Float16', 'FLOAT') | |||
| 'dt_float16': ('Float16', 'FLOAT'), | |||
| 'dt_bfloat16': ('BFloat16', 'FLOAT') | |||
| } | |||
| MODES = { | |||
| @@ -618,9 +618,10 @@ void ConvolutionBase<Parameter>::check_or_deduce_dtype_fwd(DType src, | |||
| megdnn_assert(param().compute_mode != Param::ComputeMode::FLOAT32 | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| || src.enumv() == DTypeEnum::Float16 | |||
| || src.enumv() == DTypeEnum::BFloat16 | |||
| #endif | |||
| , | |||
| "ComputeMode::FLOAT32 is only available for Float16 " | |||
| , | |||
| "ComputeMode::FLOAT32 is only available for Float16/BFloat16 " | |||
| "input / output."); | |||
| } | |||
| @@ -1036,9 +1037,10 @@ void ConvolutionBackwardData::deduce_dtype(DType filter, DType diff, | |||
| megdnn_assert(param().compute_mode != Param::ComputeMode::FLOAT32 | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| || filter.enumv() == DTypeEnum::Float16 | |||
| || filter.enumv() == DTypeEnum::BFloat16 | |||
| #endif | |||
| , | |||
| "ComputeMode::FLOAT32 is only available for Float16 " | |||
| , | |||
| "ComputeMode::FLOAT32 is only available for Float16/BFloat16 " | |||
| "input / output."); | |||
| } | |||
| @@ -87,7 +87,8 @@ namespace megdnn { | |||
| //! define kernel for all float types | |||
| #define DEF_KERN_FLOAT(_mode, _imp) \ | |||
| DEF_KERN(dt_float32, _mode, _imp); \ | |||
| MEGDNN_INC_FLOAT16(DEF_KERN(dt_float16, _mode, _imp);) | |||
| MEGDNN_INC_FLOAT16(DEF_KERN(dt_float16, _mode, _imp);) \ | |||
| MEGDNN_INC_FLOAT16(DEF_KERN(dt_bfloat16, _mode, _imp);) | |||
| //! define kernel for all int types | |||
| #define DEF_KERN_INT(_mode, _imp) \ | |||
| @@ -69,11 +69,11 @@ void MatrixMulForward::deduce_layout(const TensorLayout& A, | |||
| C = TensorLayout(TensorShape({A0, B1}), C.dtype); | |||
| } else { | |||
| auto do_deduce = [&](size_t pack_size) { | |||
| megdnn_assert( | |||
| A.ndim == 4 && B.ndim == 3, | |||
| "matmul requires input dimension to be A(4), B(3); get: %s %s", | |||
| A.TensorShape::to_string().c_str(), | |||
| B.TensorShape::to_string().c_str()); | |||
| megdnn_assert(A.ndim == 4 && B.ndim == 3, | |||
| "matmul requires input dimension to be A(4), B(3); " | |||
| "get: %s %s", | |||
| A.TensorShape::to_string().c_str(), | |||
| B.TensorShape::to_string().c_str()); | |||
| A0 = A.shape[0]; | |||
| A1 = A.shape[1]; | |||
| B0 = B.shape[0]; | |||
| @@ -82,11 +82,11 @@ void MatrixMulForward::deduce_layout(const TensorLayout& A, | |||
| std::swap(A0, A1); | |||
| if (m_param.transposeB) | |||
| std::swap(B0, B1); | |||
| megdnn_assert( | |||
| A1 == B0, | |||
| "shape mismatch in matmal: (transposed) A is (%zu,%zu,4,4), " | |||
| "(transposed) B is (%zu,%zu,4)", | |||
| A0, A1, B0, B1); | |||
| megdnn_assert(A1 == B0, | |||
| "shape mismatch in matmal: (transposed) A is " | |||
| "(%zu,%zu,4,4), " | |||
| "(transposed) B is (%zu,%zu,4)", | |||
| A0, A1, B0, B1); | |||
| C = TensorLayout(TensorShape({A0, B1, pack_size}), C.dtype); | |||
| }; | |||
| do_deduce(pack_size(param().format)); | |||
| @@ -172,8 +172,9 @@ void MatrixMulForward::check_exec(const TensorLayout& A, const TensorLayout& B, | |||
| } | |||
| megdnn_assert(param().compute_mode != | |||
| Param::ComputeMode::FLOAT32 MEGDNN_INC_FLOAT16( | |||
| || A.dtype == dtype::Float16()), | |||
| "ComputeMode::FLOAT32 is only available for Float16 " | |||
| || A.dtype == dtype::Float16() || | |||
| A.dtype == dtype::BFloat16()), | |||
| "ComputeMode::FLOAT32 is only available for Float16/BFloat16 " | |||
| "input / output."); | |||
| auto required_workspace_in_bytes = get_workspace_in_bytes(A, B, C); | |||
| megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
| @@ -46,6 +46,14 @@ struct RoundingConverter<half_float::half> { | |||
| } | |||
| }; | |||
| template <> | |||
| struct RoundingConverter<half_bfloat16::bfloat16> { | |||
| __host__ __device__ __forceinline__ half_bfloat16::bfloat16 operator()( | |||
| float x) const { | |||
| return static_cast<half_bfloat16::bfloat16>(x); | |||
| } | |||
| }; | |||
| #endif // #ifdef MEGDNN_DISABLE_FLOAT16 | |||
| template <> | |||
| @@ -16,6 +16,7 @@ | |||
| #include "megdnn/dtype.h" | |||
| #include "megdnn/handle.h" | |||
| #include "megdnn/thin/small_vector.h" | |||
| #include "megdnn/oprs/general.h" | |||
| #include "src/common/hash_ct.h" | |||
| #include "src/common/utils.cuh" | |||
| @@ -548,6 +549,59 @@ public: | |||
| std::string to_string() const; | |||
| }; | |||
| /**! | |||
| * \brief helpers for oprs using typecvt between comp_type and dst_type | |||
| * \tparam SrcType src type | |||
| * \tparam CompType compute type, such as fp32 for conv | |||
| * \tparam DstType dst type | |||
| */ | |||
| template <typename SrcType, typename CompType, typename DstType = SrcType> | |||
| struct CompTypeCvter { | |||
| std::unique_ptr<TypeCvt> m_cvt_opr; | |||
| WorkspaceBundle* m_workspace_bundle; | |||
| size_t m_workspace_idx; | |||
| CompTypeCvter(Handle* handle, WorkspaceBundle* bundle) | |||
| : m_workspace_bundle(bundle), m_workspace_idx(0) { | |||
| megdnn_assert( | |||
| (DTypeTrait<SrcType>::enumv != DTypeTrait<CompType>::enumv && | |||
| DTypeTrait<DstType>::enumv != DTypeTrait<CompType>::enumv), | |||
| "SrcType(%s) == CompType(%s) or DstType(%s) == CompType(%s) is " | |||
| "not " | |||
| "supportted.", | |||
| SrcType().name(), CompType().name(), DstType().name(), | |||
| CompType().name()); | |||
| m_cvt_opr = handle->create_operator<TypeCvt>(); | |||
| } | |||
| //! Convert tensor dtype from SrcType to CompType. | |||
| CompTypeCvter& src_to_comp_type(const TensorND& src, TensorND& comp) { | |||
| if (src.layout.dtype.enumv() == DTypeTrait<SrcType>::enumv) { | |||
| if (!comp.layout.dtype.valid() || | |||
| comp.layout.dtype.enumv() != DTypeTrait<CompType>::enumv) { | |||
| comp.layout.dtype = CompType(); | |||
| comp.layout.init_contiguous_stride(); | |||
| comp.raw_ptr = m_workspace_bundle->get(m_workspace_idx++); | |||
| if (src.layout.ndim) { | |||
| m_cvt_opr->exec(src, comp); | |||
| } | |||
| } | |||
| } | |||
| return *this; | |||
| } | |||
| //! Convert tensor dtype from CompType to DstType. | |||
| CompTypeCvter& comp_to_dst_type(const TensorND& comp, const TensorND& dst) { | |||
| megdnn_assert(comp.layout.dtype.enumv() == DTypeTrait<CompType>::enumv); | |||
| if (dst.layout.dtype.enumv() == DTypeTrait<DstType>::enumv) { | |||
| m_cvt_opr->exec(comp, dst); | |||
| } | |||
| return *this; | |||
| } | |||
| Workspace workspace() { | |||
| return m_workspace_bundle->get_workspace(m_workspace_idx); | |||
| } | |||
| }; | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -55,17 +55,19 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout &src, | |||
| megdnn_assert(mat.shape[2] == 3_z, "%s", errmsg().c_str()); | |||
| if (param().format == param::WarpPerspective::Format::NCHW) { | |||
| megdnn_assert(src.dtype.enumv() == DTypeEnum::Float32 || | |||
| MEGDNN_FLOAT16_SELECT( | |||
| src.dtype.enumv() == DTypeEnum::Float16, | |||
| false) || | |||
| src.dtype.enumv() == DTypeEnum::Int8 || | |||
| src.dtype.enumv() == DTypeEnum::Uint8 || | |||
| (src.dtype.enumv() == DTypeEnum::QuantizedS8 || | |||
| src.dtype.enumv() == DTypeEnum::Quantized8Asymm), | |||
| "WarpPerspective NCHW input dtype should be " | |||
| "Float32/Int8/Uint8/QInt8/QUint8" MEGDNN_FLOAT16_SELECT( | |||
| "/Float16", "") "."); | |||
| megdnn_assert( | |||
| src.dtype.enumv() == DTypeEnum::Float32 || | |||
| MEGDNN_FLOAT16_SELECT( | |||
| (src.dtype.enumv() == DTypeEnum::Float16 || | |||
| src.dtype.enumv() == DTypeEnum::BFloat16), | |||
| false) || | |||
| src.dtype.enumv() == DTypeEnum::Int8 || | |||
| src.dtype.enumv() == DTypeEnum::Uint8 || | |||
| (src.dtype.enumv() == DTypeEnum::QuantizedS8 || | |||
| src.dtype.enumv() == DTypeEnum::Quantized8Asymm), | |||
| "WarpPerspective NCHW input dtype should be " | |||
| "Float32/Int8/Uint8/QInt8/QUint8" MEGDNN_FLOAT16_SELECT( | |||
| "/Float16/BFloat16", "") "."); | |||
| megdnn_assert( | |||
| (src.dtype.category() == DTypeCategory::FLOAT && | |||
| (src.dtype == mat.dtype || | |||
| @@ -107,14 +109,17 @@ void WarpPerspectiveBase::check_layout_fwd(const TensorLayout &src, | |||
| param::WarpPerspective::BorderMode::ISOLATED); | |||
| } else { | |||
| megdnn_assert(param().format == param::WarpPerspective::Format::NHWCD4); | |||
| megdnn_assert(src.dtype == dtype::Float32() || | |||
| MEGDNN_FLOAT16_SELECT( | |||
| src.dtype == dtype::Float16(), false) || | |||
| src.dtype.enumv() == DTypeEnum::QuantizedS8 || | |||
| src.dtype.enumv() == DTypeEnum::Quantized8Asymm, | |||
| "WarpPerspective NHWCD4 input dtype should be " | |||
| "Float32" MEGDNN_FLOAT16_SELECT( | |||
| "/Float16", "") ",QunatizedS8, Quantized8Asymm."); | |||
| megdnn_assert( | |||
| src.dtype == dtype::Float32() || | |||
| MEGDNN_FLOAT16_SELECT((src.dtype == dtype::Float16() || | |||
| src.dtype == dtype::BFloat16()), | |||
| false) || | |||
| src.dtype.enumv() == DTypeEnum::QuantizedS8 || | |||
| src.dtype.enumv() == DTypeEnum::Quantized8Asymm, | |||
| "WarpPerspective NHWCD4 input dtype should be " | |||
| "Float32" MEGDNN_FLOAT16_SELECT( | |||
| "/Float16/BFloat16", | |||
| "") ",QunatizedS8, Quantized8Asymm."); | |||
| megdnn_assert( | |||
| (src.dtype == mat.dtype || mat.dtype == dtype::Float32()), | |||
| "The input to WarpPerspective is in NHWCD4 format, in this " | |||
| @@ -253,30 +258,30 @@ void WarpPerspectiveForward::check_exec_allow_nhwc_mat_idx( | |||
| } | |||
| } | |||
| void WarpPerspectiveBackwardData::check_exec(const TensorLayout &mat, | |||
| const TensorLayout &diff, | |||
| const TensorLayout &grad, | |||
| size_t workspace_in_bytes) | |||
| { | |||
| void WarpPerspectiveBackwardData::check_exec(const TensorLayout& mat, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad, | |||
| size_t workspace_in_bytes) { | |||
| check_layout_fwd(grad, mat, diff); | |||
| megdnn_assert(grad.dtype == dtype::Float32(), | |||
| "Backward WarpPerspective only supports Float32."); | |||
| megdnn_assert(grad.dtype == dtype::Float32() MEGDNN_INC_FLOAT16( | |||
| || grad.dtype == dtype::BFloat16()), | |||
| "Backward WarpPerspective only supports Float32/BFloat16."); | |||
| auto required_workspace_in_bytes = get_workspace_in_bytes(mat, diff, grad); | |||
| megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
| } | |||
| void WarpPerspectiveBackwardMat::check_exec(const TensorLayout &src, | |||
| const TensorLayout &mat, | |||
| const TensorLayout &diff, | |||
| const TensorLayout &grad, | |||
| size_t workspace_in_bytes) | |||
| { | |||
| void WarpPerspectiveBackwardMat::check_exec(const TensorLayout& src, | |||
| const TensorLayout& mat, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad, | |||
| size_t workspace_in_bytes) { | |||
| check_layout_fwd(src, mat, diff); | |||
| megdnn_assert_eq_layout(mat, grad); | |||
| megdnn_assert(grad.dtype == dtype::Float32(), | |||
| "Backward WarpPerspective only supports Float32."); | |||
| auto required_workspace_in_bytes = get_workspace_in_bytes(src, | |||
| mat, diff, grad); | |||
| megdnn_assert(grad.dtype == dtype::Float32() MEGDNN_INC_FLOAT16( | |||
| || grad.dtype == dtype::BFloat16()), | |||
| "Backward WarpPerspective only supports Float32/BFloat16."); | |||
| auto required_workspace_in_bytes = | |||
| get_workspace_in_bytes(src, mat, diff, grad); | |||
| megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
| } | |||
| @@ -0,0 +1,29 @@ | |||
| /** | |||
| * \file dnn/src/cuda/cond_take/kimpl/dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_cond_take_kern_impls.py | |||
| #include "../kern.inl" | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| namespace megdnn { | |||
| namespace cuda { | |||
| namespace cond_take { | |||
| inst_genidx(::megdnn::dtype::BFloat16) | |||
| #undef inst_genidx | |||
| inst_copy(::megdnn::dtype::BFloat16) | |||
| #undef inst_copy | |||
| #undef inst_copy_ | |||
| } // cond_take | |||
| } // cuda | |||
| } // megdnn | |||
| #endif | |||
| @@ -62,6 +62,13 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() { | |||
| non_cudnn_algos.push_back(all_algos.rbegin()[1]); // group batched_matmul | |||
| non_cudnn_algos.push_back(all_algos.rbegin()[0]); // group 1x1 | |||
| algo_size = all_algos.size(); | |||
| for (size_t i = 0; i < algo_size; ++i) { | |||
| bfloat16_refhold.emplace_back(new AlgoBFloat16(all_algos[i])); | |||
| all_algos.push_back(bfloat16_refhold.back().get()); | |||
| bfloat16_algos.push_back(bfloat16_refhold.back().get()); | |||
| } | |||
| size_t all_algo_size = all_algos.size(); | |||
| #if CUDA_VERSION >= 10000 | |||
| fill_imma_algos(); | |||
| @@ -499,6 +499,28 @@ private: | |||
| }; | |||
| #endif | |||
| class ConvBiasForwardImpl::AlgoBFloat16 final : public AlgoBase { | |||
| public: | |||
| AlgoBFloat16(AlgoBase* impl); | |||
| bool is_available(const SizeArgs& args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
| void exec(const ExecArgs& args) const override; | |||
| const char* name() const override { return m_name.c_str(); } | |||
| bool is_reproducible() const override { return m_impl->is_reproducible(); } | |||
| private: | |||
| SizeArgs float_args(const SizeArgs& args, ConvBiasForwardImpl* opr, | |||
| TensorLayout& fsrc, TensorLayout& ffilter, | |||
| TensorLayout& fbias, TensorLayout& fz, | |||
| TensorLayout& fdst) const; | |||
| WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | |||
| AlgoBase* m_impl; | |||
| std::string m_name; | |||
| }; | |||
| class ConvBiasForwardImpl::AlgoPack { | |||
| AlgoPack(const AlgoPack&) = delete; | |||
| AlgoPack& operator=(const AlgoPack&) = delete; | |||
| @@ -508,7 +530,8 @@ public: | |||
| std::vector<AlgoBase*> all_algos, | |||
| //! non-cudnn algos, used for heuristic if cudnn is not supported | |||
| non_cudnn_algos; | |||
| non_cudnn_algos, | |||
| bfloat16_algos; | |||
| std::vector<AlgoCUDNNConvBiasActivation> cudnn_conv_bias_activations; | |||
| std::vector<AlgoCUDNNConv> cudnn_convs; | |||
| AlgoChanwise chanwise; | |||
| @@ -531,6 +554,7 @@ public: | |||
| int8_chwn4_imma_unroll_width; | |||
| #endif | |||
| std::vector<std::unique_ptr<AlgoGroupConvGeneral>> gconv_refhold; | |||
| std::vector<std::unique_ptr<AlgoBFloat16>> bfloat16_refhold; | |||
| std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | |||
| AlgoBase* cudnn_conv_bias_act_from_enum(cudnnConvolutionFwdAlgo_t algo); | |||
| @@ -0,0 +1,120 @@ | |||
| /** | |||
| * \file dnn/src/cuda/conv_bias/bfloat16.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "src/cuda/conv_bias/algo.h" | |||
| #include "src/cuda/handle.h" | |||
| #include "src/cuda/utils.cuh" | |||
| #include "src/cuda/utils.h" | |||
| using namespace megdnn; | |||
| using namespace cuda; | |||
| using namespace conv_bias; | |||
| ConvBiasForwardImpl::AlgoBFloat16::AlgoBFloat16( | |||
| ConvBiasForwardImpl::AlgoBase* algorithm) | |||
| : m_impl(algorithm) { | |||
| megdnn_assert_internal(algorithm); | |||
| m_name = ssprintf("BFLOAT16:%s", m_impl->name()); | |||
| } | |||
| ConvBiasForwardImpl::AlgoBase::SizeArgs | |||
| ConvBiasForwardImpl::AlgoBFloat16::float_args( | |||
| const SizeArgs& args, ConvBiasForwardImpl* opr, TensorLayout& fsrc, | |||
| TensorLayout& ffilter, TensorLayout& fbias, TensorLayout& fz, | |||
| TensorLayout& fdst) const { | |||
| fsrc = *args.src_layout; | |||
| ffilter = *args.filter_layout; | |||
| fbias = *args.bias_layout; | |||
| fz = *args.z_layout; | |||
| fdst = *args.dst_layout; | |||
| auto change_dtype = [](TensorLayout& layout) { | |||
| if (layout.dtype == dtype::BFloat16()) { | |||
| layout.dtype = dtype::Float32(); | |||
| } | |||
| }; | |||
| change_dtype(fsrc); | |||
| change_dtype(ffilter); | |||
| change_dtype(fbias); | |||
| change_dtype(fz); | |||
| change_dtype(fdst); | |||
| opr->param() = args.opr->param(); | |||
| opr->param().compute_mode = Param::ComputeMode::DEFAULT; | |||
| opr->execution_policy() = {m_impl}; | |||
| return SizeArgs(opr, fsrc, ffilter, fbias, fz, fdst); | |||
| } | |||
| bool ConvBiasForwardImpl::AlgoBFloat16::is_available( | |||
| const SizeArgs& args) const { | |||
| TensorLayout fsrc, ffilter, fbias, fz, fdst; | |||
| auto convbias_opr = args.handle->create_operator<ConvBias>(); | |||
| SizeArgs fargs = float_args( | |||
| args, static_cast<ConvBiasForwardImpl*>(convbias_opr.get()), fsrc, | |||
| ffilter, fbias, fz, fdst); | |||
| return args.src_layout->dtype == args.filter_layout->dtype && | |||
| args.src_layout->dtype == dtype::BFloat16() && | |||
| m_impl->is_available(fargs); | |||
| } | |||
| WorkspaceBundle ConvBiasForwardImpl::AlgoBFloat16::get_workspace_bundle( | |||
| void* ptr, const SizeArgs& args) const { | |||
| TensorLayout fsrc, ffilter, fbias, fz, fdst; | |||
| auto convbias_opr = args.handle->create_operator<ConvBias>(); | |||
| SizeArgs fargs = float_args( | |||
| args, static_cast<ConvBiasForwardImpl*>(convbias_opr.get()), fsrc, | |||
| ffilter, fbias, fz, fdst); | |||
| SmallVector<size_t> sizes; | |||
| auto get_workspace = [&sizes](const TensorLayout& src, | |||
| const TensorLayout& dst) { | |||
| if (src.dtype != dst.dtype) { | |||
| sizes.push_back(dst.span().dist_byte()); | |||
| } | |||
| }; | |||
| get_workspace(*args.src_layout, fsrc); | |||
| get_workspace(*args.filter_layout, ffilter); | |||
| get_workspace(*args.bias_layout, fbias); | |||
| get_workspace(*args.z_layout, fz); | |||
| get_workspace(*args.dst_layout, fdst); | |||
| sizes.push_back(m_impl->get_workspace_in_bytes(fargs)); | |||
| return {ptr, std::move(sizes)}; | |||
| } | |||
| size_t ConvBiasForwardImpl::AlgoBFloat16::get_workspace_in_bytes( | |||
| const SizeArgs& args) const { | |||
| return get_workspace_bundle(nullptr, args).total_size_in_bytes(); | |||
| } | |||
| void ConvBiasForwardImpl::AlgoBFloat16::exec(const ExecArgs& args) const { | |||
| TensorND fsrc_tensor = *args.src_tensor; | |||
| TensorND ffilter_tensor = *args.filter_tensor; | |||
| TensorND fbias_tensor = *args.bias_tensor; | |||
| TensorND fz_tensor = *args.z_tensor; | |||
| TensorND fdst_tensor = *args.dst_tensor; | |||
| auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | |||
| CompTypeCvter<dtype::BFloat16, dtype::Float32> cvter(args.handle, &bundle); | |||
| { | |||
| cvter.src_to_comp_type(*args.src_tensor, fsrc_tensor) | |||
| .src_to_comp_type(*args.filter_tensor, ffilter_tensor) | |||
| .src_to_comp_type(*args.bias_tensor, fbias_tensor) | |||
| .src_to_comp_type(*args.z_tensor, fz_tensor) | |||
| .src_to_comp_type(*args.dst_tensor, fdst_tensor); | |||
| } | |||
| { | |||
| auto convbias_opr = args.handle->create_operator<ConvBias>(); | |||
| convbias_opr->param() = args.opr->param(); | |||
| convbias_opr->param().compute_mode = Param::ComputeMode::DEFAULT; | |||
| convbias_opr->execution_policy() = {m_impl}; | |||
| convbias_opr->exec(fsrc_tensor, ffilter_tensor, fbias_tensor, fz_tensor, | |||
| fdst_tensor, cvter.workspace()); | |||
| } | |||
| { cvter.comp_to_dst_type(fdst_tensor, *args.dst_tensor); } | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -20,6 +20,10 @@ using namespace conv_bias; | |||
| bool ConvBiasForwardImpl::AlgoChanwise::is_available( | |||
| const SizeArgs& args) const { | |||
| if (args.src_layout->dtype == args.filter_layout->dtype && | |||
| args.src_layout->dtype == dtype::BFloat16()) { | |||
| return false; | |||
| } | |||
| if (args.z_layout->ndim > 0) | |||
| return false; | |||
| @@ -30,6 +30,10 @@ inline bool is_available_small(const chanwise::Param& param) { | |||
| bool ConvBiasForwardImpl::AlgoChanwiseSmall::is_available( | |||
| const SizeArgs& args) const { | |||
| if (args.src_layout->dtype == args.filter_layout->dtype && | |||
| args.src_layout->dtype == dtype::BFloat16()) { | |||
| return false; | |||
| } | |||
| if (args.z_layout->ndim > 0) | |||
| return false; | |||
| #if CUDA_VERSION < 9000 | |||
| @@ -23,6 +23,10 @@ using namespace conv_bias; | |||
| bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available( | |||
| const SizeArgs& args) const { | |||
| if (args.src_layout->dtype == args.filter_layout->dtype && | |||
| args.src_layout->dtype == dtype::BFloat16()) { | |||
| return false; | |||
| } | |||
| if (args.bias_layout->ndim == 0 || | |||
| args.bias_layout->eq_shape(*args.dst_layout)) | |||
| return false; | |||
| @@ -50,6 +50,10 @@ ConvBiasForwardImpl::AlgoGroupConvGeneral::AlgoGroupConvGeneral(AlgoBase* impl) | |||
| bool ConvBiasForwardImpl::AlgoGroupConvGeneral::is_available( | |||
| const SizeArgs& args) const { | |||
| if (args.src_layout->dtype == args.filter_layout->dtype && | |||
| args.src_layout->dtype == dtype::BFloat16()) { | |||
| return false; | |||
| } | |||
| if (args.z_layout->ndim > 0 || args.filter_meta.group <= 1) | |||
| return false; | |||
| auto&& param = args.opr->param(); | |||
| @@ -136,6 +136,11 @@ void ConvBiasDesc::set_conv(DType data_type, const param::ConvBias& param, | |||
| namespace conv_bias { | |||
| bool is_cudnn_supported(const BiasForwardSizeArgs& args) { | |||
| if (args.src_layout->dtype == args.filter_layout->dtype && | |||
| args.src_layout->dtype == dtype::BFloat16()) { | |||
| return false; | |||
| } | |||
| // CUDNN_STATUS_EXECUTION_FAILED on Tegra K1, so disable CUDNN | |||
| // on Tegra K1. | |||
| if (args.handle->is_tegra_k1()) | |||
| @@ -20,6 +20,10 @@ using namespace cuda; | |||
| using namespace conv_bias; | |||
| bool ConvBiasForwardImpl::AlgoMatmul::is_available(const SizeArgs& args) const { | |||
| if (args.src_layout->dtype == args.filter_layout->dtype && | |||
| args.src_layout->dtype == dtype::BFloat16()) { | |||
| return false; | |||
| } | |||
| if (args.z_layout->ndim > 0) | |||
| return false; | |||
| @@ -9,6 +9,7 @@ | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "src/cuda/conv_bias/opr_impl.h" | |||
| #include "megdnn/dtype.h" | |||
| #include "src/cuda/conv_bias/helper.h" | |||
| #include "src/cuda/conv_bias/algo.h" | |||
| #include "src/cuda/handle.h" | |||
| @@ -176,14 +177,26 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | |||
| conv_args = orig_args; | |||
| } | |||
| if (reproducible) { | |||
| return megdnn::get_reproducible_algo<ConvBiasForwardImpl>( | |||
| sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | |||
| "cuda convbias fwd"); | |||
| if (args.src_layout->dtype.enumv() != DTypeTrait<dtype::BFloat16>::enumv) { | |||
| if (reproducible) { | |||
| return megdnn::get_reproducible_algo<ConvBiasForwardImpl>( | |||
| sm_algo_pack.non_cudnn_algos, args, | |||
| workspace_limit_in_bytes, "cuda convbias fwd"); | |||
| } else { | |||
| return megdnn::get_usable_algo<ConvBiasForwardImpl>( | |||
| sm_algo_pack.non_cudnn_algos, args, | |||
| workspace_limit_in_bytes, "cuda convbias fwd"); | |||
| } | |||
| } else { | |||
| return megdnn::get_usable_algo<ConvBiasForwardImpl>( | |||
| sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | |||
| "cuda convbias fwd"); | |||
| if (reproducible) { | |||
| return megdnn::get_reproducible_algo<ConvBiasForwardImpl>( | |||
| sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, | |||
| "cuda convbias fwd"); | |||
| } else { | |||
| return megdnn::get_usable_algo<ConvBiasForwardImpl>( | |||
| sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, | |||
| "cuda convbias fwd"); | |||
| } | |||
| } | |||
| } | |||
| @@ -57,6 +57,7 @@ public: | |||
| class AlgoInt8NCHW4IMMAImplicitGemm; | |||
| class AlgoInt8CHWN4IMMAImplicitGemmReorderFilter; | |||
| class AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth; | |||
| class AlgoBFloat16; | |||
| class AlgoPack; | |||
| @@ -33,11 +33,12 @@ ConvolutionBackwardDataImpl::AlgoPack::AlgoPack() { | |||
| // add gconv algos by AlgoGroupConvGeneral | |||
| auto all_algos_data = all_algos.data(); | |||
| for (size_t i = 2; i < all_algos.size(); ++ i) { | |||
| size_t group_algo_start = 2; | |||
| for (size_t i = group_algo_start; i < all_algos.size(); ++ i) { | |||
| gconv.push_back({all_algos[i]}); | |||
| } | |||
| for (size_t i = 2; i < all_algos.size(); ++ i) { | |||
| algo2gconv[all_algos[i]] = &gconv[i - 2]; | |||
| for (size_t i = group_algo_start; i < all_algos.size(); ++ i) { | |||
| algo2gconv[all_algos[i]] = &gconv[i - group_algo_start]; | |||
| } | |||
| for (auto &&i: gconv) { | |||
| all_algos.push_back(&i); | |||
| @@ -45,6 +46,12 @@ ConvolutionBackwardDataImpl::AlgoPack::AlgoPack() { | |||
| megdnn_assert(all_algos_data == all_algos.data()); | |||
| non_cudnn_algos.push_back(all_algos.rbegin()[0]); // group matmul | |||
| size_t algo_size = all_algos.size(); | |||
| for (size_t i=0; i<algo_size; ++i) { | |||
| bfloat16_refhold.emplace_back(new AlgoBFloat16(all_algos[i])); | |||
| all_algos.push_back(bfloat16_refhold.back().get()); | |||
| bfloat16_algos.push_back(bfloat16_refhold.back().get()); | |||
| } | |||
| } | |||
| ConvolutionBackwardDataImpl::AlgoCUDNN* | |||
| @@ -65,18 +72,19 @@ ConvolutionBackwardDataImpl::AlgoBase::SizeArgs::SizeArgs( | |||
| ConvolutionBackwardDataImpl *o, | |||
| const TensorLayout &filter, const TensorLayout &diff, | |||
| const TensorLayout &grad): | |||
| SizeArgs(o, o->check_layout_fwd(grad, filter, diff), diff, grad) | |||
| SizeArgs(o, filter, o->check_layout_fwd(grad, filter, diff), diff, grad) | |||
| { | |||
| } | |||
| ConvolutionBackwardDataImpl::AlgoBase::SizeArgs::SizeArgs( | |||
| ConvolutionBackwardDataImpl *o, | |||
| const CanonizedFilterMeta &filter, const TensorLayout &diff, | |||
| ConvolutionBackwardDataImpl *o, const TensorLayout& filter, | |||
| const CanonizedFilterMeta &filter_meta, const TensorLayout &diff, | |||
| const TensorLayout &grad): | |||
| handle{concrete_handle(o->handle())}, | |||
| filter_meta{filter}, | |||
| filter_meta{filter_meta}, | |||
| diff_layout{&diff}, | |||
| grad_layout{&grad}, | |||
| filter_layout{&filter}, | |||
| opr{o} | |||
| { | |||
| } | |||
| @@ -31,22 +31,24 @@ class ConvolutionBackwardDataImpl::AlgoBase: public Algorithm { | |||
| struct SizeArgs { | |||
| HandleImpl *handle; | |||
| CanonizedFilterMeta filter_meta; | |||
| const TensorLayout *diff_layout, *grad_layout; | |||
| const TensorLayout *diff_layout, *grad_layout, *filter_layout; | |||
| ConvolutionBackwardDataImpl *opr; | |||
| std::string to_string() const; | |||
| void init_desc(convolution::CUDNNBwdDataDescs &desc) const { | |||
| desc.set(filter_meta, *diff_layout, *grad_layout, opr->param()); | |||
| } | |||
| SizeArgs(ConvolutionBackwardDataImpl *opr, | |||
| const TensorLayout &filter, const TensorLayout &diff, | |||
| const TensorLayout &grad); | |||
| SizeArgs(ConvolutionBackwardDataImpl *opr, | |||
| const CanonizedFilterMeta &filter, const TensorLayout &diff, | |||
| const TensorLayout &grad); | |||
| SizeArgs(ConvolutionBackwardDataImpl* opr, | |||
| const TensorLayout& filter, const TensorLayout& diff, | |||
| const TensorLayout& grad); | |||
| SizeArgs(ConvolutionBackwardDataImpl* opr, | |||
| const TensorLayout& filter, | |||
| const CanonizedFilterMeta& filter_meta, | |||
| const TensorLayout& diff, const TensorLayout& grad); | |||
| convolution::ForwardSizeArgs as_fwd_args() const { | |||
| return {handle, grad_layout, filter_meta, diff_layout}; | |||
| return {handle, grad_layout, filter_layout, filter_meta, | |||
| diff_layout}; | |||
| } | |||
| }; | |||
| struct ExecArgs: public SizeArgs { | |||
| @@ -170,6 +172,25 @@ class ConvolutionBackwardDataImpl::AlgoChanwiseSmall final: public AlgoBase { | |||
| } | |||
| }; | |||
| class ConvolutionBackwardDataImpl::AlgoBFloat16 final : public AlgoBase { | |||
| public: | |||
| AlgoBFloat16(ConvolutionBackwardDataImpl::AlgoBase*); | |||
| bool is_available(const SizeArgs& args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
| void exec(const ExecArgs& args) const override; | |||
| const char* name() const override { return m_name.c_str(); } | |||
| bool is_reproducible() const override { return true; } | |||
| private: | |||
| std::string m_name; | |||
| ConvolutionBackwardDataImpl::AlgoBase* m_algorithm = nullptr; | |||
| SizeArgs float_args(const SizeArgs& args, ConvolutionBackwardDataImpl* opr, | |||
| TensorLayout& fsrc, TensorLayout& ffilter, | |||
| TensorLayout& fdst) const; | |||
| WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | |||
| }; | |||
| //! implement group conv by another algo | |||
| class ConvolutionBackwardDataImpl::AlgoGroupConvGeneral final: public AlgoBase { | |||
| AlgoBase *m_impl; | |||
| @@ -210,12 +231,14 @@ class ConvolutionBackwardDataImpl::AlgoPack { | |||
| AlgoChanwiseSmall chanwise_small; | |||
| std::vector<AlgoGroupConvGeneral> gconv; | |||
| std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | |||
| std::vector<std::unique_ptr<AlgoBFloat16>> bfloat16_refhold; | |||
| std::vector<AlgoBase*> | |||
| //! all algorithms | |||
| all_algos, | |||
| //! non-cudnn algos, used for heuristic if cudnn is not supported | |||
| non_cudnn_algos; | |||
| non_cudnn_algos, | |||
| bfloat16_algos; | |||
| AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdDataAlgo_t algo); | |||
| }; | |||
| @@ -0,0 +1,115 @@ | |||
| /** | |||
| * \file src/cuda/convolution/backward_data/bfloat16.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "./algo.h" | |||
| #include "src/cuda/convolution/chanwise/kern.cuh" | |||
| #include "src/cuda/utils.h" | |||
| using namespace megdnn; | |||
| using namespace cuda; | |||
| using namespace convolution; | |||
| ConvolutionBackwardDataImpl::AlgoBFloat16::AlgoBFloat16( | |||
| ConvolutionBackwardDataImpl::AlgoBase* algorithm) | |||
| : m_algorithm(algorithm) { | |||
| megdnn_assert_internal(algorithm); | |||
| m_name = ssprintf("CONVOLUTION_BACKWARD_DATD_BFLOAT16:%s", | |||
| m_algorithm->name()); | |||
| } | |||
| ConvolutionBackwardDataImpl::AlgoBase::SizeArgs | |||
| ConvolutionBackwardDataImpl::AlgoBFloat16::float_args( | |||
| const SizeArgs& args, ConvolutionBackwardDataImpl* opr, | |||
| TensorLayout& ffilter, TensorLayout& fdiff, TensorLayout& fgrad) const { | |||
| ffilter = *args.filter_layout; | |||
| fdiff = *args.diff_layout; | |||
| fgrad = *args.grad_layout; | |||
| auto change_dtype = [](TensorLayout& layout) { | |||
| if (layout.dtype == dtype::BFloat16()) { | |||
| layout.dtype = dtype::Float32(); | |||
| } | |||
| }; | |||
| change_dtype(ffilter); | |||
| change_dtype(fdiff); | |||
| change_dtype(fgrad); | |||
| opr->param() = args.opr->param(); | |||
| opr->param().compute_mode = Param::ComputeMode::DEFAULT; | |||
| opr->execution_policy() = {m_algorithm}; | |||
| return SizeArgs(opr, ffilter, fdiff, fgrad); | |||
| } | |||
| bool ConvolutionBackwardDataImpl::AlgoBFloat16::is_available( | |||
| const SizeArgs& args) const { | |||
| TensorLayout ffilter, fdiff, fgrad; | |||
| auto conv_back_data_opr = | |||
| args.handle->create_operator<ConvolutionBackwardData>(); | |||
| SizeArgs fargs = float_args( | |||
| args, | |||
| static_cast<ConvolutionBackwardDataImpl*>(conv_back_data_opr.get()), | |||
| ffilter, fdiff, fgrad); | |||
| return args.diff_layout->dtype == args.filter_layout->dtype && | |||
| args.diff_layout->dtype == dtype::BFloat16() && | |||
| m_algorithm->is_available(fargs); | |||
| } | |||
| WorkspaceBundle ConvolutionBackwardDataImpl::AlgoBFloat16::get_workspace_bundle( | |||
| void* ptr, const SizeArgs& args) const { | |||
| TensorLayout ffilter, fdiff, fgrad; | |||
| auto conv_back_data_opr = | |||
| args.handle->create_operator<ConvolutionBackwardData>(); | |||
| SizeArgs fargs = float_args( | |||
| args, | |||
| static_cast<ConvolutionBackwardDataImpl*>(conv_back_data_opr.get()), | |||
| ffilter, fdiff, fgrad); | |||
| SmallVector<size_t> sizes; | |||
| auto get_workspace = [&sizes](const TensorLayout& src, | |||
| const TensorLayout& dst) { | |||
| if (src.dtype != dst.dtype) { | |||
| sizes.push_back(dst.span().dist_byte()); | |||
| } | |||
| }; | |||
| get_workspace(*args.filter_layout, ffilter); | |||
| get_workspace(*args.diff_layout, fdiff); | |||
| get_workspace(*args.grad_layout, fgrad); | |||
| sizes.push_back(m_algorithm->get_workspace_in_bytes(fargs)); | |||
| return {ptr, std::move(sizes)}; | |||
| } | |||
| size_t ConvolutionBackwardDataImpl::AlgoBFloat16::get_workspace_in_bytes( | |||
| const SizeArgs& args) const { | |||
| return get_workspace_bundle(nullptr, args).total_size_in_bytes(); | |||
| } | |||
| void ConvolutionBackwardDataImpl::AlgoBFloat16::exec( | |||
| const ExecArgs& args) const { | |||
| TensorND ffilter_tensor = *args.filter_tensor; | |||
| TensorND fdiff_tensor = *args.diff_tensor; | |||
| TensorND fgrad_tensor = *args.grad_tensor; | |||
| auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | |||
| CompTypeCvter<dtype::BFloat16, dtype::Float32> cvter(args.handle, &bundle); | |||
| { | |||
| cvter.src_to_comp_type(*args.filter_tensor, ffilter_tensor) | |||
| .src_to_comp_type(*args.diff_tensor, fdiff_tensor) | |||
| .src_to_comp_type(*args.grad_tensor, fgrad_tensor); | |||
| } | |||
| { | |||
| auto conv_back_data_opr = | |||
| args.handle->create_operator<ConvolutionBackwardData>(); | |||
| conv_back_data_opr->param() = args.opr->param(); | |||
| conv_back_data_opr->param().compute_mode = Param::ComputeMode::DEFAULT; | |||
| conv_back_data_opr->execution_policy() = {m_algorithm}; | |||
| conv_back_data_opr->exec(ffilter_tensor, fdiff_tensor, fgrad_tensor, | |||
| cvter.workspace()); | |||
| } | |||
| { cvter.comp_to_dst_type(fgrad_tensor, *args.grad_tensor); } | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -19,6 +19,10 @@ using namespace convolution; | |||
| bool ConvolutionBackwardDataImpl::AlgoChanwise::is_available( | |||
| const SizeArgs& args) const { | |||
| if (args.diff_layout->dtype == args.filter_layout->dtype && | |||
| args.diff_layout->dtype == dtype::BFloat16()) { | |||
| return false; | |||
| } | |||
| auto&& fm = args.filter_meta; | |||
| return args.filter_meta.format == Param::Format::NCHW && | |||
| args.diff_layout->dtype.category() == DTypeCategory::FLOAT && | |||
| @@ -29,6 +29,10 @@ inline bool is_available_small(const chanwise::Param& param) { | |||
| bool ConvolutionBackwardDataImpl::AlgoChanwiseSmall::is_available( | |||
| const SizeArgs &args) const { | |||
| if (args.diff_layout->dtype == args.filter_layout->dtype && | |||
| args.diff_layout->dtype == dtype::BFloat16()) { | |||
| return false; | |||
| } | |||
| #if CUDA_VERSION < 9000 | |||
| if (args.diff_layout->dtype.enumv() == DTypeEnum::Float16) | |||
| return false; | |||
| @@ -38,6 +38,10 @@ ConvolutionBackwardDataImpl::AlgoGroupConvGeneral::AlgoGroupConvGeneral( | |||
| bool ConvolutionBackwardDataImpl::AlgoGroupConvGeneral::is_available( | |||
| const SizeArgs &args) const { | |||
| if (args.diff_layout->dtype == args.filter_layout->dtype && | |||
| args.diff_layout->dtype == dtype::BFloat16()) { | |||
| return false; | |||
| } | |||
| auto sub_args = args; | |||
| TensorLayout diff_pg, grad_pg; | |||
| modify_size_args(sub_args, diff_pg, grad_pg); | |||
| @@ -20,6 +20,10 @@ using namespace cuda; | |||
| bool ConvolutionBackwardDataImpl::AlgoMatmul::is_available( | |||
| const SizeArgs &args) const { | |||
| if (args.diff_layout->dtype == args.filter_layout->dtype && | |||
| args.diff_layout->dtype == dtype::BFloat16()) { | |||
| return false; | |||
| } | |||
| auto &&fm = args.filter_meta; | |||
| return args.filter_meta.format == Param::Format::NCHW && | |||
| args.diff_layout->dtype.category() == DTypeCategory::FLOAT && | |||
| @@ -43,6 +43,12 @@ ConvolutionBackwardFilterImpl::AlgoPack::AlgoPack() { | |||
| megdnn_assert(all_algos_data == all_algos.data()); | |||
| non_cudnn_algos.push_back(all_algos.rbegin()[0]); // group matmul | |||
| size_t algo_size = all_algos.size(); | |||
| for (size_t i=0; i<algo_size; ++i) { | |||
| bfloat16_refhold.emplace_back(new AlgoBFloat16(all_algos[i])); | |||
| all_algos.push_back(bfloat16_refhold.back().get()); | |||
| bfloat16_algos.push_back(bfloat16_refhold.back().get()); | |||
| } | |||
| } | |||
| ConvolutionBackwardFilterImpl::AlgoCUDNN* | |||
| @@ -64,21 +70,20 @@ ConvolutionBackwardFilterImpl::AlgoBase::SizeArgs::SizeArgs( | |||
| ConvolutionBackwardFilterImpl *o, | |||
| const TensorLayout &src, const TensorLayout &diff, | |||
| const TensorLayout &grad): | |||
| SizeArgs(o, src, diff, o->check_layout_fwd(src, grad, diff)) | |||
| SizeArgs(o, src, diff, grad, o->check_layout_fwd(src, grad, diff)) | |||
| { | |||
| } | |||
| ConvolutionBackwardFilterImpl::AlgoBase::SizeArgs::SizeArgs( | |||
| ConvolutionBackwardFilterImpl *o, | |||
| const TensorLayout &src, const TensorLayout &diff, | |||
| const CanonizedFilterMeta &grad): | |||
| handle{concrete_handle(o->handle())}, | |||
| src_layout{&src}, | |||
| diff_layout{&diff}, | |||
| grad_filter_meta{grad}, | |||
| opr{o} | |||
| { | |||
| } | |||
| ConvolutionBackwardFilterImpl* o, const TensorLayout& src, | |||
| const TensorLayout& diff, const TensorLayout& grad, | |||
| const CanonizedFilterMeta& grad_meta) | |||
| : handle{concrete_handle(o->handle())}, | |||
| src_layout{&src}, | |||
| diff_layout{&diff}, | |||
| grad_layout{&grad}, | |||
| grad_filter_meta{grad_meta}, | |||
| opr{o} {} | |||
| ConvolutionBackwardFilterImpl::AlgoBase::ExecArgs::ExecArgs( | |||
| ConvolutionBackwardFilterImpl *opr, | |||
| @@ -30,7 +30,7 @@ class ConvolutionBackwardFilterImpl::AlgoBase: public Algorithm { | |||
| public: | |||
| struct SizeArgs { | |||
| HandleImpl *handle; | |||
| const TensorLayout *src_layout, *diff_layout; | |||
| const TensorLayout *src_layout, *diff_layout, *grad_layout; | |||
| CanonizedFilterMeta grad_filter_meta; | |||
| ConvolutionBackwardFilterImpl *opr; | |||
| @@ -42,12 +42,14 @@ class ConvolutionBackwardFilterImpl::AlgoBase: public Algorithm { | |||
| SizeArgs(ConvolutionBackwardFilterImpl *opr, | |||
| const TensorLayout &src, const TensorLayout &diff, | |||
| const TensorLayout &grad); | |||
| SizeArgs(ConvolutionBackwardFilterImpl *opr, | |||
| const TensorLayout &src, const TensorLayout &diff, | |||
| const CanonizedFilterMeta &grad); | |||
| SizeArgs(ConvolutionBackwardFilterImpl* opr, | |||
| const TensorLayout& src, const TensorLayout& diff, | |||
| const TensorLayout& grad, | |||
| const CanonizedFilterMeta& grad_meta); | |||
| convolution::ForwardSizeArgs as_fwd_args() const { | |||
| return {handle, src_layout, grad_filter_meta, diff_layout}; | |||
| return {handle, src_layout, grad_layout, grad_filter_meta, | |||
| diff_layout}; | |||
| } | |||
| }; | |||
| struct ExecArgs: public SizeArgs { | |||
| @@ -157,6 +159,25 @@ class ConvolutionBackwardFilterImpl::AlgoChanwise final: public AlgoBase { | |||
| } | |||
| }; | |||
| class ConvolutionBackwardFilterImpl::AlgoBFloat16 final : public AlgoBase { | |||
| public: | |||
| AlgoBFloat16(ConvolutionBackwardFilterImpl::AlgoBase*); | |||
| bool is_available(const SizeArgs& args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
| void exec(const ExecArgs& args) const override; | |||
| const char* name() const override { return m_name.c_str(); } | |||
| bool is_reproducible() const override { return true; } | |||
| private: | |||
| std::string m_name; | |||
| ConvolutionBackwardFilterImpl::AlgoBase* m_algorithm = nullptr; | |||
| SizeArgs float_args(const SizeArgs& args, | |||
| ConvolutionBackwardFilterImpl* opr, TensorLayout& fsrc, | |||
| TensorLayout& ffilter, TensorLayout& fdst) const; | |||
| WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | |||
| }; | |||
| //! implement group conv by another algo | |||
| class ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral final: public AlgoBase { | |||
| AlgoBase *m_impl; | |||
| @@ -196,12 +217,14 @@ class ConvolutionBackwardFilterImpl::AlgoPack { | |||
| AlgoChanwise chanwise; | |||
| std::vector<AlgoGroupConvGeneral> gconv; | |||
| std::unordered_map<AlgoBase*, AlgoGroupConvGeneral*> algo2gconv; | |||
| std::vector<std::unique_ptr<AlgoBFloat16>> bfloat16_refhold; | |||
| std::vector<AlgoBase*> | |||
| //! all algorithms | |||
| all_algos, | |||
| //! non-cudnn algos, used for heuristic if cudnn is not supported | |||
| non_cudnn_algos; | |||
| non_cudnn_algos, | |||
| bfloat16_algos; | |||
| AlgoCUDNN* cudnn_from_enum(cudnnConvolutionBwdFilterAlgo_t algo); | |||
| }; | |||
| @@ -0,0 +1,117 @@ | |||
| /** | |||
| * \file src/cuda/convolution/backward_filter/bfloat16.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "./algo.h" | |||
| #include "src/cuda/convolution/chanwise/kern.cuh" | |||
| #include "src/cuda/utils.h" | |||
| using namespace megdnn; | |||
| using namespace cuda; | |||
| using namespace convolution; | |||
| ConvolutionBackwardFilterImpl::AlgoBFloat16::AlgoBFloat16( | |||
| ConvolutionBackwardFilterImpl::AlgoBase* algorithm) | |||
| : m_algorithm(algorithm) { | |||
| megdnn_assert_internal(algorithm); | |||
| m_name = ssprintf("CONVOLUTION_BACKWARD_Filter_BFLOAT16:%s", | |||
| m_algorithm->name()); | |||
| } | |||
| ConvolutionBackwardFilterImpl::AlgoBase::SizeArgs | |||
| ConvolutionBackwardFilterImpl::AlgoBFloat16::float_args( | |||
| const SizeArgs& args, ConvolutionBackwardFilterImpl* opr, | |||
| TensorLayout& fsrc, TensorLayout& fdiff, TensorLayout& fgrad) const { | |||
| fsrc = *args.src_layout; | |||
| fdiff = *args.diff_layout; | |||
| fgrad = *args.grad_layout; | |||
| auto change_dtype = [](TensorLayout& layout) { | |||
| if (layout.dtype == dtype::BFloat16()) { | |||
| layout.dtype = dtype::Float32(); | |||
| } | |||
| }; | |||
| change_dtype(fsrc); | |||
| change_dtype(fdiff); | |||
| change_dtype(fgrad); | |||
| opr->param() = args.opr->param(); | |||
| opr->param().compute_mode = Param::ComputeMode::DEFAULT; | |||
| opr->execution_policy() = {m_algorithm}; | |||
| return SizeArgs(opr, fsrc, fdiff, fgrad); | |||
| } | |||
| bool ConvolutionBackwardFilterImpl::AlgoBFloat16::is_available( | |||
| const SizeArgs& args) const { | |||
| TensorLayout fsrc, fdiff, fgrad; | |||
| auto conv_back_filter_opr = | |||
| args.handle->create_operator<ConvolutionBackwardFilter>(); | |||
| SizeArgs fargs = float_args(args, | |||
| static_cast<ConvolutionBackwardFilterImpl*>( | |||
| conv_back_filter_opr.get()), | |||
| fsrc, fdiff, fgrad); | |||
| return args.src_layout->dtype == args.diff_layout->dtype && | |||
| args.src_layout->dtype == dtype::BFloat16() && | |||
| m_algorithm->is_available(fargs); | |||
| } | |||
| WorkspaceBundle | |||
| ConvolutionBackwardFilterImpl::AlgoBFloat16::get_workspace_bundle( | |||
| void* ptr, const SizeArgs& args) const { | |||
| TensorLayout fsrc, fdiff, fgrad; | |||
| auto conv_back_filter_opr = | |||
| args.handle->create_operator<ConvolutionBackwardFilter>(); | |||
| SizeArgs fargs = float_args(args, | |||
| static_cast<ConvolutionBackwardFilterImpl*>( | |||
| conv_back_filter_opr.get()), | |||
| fsrc, fdiff, fgrad); | |||
| SmallVector<size_t> sizes; | |||
| auto get_workspace = [&sizes](const TensorLayout& src, | |||
| const TensorLayout& dst) { | |||
| if (src.dtype != dst.dtype) { | |||
| sizes.push_back(dst.span().dist_byte()); | |||
| } | |||
| }; | |||
| get_workspace(*args.src_layout, fsrc); | |||
| get_workspace(*args.diff_layout, fdiff); | |||
| get_workspace(*args.grad_layout, fgrad); | |||
| sizes.push_back(m_algorithm->get_workspace_in_bytes(fargs)); | |||
| return {ptr, std::move(sizes)}; | |||
| } | |||
| size_t ConvolutionBackwardFilterImpl::AlgoBFloat16::get_workspace_in_bytes( | |||
| const SizeArgs& args) const { | |||
| return get_workspace_bundle(nullptr, args).total_size_in_bytes(); | |||
| } | |||
| void ConvolutionBackwardFilterImpl::AlgoBFloat16::exec( | |||
| const ExecArgs& args) const { | |||
| TensorND fsrc_tensor = *args.src_tensor; | |||
| TensorND fdiff_tensor = *args.diff_tensor; | |||
| TensorND fgrad_tensor = *args.grad_tensor; | |||
| auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | |||
| CompTypeCvter<dtype::BFloat16, dtype::Float32> cvter(args.handle, &bundle); | |||
| { | |||
| cvter.src_to_comp_type(*args.src_tensor, fsrc_tensor) | |||
| .src_to_comp_type(*args.diff_tensor, fdiff_tensor) | |||
| .src_to_comp_type(*args.grad_tensor, fgrad_tensor); | |||
| } | |||
| { | |||
| auto conv_back_filter_opr = | |||
| args.handle->create_operator<ConvolutionBackwardFilter>(); | |||
| conv_back_filter_opr->param() = args.opr->param(); | |||
| conv_back_filter_opr->param().compute_mode = | |||
| Param::ComputeMode::DEFAULT; | |||
| conv_back_filter_opr->execution_policy() = {m_algorithm}; | |||
| conv_back_filter_opr->exec(fsrc_tensor, fdiff_tensor, fgrad_tensor, | |||
| cvter.workspace()); | |||
| } | |||
| { cvter.comp_to_dst_type(fgrad_tensor, *args.grad_tensor); } | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -19,6 +19,10 @@ using namespace convolution; | |||
| bool ConvolutionBackwardFilterImpl::AlgoChanwise::is_available( | |||
| const SizeArgs &args) const { | |||
| if (args.src_layout->dtype == args.src_layout->dtype && | |||
| args.diff_layout->dtype == dtype::BFloat16()) { | |||
| return false; | |||
| } | |||
| auto &&fm = args.grad_filter_meta; | |||
| return fm.format == Param::Format::NCHW && | |||
| args.diff_layout->dtype.category() == DTypeCategory::FLOAT && | |||
| @@ -38,6 +38,10 @@ ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral::AlgoGroupConvGeneral( | |||
| bool ConvolutionBackwardFilterImpl::AlgoGroupConvGeneral::is_available( | |||
| const SizeArgs &args) const { | |||
| if (args.src_layout->dtype == args.src_layout->dtype && | |||
| args.diff_layout->dtype == dtype::BFloat16()) { | |||
| return false; | |||
| } | |||
| auto sub_args = args; | |||
| TensorLayout src_pg, diff_pg; | |||
| modify_size_args(sub_args, src_pg, diff_pg); | |||
| @@ -19,6 +19,10 @@ using namespace cuda; | |||
| bool ConvolutionBackwardFilterImpl::AlgoMatmul::is_available( | |||
| const SizeArgs &args) const { | |||
| if (args.src_layout->dtype == args.src_layout->dtype && | |||
| args.diff_layout->dtype == dtype::BFloat16()) { | |||
| return false; | |||
| } | |||
| auto &&fm = args.grad_filter_meta; | |||
| return fm.format == Param::Format::NCHW && | |||
| args.diff_layout->dtype.category() == DTypeCategory::FLOAT && | |||
| @@ -16,6 +16,10 @@ using namespace cuda; | |||
| using namespace convolution; | |||
| bool convolution::is_cudnn_supported(const ForwardSizeArgs &args) { | |||
| if (args.src_layout->dtype == args.filter_layout->dtype && | |||
| args.src_layout->dtype == dtype::BFloat16()) { | |||
| return false; | |||
| } | |||
| // CUDNN_STATUS_EXECUTION_FAILED on Tegra K1, so disable CUDNN | |||
| // on Tegra K1. | |||
| @@ -25,6 +25,7 @@ namespace convolution { | |||
| struct ForwardSizeArgs { | |||
| HandleImpl *handle; | |||
| const TensorLayout *src_layout; | |||
| const TensorLayout *filter_layout; | |||
| CanonizedFilterMeta filter_meta; | |||
| const TensorLayout *dst_layout; | |||
| }; | |||
| @@ -102,7 +102,8 @@ void ConvolutionBackwardDataImpl::exec(_megdnn_tensor_in filter, | |||
| _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace) { | |||
| AlgoBase::ExecArgs args(this, filter, diff, grad, workspace); | |||
| auto algo = get_algorithm(this, args.filter_meta, diff.layout, grad.layout); | |||
| auto algo = get_algorithm(this, filter.layout, args.filter_meta, | |||
| diff.layout, grad.layout); | |||
| algo->check_workspace(args, workspace).exec(args); | |||
| } | |||
| @@ -120,16 +121,16 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic( | |||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | |||
| bool reproducible) { | |||
| auto fm = check_layout_fwd(grad, filter, diff); | |||
| return get_algorithm_heuristic(fm, diff, grad, workspace_limit_in_bytes, | |||
| reproducible); | |||
| return get_algorithm_heuristic(filter, fm, diff, grad, | |||
| workspace_limit_in_bytes, reproducible); | |||
| } | |||
| ConvolutionBackwardDataImpl::Algorithm* | |||
| ConvolutionBackwardDataImpl::get_algorithm_heuristic( | |||
| const CanonizedFilterMeta& filter, const TensorLayout& diff, | |||
| ConvolutionBackwardDataImpl::get_algorithm_heuristic(const TensorLayout& filter, | |||
| const CanonizedFilterMeta& filter_meta, const TensorLayout& diff, | |||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | |||
| bool reproducible) { | |||
| AlgoBase::SizeArgs args(this, filter, diff, grad); | |||
| AlgoBase::SizeArgs args(this, filter, filter_meta, diff, grad); | |||
| if (args.filter_meta.group > 1 && | |||
| sm_algo_pack.chanwise.is_available_reproducible( | |||
| @@ -209,14 +210,27 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic( | |||
| args = orig_args; | |||
| } | |||
| if (reproducible) { | |||
| return megdnn::get_reproducible_algo<ConvolutionBackwardDataImpl>( | |||
| sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | |||
| "cuda conv bwd_data"); | |||
| if (args.filter_layout->dtype.enumv() != | |||
| DTypeTrait<dtype::BFloat16>::enumv) { | |||
| if (reproducible) { | |||
| return megdnn::get_reproducible_algo<ConvolutionBackwardDataImpl>( | |||
| sm_algo_pack.non_cudnn_algos, args, | |||
| workspace_limit_in_bytes, "cuda conv bwd_data"); | |||
| } else { | |||
| return megdnn::get_usable_algo<ConvolutionBackwardDataImpl>( | |||
| sm_algo_pack.non_cudnn_algos, args, | |||
| workspace_limit_in_bytes, "cuda conv bwd_data"); | |||
| } | |||
| } else { | |||
| return megdnn::get_usable_algo<ConvolutionBackwardDataImpl>( | |||
| sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | |||
| "cuda conv bwd_data"); | |||
| if (reproducible) { | |||
| return megdnn::get_reproducible_algo<ConvolutionBackwardDataImpl>( | |||
| sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, | |||
| "cuda conv bwd_data"); | |||
| } else { | |||
| return megdnn::get_usable_algo<ConvolutionBackwardDataImpl>( | |||
| sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, | |||
| "cuda conv bwd_data"); | |||
| } | |||
| } | |||
| } | |||
| @@ -225,7 +239,7 @@ size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes( | |||
| const TensorLayout &diff, | |||
| const TensorLayout &grad) { | |||
| AlgoBase::SizeArgs args(this, filter, diff, grad); | |||
| return get_algorithm(this, args.filter_meta, diff, grad)-> | |||
| return get_algorithm(this, filter, args.filter_meta, diff, grad)-> | |||
| get_workspace_in_bytes(args); | |||
| } | |||
| @@ -241,7 +255,7 @@ void ConvolutionBackwardFilterImpl::exec(_megdnn_tensor_in src, | |||
| _megdnn_workspace workspace) { | |||
| AlgoBase::ExecArgs args(this, src, diff, grad, workspace); | |||
| auto algo = get_algorithm(this, src.layout, diff.layout, | |||
| args.grad_filter_meta); | |||
| grad.layout, args.grad_filter_meta); | |||
| algo->check_workspace(args, workspace).exec(args); | |||
| } | |||
| @@ -259,16 +273,16 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | |||
| const TensorLayout& grad, size_t workspace_limit_in_bytes, | |||
| bool reproducible) { | |||
| auto fm = check_layout_fwd(src, grad, diff); | |||
| return get_algorithm_heuristic(src, diff, fm, workspace_limit_in_bytes, | |||
| reproducible); | |||
| return get_algorithm_heuristic(src, diff, grad, fm, | |||
| workspace_limit_in_bytes, reproducible); | |||
| } | |||
| ConvolutionBackwardFilterImpl::Algorithm* | |||
| ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | |||
| const TensorLayout& src, const TensorLayout& diff, | |||
| const CanonizedFilterMeta& grad, size_t workspace_limit_in_bytes, | |||
| bool reproducible) { | |||
| AlgoBase::SizeArgs args(this, src, diff, grad); | |||
| const TensorLayout& grad, const CanonizedFilterMeta& grad_meta, | |||
| size_t workspace_limit_in_bytes, bool reproducible) { | |||
| AlgoBase::SizeArgs args(this, src, diff, grad, grad_meta); | |||
| if (args.grad_filter_meta.group > 1 && | |||
| sm_algo_pack.chanwise.is_available_reproducible( | |||
| @@ -349,14 +363,26 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | |||
| args = orig_args; | |||
| } | |||
| if (reproducible) { | |||
| return megdnn::get_reproducible_algo<ConvolutionBackwardFilterImpl>( | |||
| sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | |||
| "cuda conv bwd_filter"); | |||
| if (args.src_layout->dtype.enumv() != DTypeTrait<dtype::BFloat16>::enumv) { | |||
| if (reproducible) { | |||
| return megdnn::get_reproducible_algo<ConvolutionBackwardFilterImpl>( | |||
| sm_algo_pack.non_cudnn_algos, args, | |||
| workspace_limit_in_bytes, "cuda conv bwd_filter"); | |||
| } else { | |||
| return megdnn::get_usable_algo<ConvolutionBackwardFilterImpl>( | |||
| sm_algo_pack.non_cudnn_algos, args, | |||
| workspace_limit_in_bytes, "cuda conv bwd_filter"); | |||
| } | |||
| } else { | |||
| return megdnn::get_usable_algo<ConvolutionBackwardFilterImpl>( | |||
| sm_algo_pack.non_cudnn_algos, args, workspace_limit_in_bytes, | |||
| "cuda conv bwd_filter"); | |||
| if (reproducible) { | |||
| return megdnn::get_reproducible_algo<ConvolutionBackwardFilterImpl>( | |||
| sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, | |||
| "cuda conv bwd_filter"); | |||
| } else { | |||
| return megdnn::get_usable_algo<ConvolutionBackwardFilterImpl>( | |||
| sm_algo_pack.bfloat16_algos, args, workspace_limit_in_bytes, | |||
| "cuda conv bwd_filter"); | |||
| } | |||
| } | |||
| } | |||
| @@ -365,7 +391,7 @@ size_t ConvolutionBackwardFilterImpl::get_workspace_in_bytes( | |||
| const TensorLayout &diff, | |||
| const TensorLayout &grad) { | |||
| AlgoBase::SizeArgs args(this, src, diff, grad); | |||
| return get_algorithm(this, src, diff, args.grad_filter_meta)-> | |||
| return get_algorithm(this, src, diff, grad, args.grad_filter_meta)-> | |||
| get_workspace_in_bytes(args); | |||
| } | |||
| @@ -60,11 +60,11 @@ class ConvolutionBackwardDataImpl: public ConvolutionBackwardData { | |||
| const TensorLayout& grad, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible) override; | |||
| Algorithm* get_algorithm_heuristic(const CanonizedFilterMeta& filter, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible); | |||
| Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& filter, | |||
| const CanonizedFilterMeta& filter_meta, | |||
| const TensorLayout& diff, const TensorLayout& grad, | |||
| size_t workspace_limit_in_bytes, bool reproducible); | |||
| size_t get_workspace_in_bytes(const TensorLayout& filter, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad) override; | |||
| @@ -76,6 +76,7 @@ class ConvolutionBackwardDataImpl: public ConvolutionBackwardData { | |||
| class AlgoChanwise; | |||
| class AlgoChanwiseSmall; | |||
| class AlgoGroupConvGeneral; | |||
| class AlgoBFloat16; | |||
| class AlgoPack; | |||
| @@ -104,7 +105,8 @@ class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter { | |||
| bool reproducible) override; | |||
| Algorithm* get_algorithm_heuristic(const TensorLayout& src, | |||
| const TensorLayout& diff, | |||
| const CanonizedFilterMeta& grad, | |||
| const TensorLayout& gradk, | |||
| const CanonizedFilterMeta& grad_meta, | |||
| size_t workspace_limit_in_bytes, | |||
| bool reproducible); | |||
| size_t get_workspace_in_bytes(const TensorLayout& src, | |||
| @@ -117,6 +119,7 @@ class ConvolutionBackwardFilterImpl: public ConvolutionBackwardFilter { | |||
| class AlgoMatmul; | |||
| class AlgoChanwise; | |||
| class AlgoGroupConvGeneral; | |||
| class AlgoBFloat16; | |||
| class AlgoPack; | |||
| @@ -50,7 +50,7 @@ class Convolution3DForwardImpl::AlgoBase: public Algorithm { | |||
| const CanonizedFilterMeta &filter, | |||
| const TensorLayout &dst); | |||
| }; | |||
| struct ExecArgs: public SizeArgs { | |||
| struct ExecArgs : public SizeArgs { | |||
| const TensorND *src_tensor, *filter_tensor, *dst_tensor; | |||
| Workspace workspace; | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/ABS_GRAD_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) | |||
| #define KERN_IMPL_ARITY 2 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/ABS_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) | |||
| #define KERN_IMPL_ARITY 1 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/ACOS_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOS, cb) | |||
| #define KERN_IMPL_ARITY 1 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/ADD_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) | |||
| #define KERN_IMPL_ARITY 2 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/ASIN_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASIN, cb) | |||
| #define KERN_IMPL_ARITY 1 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/ATAN2_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATAN2, cb) | |||
| #define KERN_IMPL_ARITY 2 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/CEIL_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CEIL, cb) | |||
| #define KERN_IMPL_ARITY 1 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/COND_LEQ_MOV_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) | |||
| #define KERN_IMPL_ARITY 3 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/COS_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COS, cb) | |||
| #define KERN_IMPL_ARITY 1 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/EQ_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) | |||
| #define KERN_IMPL_ARITY 2 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/ERFCINV_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERFCINV, cb) | |||
| #define KERN_IMPL_ARITY 1 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/ERFC_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERFC, cb) | |||
| #define KERN_IMPL_ARITY 1 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/ERFINV_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERFINV, cb) | |||
| #define KERN_IMPL_ARITY 1 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/ERF_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERF, cb) | |||
| #define KERN_IMPL_ARITY 1 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/EXPM1_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EXPM1, cb) | |||
| #define KERN_IMPL_ARITY 1 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/EXP_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EXP, cb) | |||
| #define KERN_IMPL_ARITY 1 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/FAST_TANH_GRAD_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH_GRAD, cb) | |||
| #define KERN_IMPL_ARITY 2 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/FAST_TANH_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH, cb) | |||
| #define KERN_IMPL_ARITY 1 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/FLOOR_DIV_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR_DIV, cb) | |||
| #define KERN_IMPL_ARITY 2 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/FLOOR_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR, cb) | |||
| #define KERN_IMPL_ARITY 1 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/FUSE_ADD_H_SWISH_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) | |||
| #define KERN_IMPL_ARITY 2 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/FUSE_ADD_RELU_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) | |||
| #define KERN_IMPL_ARITY 2 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/FUSE_ADD_SIGMOID_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_SIGMOID, cb) | |||
| #define KERN_IMPL_ARITY 2 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/FUSE_ADD_TANH_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_TANH, cb) | |||
| #define KERN_IMPL_ARITY 2 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/FUSE_MUL_ADD3_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb) | |||
| #define KERN_IMPL_ARITY 3 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/H_SWISH_GRAD_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH_GRAD, cb) | |||
| #define KERN_IMPL_ARITY 2 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/H_SWISH_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) | |||
| #define KERN_IMPL_ARITY 1 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/LEQ_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) | |||
| #define KERN_IMPL_ARITY 2 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/LOG1P_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOG1P, cb) | |||
| #define KERN_IMPL_ARITY 1 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/LOG_SUM_EXP_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOG_SUM_EXP, cb) | |||
| #define KERN_IMPL_ARITY 2 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/LOG_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOG, cb) | |||
| #define KERN_IMPL_ARITY 1 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/LT_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) | |||
| #define KERN_IMPL_ARITY 2 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/MAX_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MAX, cb) | |||
| #define KERN_IMPL_ARITY 2 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/MIN_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MIN, cb) | |||
| #define KERN_IMPL_ARITY 2 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/MOD_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MOD, cb) | |||
| #define KERN_IMPL_ARITY 2 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/MUL_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MUL, cb) | |||
| #define KERN_IMPL_ARITY 2 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/NEGATE_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb) | |||
| #define KERN_IMPL_ARITY 1 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/POW_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(POW, cb) | |||
| #define KERN_IMPL_ARITY 2 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/RELU_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) | |||
| #define KERN_IMPL_ARITY 1 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/ROUND_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ROUND, cb) | |||
| #define KERN_IMPL_ARITY 1 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/SIGMOID_GRAD_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID_GRAD, cb) | |||
| #define KERN_IMPL_ARITY 2 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/SIGMOID_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID, cb) | |||
| #define KERN_IMPL_ARITY 1 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/SIN_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIN, cb) | |||
| #define KERN_IMPL_ARITY 1 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/SUB_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SUB, cb) | |||
| #define KERN_IMPL_ARITY 2 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/SWITCH_GT0_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SWITCH_GT0, cb) | |||
| #define KERN_IMPL_ARITY 2 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/TANH_GRAD_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH_GRAD, cb) | |||
| #define KERN_IMPL_ARITY 2 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/TANH_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH, cb) | |||
| #define KERN_IMPL_ARITY 1 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,17 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/kimpl/TRUE_DIV_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TRUE_DIV, cb) | |||
| #define KERN_IMPL_ARITY 2 | |||
| #define KERN_IMPL_CTYPE dt_bfloat16 | |||
| #include "../kern_impl.inl" | |||
| #endif | |||
| @@ -0,0 +1,18 @@ | |||
| /** | |||
| * \file dnn/src/cuda/elemwise/special_kimpl/special_dt_bfloat16.cu | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| // generated by gen_elemwise_special_kern_impls.py | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| #include "../special_kerns.inl" | |||
| INST(::megdnn::dtype::BFloat16) | |||
| #undef INST | |||
| } | |||
| } | |||
| #endif | |||
| @@ -141,6 +141,9 @@ INST_FOR_CTYPE | |||
| #define ct dt_float16 | |||
| INST_FOR_CTYPE | |||
| #undef ct | |||
| #define ct dt_bfloat16 | |||
| INST_FOR_CTYPE | |||
| #undef ct | |||
| #define ct dt_int8 | |||
| INST_FOR_CTYPE | |||
| #undef ct | |||
| @@ -68,6 +68,17 @@ namespace elemwise_intl { | |||
| return t; | |||
| } | |||
| struct __attribute__((aligned(8))) bhalf4 { | |||
| dt_bfloat16 x, y, z, w; | |||
| }; | |||
| __device__ __forceinline__ bhalf4 make_bhalf4(dt_bfloat16 x, dt_bfloat16 y, | |||
| dt_bfloat16 z, dt_bfloat16 w) { | |||
| bhalf4 t; | |||
| t.x = x, t.y = y, t.z = z, t.w = w; | |||
| return t; | |||
| } | |||
| #define INST(_ctype, _vect_type) \ | |||
| template <> \ | |||
| class VectTypeTrait<_ctype> { \ | |||
| @@ -87,6 +98,7 @@ namespace elemwise_intl { | |||
| INST(dt_uint8, uchar4); | |||
| INST(dt_float32, float4); | |||
| INST(dt_float16, half4); | |||
| INST(dt_bfloat16, bhalf4); | |||
| INST(dt_int32, int4); | |||
| INST(dt_int16, short4); | |||
| #undef as_raw | |||
| @@ -17,6 +17,11 @@ __device__ void atomicAdd(megdnn::dt_float16 *, megdnn::dt_float16) { | |||
| __trap(); | |||
| ((int*)0)[0] = 1; | |||
| } | |||
| __device__ void atomicAdd(megdnn::dt_bfloat16 *, megdnn::dt_bfloat16) { | |||
| __trap(); | |||
| ((int*)0)[0] = 1; | |||
| } | |||
| #endif | |||
| __device__ void atomicAdd(megdnn::dt_int8 *, megdnn::dt_int8) { | |||
| @@ -29,6 +29,10 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() { | |||
| all_algos.push_back(&cublas_lt); | |||
| #endif | |||
| all_algos.push_back(&naive); | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| cublas_bfloat16 = std::make_unique<AlgoBFloat16>(&cublas); | |||
| all_algos.push_back(cublas_bfloat16.get()); | |||
| #endif | |||
| } | |||
| MatrixMulForwardImpl::AlgoPack MatrixMulForwardImpl::sm_algo_pack; | |||
| @@ -15,6 +15,7 @@ | |||
| #include "src/cuda/matrix_mul/opr_impl.h" | |||
| #include <cuda.h> | |||
| #include <memory> | |||
| #if CUDA_VERSION >= 10010 | |||
| #include <cublasLt.h> | |||
| #endif | |||
| @@ -140,6 +141,24 @@ public: | |||
| bool is_reproducible() const override { return true; } | |||
| }; | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| class MatrixMulForwardImpl::AlgoBFloat16 final : public AlgoBase { | |||
| public: | |||
| AlgoBFloat16(MatrixMulForwardImpl::AlgoBase*); | |||
| bool is_available(const SizeArgs& args) const override; | |||
| size_t get_workspace_in_bytes(const SizeArgs& args) const override; | |||
| const char* name() const override { return m_name.c_str(); } | |||
| void exec(const ExecArgs& args) const override; | |||
| bool is_reproducible() const override { return true; } | |||
| private: | |||
| MatrixMulForwardImpl::AlgoBase* m_algorithm = nullptr; | |||
| std::string m_name; | |||
| WorkspaceBundle get_workspace_bundle(void* ptr, const SizeArgs& args) const; | |||
| SizeArgs float_args(const SizeArgs& args) const; | |||
| }; | |||
| #endif | |||
| class MatrixMulForwardImpl::AlgoPack { | |||
| AlgoPack(const AlgoPack&) = delete; | |||
| AlgoPack& operator=(const AlgoPack&) = delete; | |||
| @@ -154,7 +173,9 @@ public: | |||
| #if CUDA_VERSION >= 10010 | |||
| AlgoCuBlasLt cublas_lt; | |||
| #endif | |||
| #if !MEGDNN_DISABLE_FLOAT16 | |||
| std::unique_ptr<AlgoBFloat16> cublas_bfloat16; | |||
| #endif | |||
| std::vector<AlgoBase*> all_algos; | |||
| }; | |||
| @@ -0,0 +1,91 @@ | |||
| /** | |||
| * \file dnn/src/cuda/matrix_mul/bfloat16.cpp | |||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| * | |||
| * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "src/cuda/handle.h" | |||
| #include "src/cuda/matrix_mul/algos.h" | |||
| #include "src/cuda/utils.h" | |||
| using namespace megdnn; | |||
| using namespace cuda; | |||
| MatrixMulForwardImpl::AlgoBFloat16::AlgoBFloat16( | |||
| MatrixMulForwardImpl::AlgoBase* algorithm) | |||
| : m_algorithm(algorithm) { | |||
| megdnn_assert_internal(algorithm); | |||
| m_name = ssprintf("MATMUL_BFLOAT16:%s", m_algorithm->name()); | |||
| } | |||
| MatrixMulForwardImpl::AlgoBase::SizeArgs | |||
| MatrixMulForwardImpl::AlgoBFloat16::float_args(const SizeArgs& args) const { | |||
| auto new_args = args; | |||
| auto change_dtype = [](TensorLayout& layout) { | |||
| if (layout.dtype == dtype::BFloat16()) { | |||
| layout.dtype = dtype::Float32(); | |||
| } | |||
| }; | |||
| change_dtype(new_args.layout_a); | |||
| change_dtype(new_args.layout_b); | |||
| change_dtype(new_args.layout_c); | |||
| return new_args; | |||
| } | |||
| bool MatrixMulForwardImpl::AlgoBFloat16::is_available( | |||
| const SizeArgs& args) const { | |||
| auto fargs = float_args(args); | |||
| return args.layout_a.dtype == dtype::BFloat16() && | |||
| m_algorithm->is_available(fargs); | |||
| } | |||
| WorkspaceBundle MatrixMulForwardImpl::AlgoBFloat16::get_workspace_bundle( | |||
| void* ptr, const SizeArgs& args) const { | |||
| auto fargs = float_args(args); | |||
| SmallVector<size_t> sizes; | |||
| auto get_workspace = [&sizes](const TensorLayout& src) { | |||
| TensorLayout dst = src; | |||
| if (dst.dtype == dtype::BFloat16()) { | |||
| dst.dtype = dtype::Float32(); | |||
| sizes.push_back(dst.span().dist_byte()); | |||
| } | |||
| }; | |||
| get_workspace(args.layout_a); | |||
| get_workspace(args.layout_b); | |||
| get_workspace(args.layout_c); | |||
| sizes.push_back(m_algorithm->get_workspace_in_bytes(fargs)); | |||
| return {ptr, std::move(sizes)}; | |||
| } | |||
| size_t MatrixMulForwardImpl::AlgoBFloat16::get_workspace_in_bytes( | |||
| const SizeArgs& args) const { | |||
| return get_workspace_bundle(nullptr, args).total_size_in_bytes(); | |||
| } | |||
| void MatrixMulForwardImpl::AlgoBFloat16::exec(const ExecArgs& args) const { | |||
| TensorND a = args.tensor_a; | |||
| TensorND b = args.tensor_b; | |||
| TensorND c = args.tensor_c; | |||
| auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args); | |||
| auto ctypecvt = CompTypeCvter<dtype::BFloat16, dtype::Float32>( | |||
| args.opr->handle(), &bundle); | |||
| ctypecvt.src_to_comp_type(args.tensor_a, a) | |||
| .src_to_comp_type(args.tensor_b, b) | |||
| .src_to_comp_type(args.tensor_c, c); | |||
| { | |||
| auto matmul_opr = | |||
| args.opr->handle()->create_operator<MatrixMulForward>(); | |||
| matmul_opr->param() = args.opr->param(); | |||
| matmul_opr->param().compute_mode = Param::ComputeMode::DEFAULT; | |||
| matmul_opr->execution_policy() = {m_algorithm}; | |||
| matmul_opr->exec(a, b, c, ctypecvt.workspace()); | |||
| } | |||
| ctypecvt.comp_to_dst_type(c, args.tensor_c); | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||