| @@ -23,9 +23,9 @@ | |||||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | #pragma GCC diagnostic ignored "-Wunused-parameter" | ||||
| #pragma GCC diagnostic ignored "-Wdeprecated-declarations" | #pragma GCC diagnostic ignored "-Wdeprecated-declarations" | ||||
| #pragma GCC diagnostic ignored "-Wsign-compare" | #pragma GCC diagnostic ignored "-Wsign-compare" | ||||
| #include <hip/hip_runtime_api.h> | |||||
| #include <hip/hip_runtime.h> | |||||
| #include <hip/hip_fp16.h> | #include <hip/hip_fp16.h> | ||||
| #include <hip/hip_runtime.h> | |||||
| #include <hip/hip_runtime_api.h> | |||||
| #pragma GCC diagnostic pop | #pragma GCC diagnostic pop | ||||
| #if !defined(__HIP_PLATFORM_HCC__) | #if !defined(__HIP_PLATFORM_HCC__) | ||||
| @@ -11,10 +11,10 @@ | |||||
| #pragma once | #pragma once | ||||
| #include "megdnn/thin/function.h" | |||||
| #include "megcore_cdefs.h" | |||||
| #include <cstddef> | #include <cstddef> | ||||
| #include <memory> | #include <memory> | ||||
| #include "megcore_cdefs.h" | |||||
| #include "megdnn/thin/function.h" | |||||
| #include "megdnn/internal/visibility_prologue.h" | #include "megdnn/internal/visibility_prologue.h" | ||||
| @@ -26,36 +26,35 @@ namespace megcore { | |||||
| * the caller thread immediately. | * the caller thread immediately. | ||||
| */ | */ | ||||
| class CPUDispatcher { | class CPUDispatcher { | ||||
| public: | |||||
| using Task = megdnn::thin_function<void()>; | |||||
| using MultiThreadingTask = megdnn::thin_function<void(size_t, size_t)>; | |||||
| virtual ~CPUDispatcher() noexcept; | |||||
| /*! | |||||
| * \brief dispatch a task on the computing thread | |||||
| * \param task the task that would be moved away | |||||
| */ | |||||
| virtual void dispatch(Task&& task) = 0; | |||||
| /*! | |||||
| * \brief dispatch a multithreading task on the computing thread | |||||
| * \param task the task would be moved away | |||||
| * \param parallelism the parallelism of the task. | |||||
| */ | |||||
| virtual void dispatch(MultiThreadingTask&& task, | |||||
| size_t parallelism) = 0; | |||||
| /*! | |||||
| * \brief synchronize the calling thread with the computing thread | |||||
| */ | |||||
| virtual void sync() = 0; | |||||
| /*! | |||||
| * \brief the computing thread number. | |||||
| */ | |||||
| virtual size_t nr_threads() = 0; | |||||
| public: | |||||
| using Task = megdnn::thin_function<void()>; | |||||
| using MultiThreadingTask = megdnn::thin_function<void(size_t, size_t)>; | |||||
| virtual ~CPUDispatcher() noexcept; | |||||
| /*! | |||||
| * \brief dispatch a task on the computing thread | |||||
| * \param task the task that would be moved away | |||||
| */ | |||||
| virtual void dispatch(Task&& task) = 0; | |||||
| /*! | |||||
| * \brief dispatch a multithreading task on the computing thread | |||||
| * \param task the task would be moved away | |||||
| * \param parallelism the parallelism of the task. | |||||
| */ | |||||
| virtual void dispatch(MultiThreadingTask&& task, size_t parallelism) = 0; | |||||
| /*! | |||||
| * \brief synchronize the calling thread with the computing thread | |||||
| */ | |||||
| virtual void sync() = 0; | |||||
| /*! | |||||
| * \brief the computing thread number. | |||||
| */ | |||||
| virtual size_t nr_threads() = 0; | |||||
| }; | }; | ||||
| } // namespace megcore | |||||
| } // namespace megcore | |||||
| using MegcoreCPUDispatcher = megcore::CPUDispatcher; | using MegcoreCPUDispatcher = megcore::CPUDispatcher; | ||||
| @@ -63,75 +62,62 @@ using MegcoreCPUDispatcher = megcore::CPUDispatcher; | |||||
| * \brief Layer 1: device handle | * \brief Layer 1: device handle | ||||
| */ | */ | ||||
| struct megcoreDeviceContext; | struct megcoreDeviceContext; | ||||
| typedef struct megcoreDeviceContext *megcoreDeviceHandle_t; | |||||
| typedef struct megcoreDeviceContext* megcoreDeviceHandle_t; | |||||
| megcoreStatus_t megcoreCreateDeviceHandle( | megcoreStatus_t megcoreCreateDeviceHandle( | ||||
| megcoreDeviceHandle_t *handle, | |||||
| megcorePlatform_t platform, | |||||
| int deviceID = -1, | |||||
| megcoreDeviceHandle_t* handle, megcorePlatform_t platform, int deviceID = -1, | |||||
| unsigned int flags = 0); | unsigned int flags = 0); | ||||
| megcoreStatus_t megcoreDestroyDeviceHandle( | |||||
| megcoreDeviceHandle_t handle); | |||||
| megcoreStatus_t megcoreGetPlatform(megcoreDeviceHandle_t handle, | |||||
| megcorePlatform_t *platform); | |||||
| megcoreStatus_t megcoreGetDeviceID(megcoreDeviceHandle_t handle, | |||||
| int *deviceID); | |||||
| megcoreStatus_t megcoreGetMemAlignment(megcoreDeviceHandle_t handle, | |||||
| size_t *memAlignmentInBytes); | |||||
| megcoreStatus_t megcoreDestroyDeviceHandle(megcoreDeviceHandle_t handle); | |||||
| megcoreStatus_t megcoreGetPlatform( | |||||
| megcoreDeviceHandle_t handle, megcorePlatform_t* platform); | |||||
| megcoreStatus_t megcoreGetDeviceID(megcoreDeviceHandle_t handle, int* deviceID); | |||||
| megcoreStatus_t megcoreGetMemAlignment( | |||||
| megcoreDeviceHandle_t handle, size_t* memAlignmentInBytes); | |||||
| megcoreStatus_t megcoreGetDeviceFlags( | megcoreStatus_t megcoreGetDeviceFlags( | ||||
| megcoreDeviceHandle_t handle, | |||||
| unsigned int *flags); | |||||
| megcoreDeviceHandle_t handle, unsigned int* flags); | |||||
| megcoreStatus_t megcoreActivate(megcoreDeviceHandle_t handle); | megcoreStatus_t megcoreActivate(megcoreDeviceHandle_t handle); | ||||
| megcoreStatus_t megcoreDeactivate(megcoreDeviceHandle_t handle); | megcoreStatus_t megcoreDeactivate(megcoreDeviceHandle_t handle); | ||||
| megcoreStatus_t megcoreMalloc(megcoreDeviceHandle_t handle, | |||||
| void **devPtr, size_t sizeInBytes); | |||||
| megcoreStatus_t megcoreFree(megcoreDeviceHandle_t handle, | |||||
| void *devPtr); | |||||
| megcoreStatus_t megcoreMalloc( | |||||
| megcoreDeviceHandle_t handle, void** devPtr, size_t sizeInBytes); | |||||
| megcoreStatus_t megcoreFree(megcoreDeviceHandle_t handle, void* devPtr); | |||||
| /** | /** | ||||
| * \brief Layer 2: computing handle | * \brief Layer 2: computing handle | ||||
| */ | */ | ||||
| struct megcoreComputingContext; | struct megcoreComputingContext; | ||||
| typedef struct megcoreComputingContext *megcoreComputingHandle_t; | |||||
| typedef struct megcoreComputingContext* megcoreComputingHandle_t; | |||||
| megcoreStatus_t megcoreCreateComputingHandle( | megcoreStatus_t megcoreCreateComputingHandle( | ||||
| megcoreComputingHandle_t *compHandle, | |||||
| megcoreDeviceHandle_t devHandle, | |||||
| megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | |||||
| unsigned int flags = 0); | unsigned int flags = 0); | ||||
| megcoreStatus_t megcoreCreateComputingHandleWithCPUDispatcher( | megcoreStatus_t megcoreCreateComputingHandleWithCPUDispatcher( | ||||
| megcoreComputingHandle_t *compHandle, | |||||
| megcoreDeviceHandle_t devHandle, | |||||
| megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | |||||
| const std::shared_ptr<MegcoreCPUDispatcher>& dispatcher, | const std::shared_ptr<MegcoreCPUDispatcher>& dispatcher, | ||||
| unsigned int flags = 0); | unsigned int flags = 0); | ||||
| megcoreStatus_t megcoreDestroyComputingHandle( | |||||
| megcoreComputingHandle_t handle); | |||||
| megcoreStatus_t megcoreDestroyComputingHandle(megcoreComputingHandle_t handle); | |||||
| megcoreStatus_t megcoreGetDeviceHandle( | megcoreStatus_t megcoreGetDeviceHandle( | ||||
| megcoreComputingHandle_t compHandle, | |||||
| megcoreDeviceHandle_t *devHandle); | |||||
| megcoreComputingHandle_t compHandle, megcoreDeviceHandle_t* devHandle); | |||||
| megcoreStatus_t megcoreGetComputingFlags( | megcoreStatus_t megcoreGetComputingFlags( | ||||
| megcoreComputingHandle_t handle, | |||||
| unsigned int *flags); | |||||
| megcoreComputingHandle_t handle, unsigned int* flags); | |||||
| MegcoreCPUDispatcher* megcoreGetCPUDispatcher(megcoreComputingHandle_t handle); | MegcoreCPUDispatcher* megcoreGetCPUDispatcher(megcoreComputingHandle_t handle); | ||||
| megcoreStatus_t megcoreMemcpy( | megcoreStatus_t megcoreMemcpy( | ||||
| megcoreComputingHandle_t handle, | |||||
| void *dst, const void *src, size_t sizeInBytes, | |||||
| megcoreComputingHandle_t handle, void* dst, const void* src, size_t sizeInBytes, | |||||
| megcoreMemcpyKind_t kind); | megcoreMemcpyKind_t kind); | ||||
| megcoreStatus_t megcoreMemset( | megcoreStatus_t megcoreMemset( | ||||
| megcoreComputingHandle_t handle, | |||||
| void *dst, int value, size_t sizeInBytes); | |||||
| megcoreComputingHandle_t handle, void* dst, int value, size_t sizeInBytes); | |||||
| megcoreStatus_t megcoreSynchronize(megcoreComputingHandle_t handle); | megcoreStatus_t megcoreSynchronize(megcoreComputingHandle_t handle); | ||||
| /** | /** | ||||
| * \brief Miscellaneous | * \brief Miscellaneous | ||||
| */ | */ | ||||
| const char *megcoreGetErrorName(megcoreStatus_t status); | |||||
| const char* megcoreGetErrorName(megcoreStatus_t status); | |||||
| #include "megdnn/internal/visibility_epilogue.h" | #include "megdnn/internal/visibility_epilogue.h" | ||||
| @@ -33,8 +33,7 @@ megcoreStatus_t createComputingHandleWithAtlasContext( | |||||
| megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | ||||
| unsigned int flags, const AtlasContext& ctx); | unsigned int flags, const AtlasContext& ctx); | ||||
| megcoreStatus_t getAtlasContext(megcoreComputingHandle_t handle, | |||||
| AtlasContext* ctx); | |||||
| megcoreStatus_t getAtlasContext(megcoreComputingHandle_t handle, AtlasContext* ctx); | |||||
| namespace atlas { | namespace atlas { | ||||
| //! convert acl error code to error string | //! convert acl error code to error string | ||||
| @@ -47,12 +46,12 @@ inline megcoreStatus_t megcoreCreateComputingHandleWithACLStream( | |||||
| megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | ||||
| unsigned int flags, aclrtStream stream) { | unsigned int flags, aclrtStream stream) { | ||||
| megcore::AtlasContext ctx{stream}; | megcore::AtlasContext ctx{stream}; | ||||
| return megcore::createComputingHandleWithAtlasContext(compHandle, devHandle, | |||||
| flags, ctx); | |||||
| return megcore::createComputingHandleWithAtlasContext( | |||||
| compHandle, devHandle, flags, ctx); | |||||
| } | } | ||||
| inline megcoreStatus_t megcoreGetACLStream(megcoreComputingHandle_t handle, | |||||
| aclrtStream* stream) { | |||||
| inline megcoreStatus_t megcoreGetACLStream( | |||||
| megcoreComputingHandle_t handle, aclrtStream* stream) { | |||||
| megcore::AtlasContext ctx; | megcore::AtlasContext ctx; | ||||
| auto ret = megcore::getAtlasContext(handle, &ctx); | auto ret = megcore::getAtlasContext(handle, &ctx); | ||||
| *stream = ctx.stream; | *stream = ctx.stream; | ||||
| @@ -34,8 +34,8 @@ megcoreStatus_t createComputingHandleWithCambriconContext( | |||||
| megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | ||||
| unsigned int flags, const CambriconContext& ctx); | unsigned int flags, const CambriconContext& ctx); | ||||
| megcoreStatus_t getCambriconContext(megcoreComputingHandle_t handle, | |||||
| CambriconContext* ctx); | |||||
| megcoreStatus_t getCambriconContext( | |||||
| megcoreComputingHandle_t handle, CambriconContext* ctx); | |||||
| } // namespace megcore | } // namespace megcore | ||||
| @@ -58,4 +58,3 @@ static inline megcoreStatus_t megcoreGetCNRTQueue( | |||||
| #include "megdnn/internal/visibility_epilogue.h" | #include "megdnn/internal/visibility_epilogue.h" | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -40,7 +40,6 @@ typedef enum { | |||||
| megcoreErrorInternalError = 5, | megcoreErrorInternalError = 5, | ||||
| } megcoreStatus_t; | } megcoreStatus_t; | ||||
| /** | /** | ||||
| * \brief Memcpy kind | * \brief Memcpy kind | ||||
| */ | */ | ||||
| @@ -70,6 +69,6 @@ struct AsyncErrorInfo { | |||||
| char msg[228]; | char msg[228]; | ||||
| int msg_args[4]; | int msg_args[4]; | ||||
| }; | }; | ||||
| } // namespace megcore | |||||
| } // namespace megcore | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -33,8 +33,7 @@ megcoreStatus_t createComputingHandleWithCUDAContext( | |||||
| megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | ||||
| unsigned int flags, const CudaContext& ctx); | unsigned int flags, const CudaContext& ctx); | ||||
| megcoreStatus_t getCUDAContext(megcoreComputingHandle_t handle, | |||||
| CudaContext* ctx); | |||||
| megcoreStatus_t getCUDAContext(megcoreComputingHandle_t handle, CudaContext* ctx); | |||||
| } // namespace megcore | } // namespace megcore | ||||
| @@ -43,8 +42,8 @@ static inline megcoreStatus_t megcoreCreateComputingHandleWithCUDAStream( | |||||
| unsigned int flags, cudaStream_t stream) { | unsigned int flags, cudaStream_t stream) { | ||||
| megcore::CudaContext ctx; | megcore::CudaContext ctx; | ||||
| ctx.stream = stream; | ctx.stream = stream; | ||||
| return megcore::createComputingHandleWithCUDAContext(compHandle, devHandle, | |||||
| flags, ctx); | |||||
| return megcore::createComputingHandleWithCUDAContext( | |||||
| compHandle, devHandle, flags, ctx); | |||||
| } | } | ||||
| static inline megcoreStatus_t megcoreGetCUDAStream( | static inline megcoreStatus_t megcoreGetCUDAStream( | ||||
| @@ -23,7 +23,9 @@ struct ROCMContext { | |||||
| hipStream_t stream = nullptr; | hipStream_t stream = nullptr; | ||||
| static std::atomic_bool sm_miopen_algo_search; | static std::atomic_bool sm_miopen_algo_search; | ||||
| static inline bool enable_miopen_algo_search() { return sm_miopen_algo_search.load(); } | |||||
| static inline bool enable_miopen_algo_search() { | |||||
| return sm_miopen_algo_search.load(); | |||||
| } | |||||
| static inline void enable_miopen_algo_search(bool enable_algo_search) { | static inline void enable_miopen_algo_search(bool enable_algo_search) { | ||||
| sm_miopen_algo_search.store(enable_algo_search); | sm_miopen_algo_search.store(enable_algo_search); | ||||
| } | } | ||||
| @@ -40,8 +42,7 @@ megcoreStatus_t createComputingHandleWithROCMContext( | |||||
| megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | ||||
| unsigned int flags, const ROCMContext& ctx); | unsigned int flags, const ROCMContext& ctx); | ||||
| megcoreStatus_t getROCMContext(megcoreComputingHandle_t handle, | |||||
| ROCMContext* ctx); | |||||
| megcoreStatus_t getROCMContext(megcoreComputingHandle_t handle, ROCMContext* ctx); | |||||
| // Set MIOpen algo search enabled or disabled | // Set MIOpen algo search enabled or disabled | ||||
| megcoreStatus_t enableMIOpenAlgoSearch(bool enable_algo_search = true); | megcoreStatus_t enableMIOpenAlgoSearch(bool enable_algo_search = true); | ||||
| @@ -55,8 +56,8 @@ static inline megcoreStatus_t megcoreCreateComputingHandleWithROCMStream( | |||||
| unsigned int flags, hipStream_t stream) { | unsigned int flags, hipStream_t stream) { | ||||
| megcore::ROCMContext ctx; | megcore::ROCMContext ctx; | ||||
| ctx.stream = stream; | ctx.stream = stream; | ||||
| return megcore::createComputingHandleWithROCMContext(compHandle, devHandle, | |||||
| flags, ctx); | |||||
| return megcore::createComputingHandleWithROCMContext( | |||||
| compHandle, devHandle, flags, ctx); | |||||
| } | } | ||||
| static inline megcoreStatus_t megcoreGetROCMStream( | static inline megcoreStatus_t megcoreGetROCMStream( | ||||
| @@ -10,7 +10,7 @@ | |||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| #include "megdnn/version.h" | |||||
| #include "megdnn/oprs.h" | #include "megdnn/oprs.h" | ||||
| #include "megdnn/version.h" | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -14,20 +14,20 @@ | |||||
| #include "megdnn/config/config.h" | #include "megdnn/config/config.h" | ||||
| #if defined(__GNUC__) || defined(__clang__) | #if defined(__GNUC__) || defined(__clang__) | ||||
| #if !defined (__clang__) | |||||
| // gcc specific | |||||
| #define GCC_VERSION (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__) | |||||
| #if GCC_VERSION < 40800 | |||||
| #error "GCC version should be at least 4.8.0." | |||||
| #endif // GCC_VERSION < 40800 | |||||
| #endif // !defined(__clang__) | |||||
| #if !defined(__clang__) | |||||
| // gcc specific | |||||
| #define GCC_VERSION (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__) | |||||
| #if GCC_VERSION < 40800 | |||||
| #error "GCC version should be at least 4.8.0." | |||||
| #endif // GCC_VERSION < 40800 | |||||
| #endif // !defined(__clang__) | |||||
| #ifndef megdnn_trap | |||||
| #define megdnn_trap() __builtin_trap() | |||||
| #endif | |||||
| #ifndef megdnn_trap | |||||
| #define megdnn_trap() __builtin_trap() | |||||
| #endif | |||||
| #define megdnn_likely(v) __builtin_expect(bool(v), 1) | |||||
| #define megdnn_unlikely(v) __builtin_expect(bool(v), 0) | |||||
| #define megdnn_likely(v) __builtin_expect(bool(v), 1) | |||||
| #define megdnn_unlikely(v) __builtin_expect(bool(v), 0) | |||||
| #if !defined(__clang__) && MEGDNN_ARMV7 && !defined(NDEBUG) | #if !defined(__clang__) && MEGDNN_ARMV7 && !defined(NDEBUG) | ||||
| //! Thumb2 limit code length | //! Thumb2 limit code length | ||||
| @@ -36,123 +36,122 @@ | |||||
| #define MEGDNN_ALWAYS_INLINE inline __attribute__((__always_inline__)) | #define MEGDNN_ALWAYS_INLINE inline __attribute__((__always_inline__)) | ||||
| #endif | #endif | ||||
| #define MEGDNN_DEPRECATED __attribute__((deprecated)) | |||||
| #define MEGDNN_PACKED __attribute__((packed)) | |||||
| #define MEGDNN_CONSTEXPR constexpr | |||||
| #define MEGDNN_NOEXCEPT noexcept | |||||
| #define MEGDNN_STATIC_ASSERT static_assert | |||||
| #define MEGDNN_FINAL final | |||||
| #define MEGDNN_NORETURN __attribute__((noreturn)) | |||||
| #define MEGDNN_WARN_UNUSED_RESULT __attribute__((warn_unused_result)) | |||||
| #define MEGDNN_ATTRIBUTE_TARGET(simd) __attribute__((target(simd))) | |||||
| #if defined(__clang_major__) && (__clang_major__ >= 7) | |||||
| #define MEGDNN_LAMBDA_ATTRIBUTE_TARGET(simd) __attribute__((target(simd))) | |||||
| #else | |||||
| #define MEGDNN_LAMBDA_ATTRIBUTE_TARGET(simd) [[gnu::target(simd)]] | |||||
| #endif | |||||
| #define MEGDNN_NOINLINE __attribute__((noinline)) | |||||
| #define megdnn_isatty(x) isatty(x) | |||||
| #define MEGDNN_DEPRECATED __attribute__((deprecated)) | |||||
| #define MEGDNN_PACKED __attribute__((packed)) | |||||
| #define MEGDNN_CONSTEXPR constexpr | |||||
| #define MEGDNN_NOEXCEPT noexcept | |||||
| #define MEGDNN_STATIC_ASSERT static_assert | |||||
| #define MEGDNN_FINAL final | |||||
| #define MEGDNN_NORETURN __attribute__((noreturn)) | |||||
| #define MEGDNN_WARN_UNUSED_RESULT __attribute__((warn_unused_result)) | |||||
| #define MEGDNN_ATTRIBUTE_TARGET(simd) __attribute__((target(simd))) | |||||
| #if defined(__clang_major__) && (__clang_major__ >= 7) | |||||
| #define MEGDNN_LAMBDA_ATTRIBUTE_TARGET(simd) __attribute__((target(simd))) | |||||
| #else | |||||
| #define MEGDNN_LAMBDA_ATTRIBUTE_TARGET(simd) [[gnu::target(simd)]] | |||||
| #endif | |||||
| #define MEGDNN_NOINLINE __attribute__((noinline)) | |||||
| #define megdnn_isatty(x) isatty(x) | |||||
| #elif defined(__INTEL_COMPILER) || defined(_MSC_VER) | #elif defined(__INTEL_COMPILER) || defined(_MSC_VER) | ||||
| #ifndef megdnn_trap | #ifndef megdnn_trap | ||||
| #define megdnn_trap() __debugbreak() | #define megdnn_trap() __debugbreak() | ||||
| #endif | #endif | ||||
| #define megdnn_likely(v) (bool(v)) | |||||
| #define megdnn_likely(v) (bool(v)) | |||||
| #define megdnn_unlikely(v) (bool(v)) | #define megdnn_unlikely(v) (bool(v)) | ||||
| #define MEGDNN_DEPRECATED | #define MEGDNN_DEPRECATED | ||||
| #define MEGDNN_PACKED | #define MEGDNN_PACKED | ||||
| #define MEGDNN_CONSTEXPR constexpr | |||||
| #define MEGDNN_NOEXCEPT noexcept | |||||
| #define MEGDNN_CONSTEXPR constexpr | |||||
| #define MEGDNN_NOEXCEPT noexcept | |||||
| #define MEGDNN_STATIC_ASSERT static_assert | #define MEGDNN_STATIC_ASSERT static_assert | ||||
| #define MEGDNN_FINAL final | |||||
| #define MEGDNN_FINAL final | |||||
| #if defined(_MSC_VER) | #if defined(_MSC_VER) | ||||
| #define MEGDNN_NORETURN __declspec(noreturn) | |||||
| #define MEGDNN_NOINLINE __declspec(noinline) | |||||
| #define MEGDNN_NORETURN __declspec(noreturn) | |||||
| #define MEGDNN_NOINLINE __declspec(noinline) | |||||
| #else | #else | ||||
| #define MEGDNN_NORETURN | |||||
| #define MEGDNN_FORCE_NOINLINE | |||||
| #endif // _MSC_VER | |||||
| #define MEGDNN_NORETURN | |||||
| #define MEGDNN_FORCE_NOINLINE | |||||
| #endif // _MSC_VER | |||||
| #define MEGDNN_WARN_UNUSED_RESULT | #define MEGDNN_WARN_UNUSED_RESULT | ||||
| #define megdnn_isatty(x) _isatty(x) | #define megdnn_isatty(x) _isatty(x) | ||||
| #else | #else | ||||
| #error "unknown compiler" | |||||
| #endif // __GNUC__ | |||||
| #error "unknown compiler" | |||||
| #endif // __GNUC__ | |||||
| // __cpp_exceptions and __cpp_rtti is referred from | // __cpp_exceptions and __cpp_rtti is referred from | ||||
| // https://isocpp.org/std/standing-documentssd-6-sg10-feature-test-recommendations | // https://isocpp.org/std/standing-documentssd-6-sg10-feature-test-recommendations | ||||
| // gcc < 5 does not define __cpp_exceptions but __EXCEPTIONS, | |||||
| // gcc < 5 does not define __cpp_exceptions but __EXCEPTIONS, | |||||
| // similar for __GXX_RTTI | // similar for __GXX_RTTI | ||||
| // _CPPUNWIND and _CPPRTTI is used by MSVC, see | // _CPPUNWIND and _CPPRTTI is used by MSVC, see | ||||
| // https://docs.microsoft.com/en-us/cpp/preprocessor/predefined-macrosview=vs-2019 | // https://docs.microsoft.com/en-us/cpp/preprocessor/predefined-macrosview=vs-2019 | ||||
| #ifndef MEGDNN_ENABLE_EXCEPTIONS | #ifndef MEGDNN_ENABLE_EXCEPTIONS | ||||
| #if __cpp_exceptions || __EXCEPTIONS || \ | |||||
| (defined(_MSC_VER) && defined(_CPPUNWIND)) | |||||
| #define MEGDNN_ENABLE_EXCEPTIONS 1 | |||||
| #else | |||||
| #define MEGDNN_ENABLE_EXCEPTIONS 0 | |||||
| #endif | |||||
| #if __cpp_exceptions || __EXCEPTIONS || (defined(_MSC_VER) && defined(_CPPUNWIND)) | |||||
| #define MEGDNN_ENABLE_EXCEPTIONS 1 | |||||
| #else | |||||
| #define MEGDNN_ENABLE_EXCEPTIONS 0 | |||||
| #endif | |||||
| #endif | #endif | ||||
| #ifndef MEGDNN_ENABLE_RTTI | #ifndef MEGDNN_ENABLE_RTTI | ||||
| #if __cpp_rtti || __GXX_RTTI || (defined(_MSC_VER) && defined(_CPPRTTI)) | |||||
| #define MEGDNN_ENABLE_RTTI 1 | |||||
| #else | |||||
| #define MEGDNN_ENABLE_RTTI 0 | |||||
| #endif | |||||
| #if __cpp_rtti || __GXX_RTTI || (defined(_MSC_VER) && defined(_CPPRTTI)) | |||||
| #define MEGDNN_ENABLE_RTTI 1 | |||||
| #else | |||||
| #define MEGDNN_ENABLE_RTTI 0 | |||||
| #endif | |||||
| #endif | #endif | ||||
| #ifdef __CUDACC__ | #ifdef __CUDACC__ | ||||
| #define MEGDNN_CC_CUDA 1 | |||||
| #undef MEGDNN_CONSTEXPR | |||||
| #define MEGDNN_CONSTEXPR const | |||||
| #define MEGDNN_CC_CUDA 1 | |||||
| #undef MEGDNN_CONSTEXPR | |||||
| #define MEGDNN_CONSTEXPR const | |||||
| #if defined(__CUDACC_VER_MAJOR__) | #if defined(__CUDACC_VER_MAJOR__) | ||||
| #if __CUDACC_VER_MAJOR__ >= 9 | #if __CUDACC_VER_MAJOR__ >= 9 | ||||
| #undef MEGDNN_STATIC_ASSERT | |||||
| #define MEGDNN_STATIC_ASSERT(cond, msg) static_assert(cond, msg); | |||||
| #undef MEGDNN_STATIC_ASSERT | |||||
| #define MEGDNN_STATIC_ASSERT(cond, msg) static_assert(cond, msg); | |||||
| #else | #else | ||||
| #undef MEGDNN_STATIC_ASSERT | |||||
| #define MEGDNN_STATIC_ASSERT(cond, msg) | |||||
| #undef MEGDNN_STATIC_ASSERT | |||||
| #define MEGDNN_STATIC_ASSERT(cond, msg) | |||||
| #endif | #endif | ||||
| #endif | #endif | ||||
| #define nullptr NULL | |||||
| #undef MEGDNN_FINAL | |||||
| #define MEGDNN_FINAL | |||||
| #define nullptr NULL | |||||
| #undef MEGDNN_FINAL | |||||
| #define MEGDNN_FINAL | |||||
| #elif defined(__HIPCC__) | #elif defined(__HIPCC__) | ||||
| #define MEGDNN_CC_CUDA 1 | |||||
| #define MEGDNN_CC_CUDA 1 | |||||
| #else | #else | ||||
| #define MEGDNN_CC_HOST 1 | |||||
| #endif // __CUDACC__ | |||||
| #define MEGDNN_CC_HOST 1 | |||||
| #endif // __CUDACC__ | |||||
| // MEGDNN_HOST and MEGDNN_DEVICE | // MEGDNN_HOST and MEGDNN_DEVICE | ||||
| #if MEGDNN_CC_CUDA | #if MEGDNN_CC_CUDA | ||||
| #define MEGDNN_HOST __host__ | |||||
| #define MEGDNN_DEVICE __device__ | |||||
| #define MEGDNN_HOST __host__ | |||||
| #define MEGDNN_DEVICE __device__ | |||||
| #else | #else | ||||
| #define MEGDNN_HOST | |||||
| #define MEGDNN_DEVICE | |||||
| #define MEGDNN_HOST | |||||
| #define MEGDNN_DEVICE | |||||
| #endif | #endif | ||||
| #if MEGDNN_CC_CUDA | #if MEGDNN_CC_CUDA | ||||
| #define MEGDNN_FORCE_INLINE __forceinline__ | |||||
| #define MEGDNN_FORCE_INLINE __forceinline__ | |||||
| #else | #else | ||||
| #if __GNUC__ || __has_attribute(always_inline) | #if __GNUC__ || __has_attribute(always_inline) | ||||
| #define MEGDNN_FORCE_INLINE inline __attribute__((always_inline)) | |||||
| #define MEGDNN_FORCE_INLINE inline __attribute__((always_inline)) | |||||
| #else | #else | ||||
| #define MEGDNN_FORCE_INLINE inline | |||||
| #define MEGDNN_FORCE_INLINE inline | |||||
| #endif | #endif | ||||
| #endif | #endif | ||||
| #if defined(_MSC_VER) || defined(WIN32) | #if defined(_MSC_VER) || defined(WIN32) | ||||
| #define ATTR_ALIGNED(v) __declspec(align(v)) | |||||
| #define ATTR_ALIGNED(v) __declspec(align(v)) | |||||
| #else | #else | ||||
| #define ATTR_ALIGNED(v) __attribute__((aligned(v))) | |||||
| #define ATTR_ALIGNED(v) __attribute__((aligned(v))) | |||||
| #endif | #endif | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -16,10 +16,10 @@ | |||||
| #include "megdnn/internal/defs.h" | #include "megdnn/internal/defs.h" | ||||
| #if MEGDNN_CC_HOST | #if MEGDNN_CC_HOST | ||||
| #include <cstdarg> | |||||
| #include <string> | #include <string> | ||||
| #include <type_traits> | #include <type_traits> | ||||
| #include <vector> | #include <vector> | ||||
| #include <cstdarg> | |||||
| #include "megdnn/thin/small_vector.h" | #include "megdnn/thin/small_vector.h" | ||||
| #endif // MEGDNN_CC_HOST | #endif // MEGDNN_CC_HOST | ||||
| @@ -35,8 +35,7 @@ class ErrorHandler { | |||||
| protected: | protected: | ||||
| MEGDNN_NORETURN virtual void do_on_megdnn_error(const std::string& msg) = 0; | MEGDNN_NORETURN virtual void do_on_megdnn_error(const std::string& msg) = 0; | ||||
| MEGDNN_NORETURN virtual void do_on_tensor_reshape_error( | |||||
| const std::string& msg) { | |||||
| MEGDNN_NORETURN virtual void do_on_tensor_reshape_error(const std::string& msg) { | |||||
| on_megdnn_error(msg); | on_megdnn_error(msg); | ||||
| } | } | ||||
| @@ -70,8 +69,9 @@ public: | |||||
| #if MEGDNN_CC_HOST | #if MEGDNN_CC_HOST | ||||
| enum class LogLevel { DEBUG, INFO, WARN, ERROR }; | enum class LogLevel { DEBUG, INFO, WARN, ERROR }; | ||||
| typedef void (*LogHandler)(LogLevel level, const char* file, const char* func, | |||||
| int line, const char* fmt, va_list ap); | |||||
| typedef void (*LogHandler)( | |||||
| LogLevel level, const char* file, const char* func, int line, const char* fmt, | |||||
| va_list ap); | |||||
| /*! | /*! | ||||
| * \brief set the callback to receive all log messages | * \brief set the callback to receive all log messages | ||||
| @@ -144,8 +144,7 @@ struct TensorLayout : public TensorShape { | |||||
| ptrdiff_t low_elem, low_byte; | ptrdiff_t low_elem, low_byte; | ||||
| size_t high_elem, high_byte; | size_t high_elem, high_byte; | ||||
| Span(ptrdiff_t low_elem, ptrdiff_t low_byte, size_t high_elem, | |||||
| size_t high_byte) | |||||
| Span(ptrdiff_t low_elem, ptrdiff_t low_byte, size_t high_elem, size_t high_byte) | |||||
| : low_elem(low_elem), | : low_elem(low_elem), | ||||
| low_byte(low_byte), | low_byte(low_byte), | ||||
| high_elem(high_elem), | high_elem(high_elem), | ||||
| @@ -235,11 +234,13 @@ struct TensorLayout : public TensorShape { | |||||
| TensorLayout(const TensorShape& shape, DType dtype, Format format); | TensorLayout(const TensorShape& shape, DType dtype, Format format); | ||||
| //! creating layout with user-specified shape and stride. | //! creating layout with user-specified shape and stride. | ||||
| TensorLayout(const TensorShape& shape, const std::vector<ptrdiff_t>& stride, | |||||
| DType dtype); | |||||
| TensorLayout( | |||||
| const TensorShape& shape, const std::vector<ptrdiff_t>& stride, | |||||
| DType dtype); | |||||
| TensorLayout(const TensorShape& shape, const std::vector<ptrdiff_t>& stride, | |||||
| DType dtype, Format format); | |||||
| TensorLayout( | |||||
| const TensorShape& shape, const std::vector<ptrdiff_t>& stride, DType dtype, | |||||
| Format format); | |||||
| /* =================== inplace modifiers =================== */ | /* =================== inplace modifiers =================== */ | ||||
| @@ -310,8 +311,7 @@ struct TensorLayout : public TensorShape { | |||||
| * | * | ||||
| * \throw TensorReshapeError if no stride exists for target shape. | * \throw TensorReshapeError if no stride exists for target shape. | ||||
| */ | */ | ||||
| TensorLayout reshape(const TensorShape& shape) const | |||||
| MEGDNN_WARN_UNUSED_RESULT; | |||||
| TensorLayout reshape(const TensorShape& shape) const MEGDNN_WARN_UNUSED_RESULT; | |||||
| /*! | /*! | ||||
| * \brief try to reshape to another view; return whether these two shapes | * \brief try to reshape to another view; return whether these two shapes | ||||
| @@ -319,15 +319,14 @@ struct TensorLayout : public TensorShape { | |||||
| * \return true iff there exists target stride so this layout can be | * \return true iff there exists target stride so this layout can be | ||||
| * converted to target shape and the elements can match. | * converted to target shape and the elements can match. | ||||
| */ | */ | ||||
| bool try_reshape(TensorLayout& output, | |||||
| const TensorShape& shape) const MEGDNN_WARN_UNUSED_RESULT; | |||||
| bool try_reshape(TensorLayout& output, const TensorShape& shape) const | |||||
| MEGDNN_WARN_UNUSED_RESULT; | |||||
| /*! | /*! | ||||
| * \brief Broadcast on dims with shape == 1 to match target *shape*. | * \brief Broadcast on dims with shape == 1 to match target *shape*. | ||||
| * \throw TensorReshapeError if could not be satisfied | * \throw TensorReshapeError if could not be satisfied | ||||
| */ | */ | ||||
| TensorLayout broadcast(const TensorShape& shape) const | |||||
| MEGDNN_WARN_UNUSED_RESULT; | |||||
| TensorLayout broadcast(const TensorShape& shape) const MEGDNN_WARN_UNUSED_RESULT; | |||||
| /*! | /*! | ||||
| * \brief Collapse consecutive axes with contiguous layout together | * \brief Collapse consecutive axes with contiguous layout together | ||||
| @@ -441,8 +440,7 @@ struct Workspace { | |||||
| Workspace() : raw_ptr(NULL), size(0) {} | Workspace() : raw_ptr(NULL), size(0) {} | ||||
| Workspace(dt_byte* raw_ptr_, size_t size_) | |||||
| : raw_ptr(raw_ptr_), size(size_) {} | |||||
| Workspace(dt_byte* raw_ptr_, size_t size_) : raw_ptr(raw_ptr_), size(size_) {} | |||||
| template <typename T> | template <typename T> | ||||
| T* ptr(size_t offset_in_bytes = 0) const { | T* ptr(size_t offset_in_bytes = 0) const { | ||||
| @@ -467,9 +465,8 @@ public: | |||||
| * \param shape requested output shape | * \param shape requested output shape | ||||
| * \param user_data extra user data passed in DynOutMallocPolicyCall | * \param user_data extra user data passed in DynOutMallocPolicyCall | ||||
| */ | */ | ||||
| virtual TensorND alloc_output(size_t id, DType dtype, | |||||
| const TensorShape& shape, | |||||
| void* user_data) = 0; | |||||
| virtual TensorND alloc_output( | |||||
| size_t id, DType dtype, const TensorShape& shape, void* user_data) = 0; | |||||
| /*! | /*! | ||||
| * \brief allocate workspace memory | * \brief allocate workspace memory | ||||
| @@ -508,19 +505,15 @@ struct DynOutMallocPolicyCall { | |||||
| */ | */ | ||||
| template <typename T = void, typename elem = T> | template <typename T = void, typename elem = T> | ||||
| T* alloc_workspace(size_t nr_elem) { | T* alloc_workspace(size_t nr_elem) { | ||||
| using real_elem = | |||||
| typename std::conditional<std::is_same<elem, void>::value, | |||||
| uint8_t, elem>::type; | |||||
| return static_cast<T*>(policy->alloc_workspace( | |||||
| nr_elem * sizeof(real_elem), user_data)); | |||||
| using real_elem = typename std::conditional< | |||||
| std::is_same<elem, void>::value, uint8_t, elem>::type; | |||||
| return static_cast<T*>( | |||||
| policy->alloc_workspace(nr_elem * sizeof(real_elem), user_data)); | |||||
| } | } | ||||
| void free_workspace(void* ptr) { | |||||
| return policy->free_workspace(ptr, user_data); | |||||
| } | |||||
| void free_workspace(void* ptr) { return policy->free_workspace(ptr, user_data); } | |||||
| }; | }; | ||||
| template <typename T> | template <typename T> | ||||
| class EnumClassBit { | class EnumClassBit { | ||||
| std::underlying_type_t<T> m_val; | std::underlying_type_t<T> m_val; | ||||
| @@ -528,8 +521,7 @@ class EnumClassBit { | |||||
| constexpr EnumClassBit(std::underlying_type_t<T> v) : m_val(v) {} | constexpr EnumClassBit(std::underlying_type_t<T> v) : m_val(v) {} | ||||
| public: | public: | ||||
| constexpr EnumClassBit(T v) | |||||
| : m_val(static_cast<std::underlying_type_t<T>>(v)) {} | |||||
| constexpr EnumClassBit(T v) : m_val(static_cast<std::underlying_type_t<T>>(v)) {} | |||||
| constexpr operator T() const { return static_cast<T>(m_val); } | constexpr operator T() const { return static_cast<T>(m_val); } | ||||
| @@ -542,7 +534,7 @@ public: | |||||
| DEF_OPR(&) | DEF_OPR(&) | ||||
| DEF_OPR(|) | DEF_OPR(|) | ||||
| DEF_OPR (^) | |||||
| DEF_OPR(^) | |||||
| constexpr EnumClassBit operator~() const { return ~m_val; } | constexpr EnumClassBit operator~() const { return ~m_val; } | ||||
| @@ -553,14 +545,13 @@ public: | |||||
| } // namespace megdnn | } // namespace megdnn | ||||
| #define _MEGDNN_DECBO_SINGLE_OPR(cls, op) \ | |||||
| inline constexpr ::megdnn::EnumClassBit<cls> operator op(cls x, cls y) { \ | |||||
| return ::megdnn::EnumClassBit<cls>(x) \ | |||||
| op ::megdnn::EnumClassBit<cls>(y); \ | |||||
| } \ | |||||
| inline constexpr ::megdnn::EnumClassBit<cls> operator op( \ | |||||
| ::megdnn::EnumClassBit<cls> x, cls y) { \ | |||||
| return x op ::megdnn::EnumClassBit<cls>(y); \ | |||||
| #define _MEGDNN_DECBO_SINGLE_OPR(cls, op) \ | |||||
| inline constexpr ::megdnn::EnumClassBit<cls> operator op(cls x, cls y) { \ | |||||
| return ::megdnn::EnumClassBit<cls>(x) op ::megdnn::EnumClassBit<cls>(y); \ | |||||
| } \ | |||||
| inline constexpr ::megdnn::EnumClassBit<cls> operator op( \ | |||||
| ::megdnn::EnumClassBit<cls> x, cls y) { \ | |||||
| return x op ::megdnn::EnumClassBit<cls>(y); \ | |||||
| } | } | ||||
| #define _MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, op) \ | #define _MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, op) \ | ||||
| @@ -14,14 +14,14 @@ | |||||
| #include "megbrain_build_config.h" | #include "megbrain_build_config.h" | ||||
| #if MGB_ENABLE_GETENV | #if MGB_ENABLE_GETENV | ||||
| #define MGB_GETENV ::std::getenv | |||||
| #define MGB_GETENV ::std::getenv | |||||
| #else | #else | ||||
| #define MGB_GETENV(_name) static_cast<char*>(nullptr) | |||||
| #define MGB_GETENV(_name) static_cast<char*>(nullptr) | |||||
| #endif | #endif | ||||
| #ifdef WIN32 | #ifdef WIN32 | ||||
| #define unsetenv(_name) _putenv_s(_name, ""); | |||||
| #define setenv(name,value,overwrite) _putenv_s(name,value) | |||||
| #define unsetenv(_name) _putenv_s(_name, ""); | |||||
| #define setenv(name, value, overwrite) _putenv_s(name, value) | |||||
| #endif | #endif | ||||
| namespace megdnn { | namespace megdnn { | ||||
| @@ -32,8 +32,7 @@ namespace megdnn { | |||||
| */ | */ | ||||
| template <class Opr, typename... Args> | template <class Opr, typename... Args> | ||||
| bool has_available_algo(Opr* opr, Args&&... args) { | bool has_available_algo(Opr* opr, Args&&... args) { | ||||
| const typename Opr::AlgoBase::SizeArgs size_args( | |||||
| opr, std::forward<Args>(args)...); | |||||
| const typename Opr::AlgoBase::SizeArgs size_args(opr, std::forward<Args>(args)...); | |||||
| for (auto i : Opr::algo_pack().all_algos) { | for (auto i : Opr::algo_pack().all_algos) { | ||||
| if (i->is_available(size_args)) { | if (i->is_available(size_args)) { | ||||
| return true; | return true; | ||||
| @@ -42,6 +41,6 @@ bool has_available_algo(Opr* opr, Args&&... args) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| } | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| @@ -17,11 +17,11 @@ | |||||
| #include "megdnn/internal/visibility_prologue.h" | #include "megdnn/internal/visibility_prologue.h" | ||||
| namespace megdnn { | namespace megdnn { | ||||
| std::unique_ptr<Handle> make_cuda_handle_with_stream(cudaStream_t stream, | |||||
| int device_id = -1); | |||||
| cudaStream_t get_cuda_stream(Handle *handle); | |||||
| std::unique_ptr<Handle> make_cuda_handle_with_stream( | |||||
| cudaStream_t stream, int device_id = -1); | |||||
| cudaStream_t get_cuda_stream(Handle* handle); | |||||
| } // namespace megdnn | |||||
| } // namespace megdnn | |||||
| #include "megdnn/internal/visibility_epilogue.h" | #include "megdnn/internal/visibility_epilogue.h" | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -3,17 +3,22 @@ | |||||
| * | * | ||||
| * Copyright (c) 2012-2013 Christian Rau <rauy@users.sourceforge.net> | * Copyright (c) 2012-2013 Christian Rau <rauy@users.sourceforge.net> | ||||
| * | * | ||||
| * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation | |||||
| * files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, | |||||
| * modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the | |||||
| * Software is furnished to do so, subject to the following conditions: | |||||
| * | |||||
| * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. | |||||
| * | |||||
| * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE | |||||
| * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR | |||||
| * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, | |||||
| * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | |||||
| * Permission is hereby granted, free of charge, to any person obtaining a copy of this | |||||
| * software and associated documentation files (the "Software"), to deal in the Software | |||||
| * without restriction, including without limitation the rights to use, copy, modify, | |||||
| * merge, publish, distribute, sublicense, and/or sell copies of the Software, and to | |||||
| * permit persons to whom the Software is furnished to do so, subject to the following | |||||
| * conditions: | |||||
| * | |||||
| * The above copyright notice and this permission notice shall be included in all copies | |||||
| * or substantial portions of the Software. | |||||
| * | |||||
| * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, | |||||
| * INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A | |||||
| * PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT | |||||
| * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF | |||||
| * CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE | |||||
| * OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | |||||
| * | * | ||||
| * Version 1.11.0 | * Version 1.11.0 | ||||
| * \file | * \file | ||||
| @@ -41,8 +46,8 @@ | |||||
| #undef HALF_NOEXCEPT | #undef HALF_NOEXCEPT | ||||
| #undef HALF_NOTHROW | #undef HALF_NOTHROW | ||||
| #ifdef HALF_POP_WARNINGS | #ifdef HALF_POP_WARNINGS | ||||
| #pragma warning(pop) | |||||
| #undef HALF_POP_WARNINGS | |||||
| #pragma warning(pop) | |||||
| #undef HALF_POP_WARNINGS | |||||
| #endif | #endif | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -3,17 +3,22 @@ | |||||
| * | * | ||||
| * Copyright (c) 2012-2013 Christian Rau <rauy@users.sourceforge.net> | * Copyright (c) 2012-2013 Christian Rau <rauy@users.sourceforge.net> | ||||
| * | * | ||||
| * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation | |||||
| * files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, | |||||
| * modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the | |||||
| * Software is furnished to do so, subject to the following conditions: | |||||
| * Permission is hereby granted, free of charge, to any person obtaining a copy of this | |||||
| * software and associated documentation files (the "Software"), to deal in the Software | |||||
| * without restriction, including without limitation the rights to use, copy, modify, | |||||
| * merge, publish, distribute, sublicense, and/or sell copies of the Software, and to | |||||
| * permit persons to whom the Software is furnished to do so, subject to the following | |||||
| * conditions: | |||||
| * | * | ||||
| * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. | |||||
| * The above copyright notice and this permission notice shall be included in all copies | |||||
| * or substantial portions of the Software. | |||||
| * | * | ||||
| * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE | |||||
| * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR | |||||
| * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, | |||||
| * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | |||||
| * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, | |||||
| * INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A | |||||
| * PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT | |||||
| * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF | |||||
| * CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE | |||||
| * OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | |||||
| * | * | ||||
| * Version 1.11.0 | * Version 1.11.0 | ||||
| * \file | * \file | ||||
| @@ -39,166 +44,164 @@ | |||||
| #include "megdnn/arch.h" | #include "megdnn/arch.h" | ||||
| /// Combined gcc version number. | /// Combined gcc version number. | ||||
| #define HALF_GNUC_VERSION (__GNUC__*100+__GNUC_MINOR__) | |||||
| #define HALF_GNUC_VERSION (__GNUC__ * 100 + __GNUC_MINOR__) | |||||
| //check C++11 language features | |||||
| #if defined(__clang__) //clang | |||||
| #if __has_feature(cxx_static_assert) && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||||
| #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||||
| #endif | |||||
| #if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR) | |||||
| #define HALF_ENABLE_CPP11_CONSTEXPR 1 | |||||
| #endif | |||||
| #if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT) | |||||
| #define HALF_ENABLE_CPP11_NOEXCEPT 1 | |||||
| #endif | |||||
| #if __has_feature(cxx_user_literals) && !defined(HALF_ENABLE_CPP11_USER_LITERALS) | |||||
| #define HALF_ENABLE_CPP11_USER_LITERALS 1 | |||||
| #endif | |||||
| #if (defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && !defined(HALF_ENABLE_CPP11_LONG_LONG) | |||||
| #define HALF_ENABLE_CPP11_LONG_LONG 1 | |||||
| #endif | |||||
| /*#elif defined(__INTEL_COMPILER) //Intel C++ | |||||
| #if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) ???????? | |||||
| #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||||
| #endif | |||||
| #if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) ???????? | |||||
| #define HALF_ENABLE_CPP11_CONSTEXPR 1 | |||||
| #endif | |||||
| #if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) ???????? | |||||
| #define HALF_ENABLE_CPP11_NOEXCEPT 1 | |||||
| #endif | |||||
| #if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_LONG_LONG) ???????? | |||||
| #define HALF_ENABLE_CPP11_LONG_LONG 1 | |||||
| #endif*/ | |||||
| #elif defined(__GNUC__) //gcc | |||||
| #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L | |||||
| #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||||
| #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||||
| #endif | |||||
| #if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) | |||||
| #define HALF_ENABLE_CPP11_CONSTEXPR 1 | |||||
| #endif | |||||
| #if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) | |||||
| #define HALF_ENABLE_CPP11_NOEXCEPT 1 | |||||
| #endif | |||||
| #if HALF_GNUC_VERSION >= 407 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) | |||||
| #define HALF_ENABLE_CPP11_USER_LITERALS 1 | |||||
| #endif | |||||
| #if !defined(HALF_ENABLE_CPP11_LONG_LONG) | |||||
| #define HALF_ENABLE_CPP11_LONG_LONG 1 | |||||
| #endif | |||||
| #endif | |||||
| #elif defined(_MSC_VER) //Visual C++ | |||||
| #if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||||
| #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||||
| #endif | |||||
| #if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG) | |||||
| #define HALF_ENABLE_CPP11_LONG_LONG 1 | |||||
| #endif | |||||
| #define HALF_POP_WARNINGS 1 | |||||
| #pragma warning(push) | |||||
| //! 4521 and 4522 is multiple copy/assigment operator specified | |||||
| #pragma warning(disable : 4099 4127 4146 4521 4522) //struct vs class, constant in if, negative unsigned | |||||
| // check C++11 language features | |||||
| #if defined(__clang__) // clang | |||||
| #if __has_feature(cxx_static_assert) && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||||
| #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||||
| #endif | |||||
| #if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR) | |||||
| #define HALF_ENABLE_CPP11_CONSTEXPR 1 | |||||
| #endif | |||||
| #if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT) | |||||
| #define HALF_ENABLE_CPP11_NOEXCEPT 1 | |||||
| #endif | |||||
| #if __has_feature(cxx_user_literals) && !defined(HALF_ENABLE_CPP11_USER_LITERALS) | |||||
| #define HALF_ENABLE_CPP11_USER_LITERALS 1 | |||||
| #endif | |||||
| #if (defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && \ | |||||
| !defined(HALF_ENABLE_CPP11_LONG_LONG) | |||||
| #define HALF_ENABLE_CPP11_LONG_LONG 1 | |||||
| #endif | |||||
| /*#elif defined(__INTEL_COMPILER) | |||||
| //Intel C++ #if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||||
| ???????? #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 #endif #if __INTEL_COMPILER >= | |||||
| 1300 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) ???????? #define | |||||
| HALF_ENABLE_CPP11_CONSTEXPR 1 #endif #if __INTEL_COMPILER >= 1300 && | |||||
| !defined(HALF_ENABLE_CPP11_NOEXCEPT) ???????? #define | |||||
| HALF_ENABLE_CPP11_NOEXCEPT 1 #endif #if __INTEL_COMPILER >= 1100 && | |||||
| !defined(HALF_ENABLE_CPP11_LONG_LONG) ???????? #define | |||||
| HALF_ENABLE_CPP11_LONG_LONG 1 #endif*/ | |||||
| #elif defined(__GNUC__) // gcc | |||||
| #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L | |||||
| #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||||
| #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||||
| #endif | |||||
| #if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) | |||||
| #define HALF_ENABLE_CPP11_CONSTEXPR 1 | |||||
| #endif | |||||
| #if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) | |||||
| #define HALF_ENABLE_CPP11_NOEXCEPT 1 | |||||
| #endif | |||||
| #if HALF_GNUC_VERSION >= 407 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) | |||||
| #define HALF_ENABLE_CPP11_USER_LITERALS 1 | |||||
| #endif | |||||
| #if !defined(HALF_ENABLE_CPP11_LONG_LONG) | |||||
| #define HALF_ENABLE_CPP11_LONG_LONG 1 | |||||
| #endif | |||||
| #endif | |||||
| #elif defined(_MSC_VER) // Visual C++ | |||||
| #if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||||
| #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||||
| #endif | |||||
| #if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG) | |||||
| #define HALF_ENABLE_CPP11_LONG_LONG 1 | |||||
| #endif | |||||
| #define HALF_POP_WARNINGS 1 | |||||
| #pragma warning(push) | |||||
| //! 4521 and 4522 is multiple copy/assigment operator specified | |||||
| #pragma warning(disable : 4099 4127 4146 4521 4522) // struct vs class, constant in if, | |||||
| // negative unsigned | |||||
| #endif | #endif | ||||
| //check C++11 library features | |||||
| // check C++11 library features | |||||
| #include <utility> | #include <utility> | ||||
| #if defined(_LIBCPP_VERSION) //libc++ | |||||
| #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 | |||||
| #ifndef HALF_ENABLE_CPP11_TYPE_TRAITS | |||||
| #define HALF_ENABLE_CPP11_TYPE_TRAITS 1 | |||||
| #endif | |||||
| #ifndef HALF_ENABLE_CPP11_CSTDINT | |||||
| #define HALF_ENABLE_CPP11_CSTDINT 1 | |||||
| #endif | |||||
| #ifndef HALF_ENABLE_CPP11_CMATH | |||||
| #define HALF_ENABLE_CPP11_CMATH 1 | |||||
| #endif | |||||
| #ifndef HALF_ENABLE_CPP11_HASH | |||||
| #define HALF_ENABLE_CPP11_HASH 1 | |||||
| #endif | |||||
| #endif | |||||
| #elif defined(__GLIBCXX__) //libstdc++ | |||||
| #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 | |||||
| #ifdef __clang__ | |||||
| #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) | |||||
| #define HALF_ENABLE_CPP11_TYPE_TRAITS 1 | |||||
| #endif | |||||
| #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT) | |||||
| #define HALF_ENABLE_CPP11_CSTDINT 1 | |||||
| #endif | |||||
| #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH) | |||||
| #define HALF_ENABLE_CPP11_CMATH 1 | |||||
| #endif | |||||
| #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH) | |||||
| #define HALF_ENABLE_CPP11_HASH 1 | |||||
| #endif | |||||
| #else | |||||
| #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CSTDINT) | |||||
| #define HALF_ENABLE_CPP11_CSTDINT 1 | |||||
| #endif | |||||
| #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CMATH) | |||||
| #define HALF_ENABLE_CPP11_CMATH 1 | |||||
| #endif | |||||
| #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH) | |||||
| #define HALF_ENABLE_CPP11_HASH 1 | |||||
| #endif | |||||
| #endif | |||||
| #endif | |||||
| #elif defined(_CPPLIB_VER) //Dinkumware/Visual C++ | |||||
| #if _CPPLIB_VER >= 520 | |||||
| #ifndef HALF_ENABLE_CPP11_TYPE_TRAITS | |||||
| #define HALF_ENABLE_CPP11_TYPE_TRAITS 1 | |||||
| #endif | |||||
| #ifndef HALF_ENABLE_CPP11_CSTDINT | |||||
| #define HALF_ENABLE_CPP11_CSTDINT 1 | |||||
| #endif | |||||
| #ifndef HALF_ENABLE_CPP11_HASH | |||||
| #define HALF_ENABLE_CPP11_HASH 1 | |||||
| #endif | |||||
| #endif | |||||
| #if _CPPLIB_VER >= 610 | |||||
| #ifndef HALF_ENABLE_CPP11_CMATH | |||||
| #define HALF_ENABLE_CPP11_CMATH 1 | |||||
| #endif | |||||
| #endif | |||||
| #if defined(_LIBCPP_VERSION) // libc++ | |||||
| #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 | |||||
| #ifndef HALF_ENABLE_CPP11_TYPE_TRAITS | |||||
| #define HALF_ENABLE_CPP11_TYPE_TRAITS 1 | |||||
| #endif | |||||
| #ifndef HALF_ENABLE_CPP11_CSTDINT | |||||
| #define HALF_ENABLE_CPP11_CSTDINT 1 | |||||
| #endif | |||||
| #ifndef HALF_ENABLE_CPP11_CMATH | |||||
| #define HALF_ENABLE_CPP11_CMATH 1 | |||||
| #endif | |||||
| #ifndef HALF_ENABLE_CPP11_HASH | |||||
| #define HALF_ENABLE_CPP11_HASH 1 | |||||
| #endif | |||||
| #endif | |||||
| #elif defined(__GLIBCXX__) // libstdc++ | |||||
| #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 | |||||
| #ifdef __clang__ | |||||
| #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) | |||||
| #define HALF_ENABLE_CPP11_TYPE_TRAITS 1 | |||||
| #endif | |||||
| #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT) | |||||
| #define HALF_ENABLE_CPP11_CSTDINT 1 | |||||
| #endif | |||||
| #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH) | |||||
| #define HALF_ENABLE_CPP11_CMATH 1 | |||||
| #endif | |||||
| #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH) | |||||
| #define HALF_ENABLE_CPP11_HASH 1 | |||||
| #endif | |||||
| #else | |||||
| #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CSTDINT) | |||||
| #define HALF_ENABLE_CPP11_CSTDINT 1 | |||||
| #endif | |||||
| #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CMATH) | |||||
| #define HALF_ENABLE_CPP11_CMATH 1 | |||||
| #endif | |||||
| #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH) | |||||
| #define HALF_ENABLE_CPP11_HASH 1 | |||||
| #endif | |||||
| #endif | |||||
| #endif | |||||
| #elif defined(_CPPLIB_VER) // Dinkumware/Visual C++ | |||||
| #if _CPPLIB_VER >= 520 | |||||
| #ifndef HALF_ENABLE_CPP11_TYPE_TRAITS | |||||
| #define HALF_ENABLE_CPP11_TYPE_TRAITS 1 | |||||
| #endif | |||||
| #ifndef HALF_ENABLE_CPP11_CSTDINT | |||||
| #define HALF_ENABLE_CPP11_CSTDINT 1 | |||||
| #endif | |||||
| #ifndef HALF_ENABLE_CPP11_HASH | |||||
| #define HALF_ENABLE_CPP11_HASH 1 | |||||
| #endif | |||||
| #endif | |||||
| #if _CPPLIB_VER >= 610 | |||||
| #ifndef HALF_ENABLE_CPP11_CMATH | |||||
| #define HALF_ENABLE_CPP11_CMATH 1 | |||||
| #endif | |||||
| #endif | |||||
| #endif | #endif | ||||
| #undef HALF_GNUC_VERSION | #undef HALF_GNUC_VERSION | ||||
| //support constexpr | |||||
| // support constexpr | |||||
| #if HALF_ENABLE_CPP11_CONSTEXPR | #if HALF_ENABLE_CPP11_CONSTEXPR | ||||
| #define HALF_CONSTEXPR constexpr | |||||
| #define HALF_CONSTEXPR_CONST constexpr | |||||
| #define HALF_CONSTEXPR constexpr | |||||
| #define HALF_CONSTEXPR_CONST constexpr | |||||
| #else | #else | ||||
| #define HALF_CONSTEXPR | |||||
| #define HALF_CONSTEXPR_CONST const | |||||
| #define HALF_CONSTEXPR | |||||
| #define HALF_CONSTEXPR_CONST const | |||||
| #endif | #endif | ||||
| //support noexcept | |||||
| // support noexcept | |||||
| #if HALF_ENABLE_CPP11_NOEXCEPT | #if HALF_ENABLE_CPP11_NOEXCEPT | ||||
| #define HALF_NOEXCEPT noexcept | |||||
| #define HALF_NOTHROW noexcept | |||||
| #define HALF_NOEXCEPT noexcept | |||||
| #define HALF_NOTHROW noexcept | |||||
| #else | #else | ||||
| #define HALF_NOEXCEPT | |||||
| #define HALF_NOTHROW throw() | |||||
| #define HALF_NOEXCEPT | |||||
| #define HALF_NOTHROW throw() | |||||
| #endif | #endif | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <limits> | |||||
| #include <climits> | #include <climits> | ||||
| #include <cmath> | #include <cmath> | ||||
| #include <cstring> | #include <cstring> | ||||
| #include <ostream> | |||||
| #include <istream> | #include <istream> | ||||
| #include <limits> | |||||
| #include <ostream> | |||||
| #if HALF_ENABLE_CPP11_TYPE_TRAITS | #if HALF_ENABLE_CPP11_TYPE_TRAITS | ||||
| #include <type_traits> | |||||
| #include <type_traits> | |||||
| #endif | #endif | ||||
| #if HALF_ENABLE_CPP11_CSTDINT | #if HALF_ENABLE_CPP11_CSTDINT | ||||
| #include <cstdint> | |||||
| #include <cstdint> | |||||
| #endif | #endif | ||||
| #if HALF_ENABLE_CPP11_HASH | #if HALF_ENABLE_CPP11_HASH | ||||
| #include <functional> | |||||
| #include <functional> | |||||
| #endif | #endif | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -12,8 +12,8 @@ | |||||
| #pragma once | #pragma once | ||||
| #include "megcore.h" | #include "megcore.h" | ||||
| #include "megdnn/config/config.h" | |||||
| #include "megdnn/basic_types.h" | #include "megdnn/basic_types.h" | ||||
| #include "megdnn/config/config.h" | |||||
| #include <functional> | #include <functional> | ||||
| #include <memory> | #include <memory> | ||||
| @@ -24,150 +24,147 @@ namespace megdnn { | |||||
| class OperatorBase; | class OperatorBase; | ||||
| class Handle { | class Handle { | ||||
| public: | |||||
| enum class HandleType { | |||||
| NAIVE = 0, | |||||
| FALLBACK = 1, | |||||
| X86 = 2, | |||||
| ARM_COMMON = 3, | |||||
| ARMV7 = 4, | |||||
| AARCH64 = 5, | |||||
| CUDA = 6, | |||||
| ROCM = 11, | |||||
| ATLAS = 13, | |||||
| CAMBRICON = 12, | |||||
| }; | |||||
| //! Device vendor | |||||
| enum class HandleVendorType : uint32_t { | |||||
| NOT_SPEC = 0, | |||||
| MALI = 1, | |||||
| ADRENO = 2, | |||||
| CUDA = 3, | |||||
| INTEL = 4, | |||||
| POWERVR = 5, | |||||
| AMD = 6, | |||||
| }; | |||||
| protected: | |||||
| Handle(megcoreComputingHandle_t computing_handle, HandleType type); | |||||
| public: | |||||
| /** | |||||
| * \brief Create a MegDNN handle from a MegCore Computing handle. | |||||
| * | |||||
| * \param[in] computing_handle MegCore computing handle. Please note | |||||
| * that computing_handle would not be released when this Handle is | |||||
| * destructed | |||||
| * \param[in] debug_level | |||||
| * Applicable for CPU computing handle. | |||||
| * 0 means taking the fastest possible code path; it may contains | |||||
| * platform-specific instructions such as SSE for x86_64 or NEON for | |||||
| * armv7v7. | |||||
| * 1 means taking the fastest possible code path without | |||||
| * platform-specific instructions in C++ code. Note that the compiled | |||||
| * binary file still contains platform-specific codes. | |||||
| * 2 means taking the naive code path. Performance is severely | |||||
| * hampered, but it is less error-prone since the internal | |||||
| * implementation is rather straightforward. | |||||
| * | |||||
| * **Debug level 1 and 2 should not be used in productions.** | |||||
| */ | |||||
| static std::unique_ptr<Handle> make( | |||||
| megcoreComputingHandle_t computing_handle, | |||||
| int debug_level = 0); | |||||
| public: | |||||
| enum class HandleType { | |||||
| NAIVE = 0, | |||||
| FALLBACK = 1, | |||||
| X86 = 2, | |||||
| ARM_COMMON = 3, | |||||
| ARMV7 = 4, | |||||
| AARCH64 = 5, | |||||
| CUDA = 6, | |||||
| ROCM = 11, | |||||
| ATLAS = 13, | |||||
| CAMBRICON = 12, | |||||
| }; | |||||
| //! Device vendor | |||||
| enum class HandleVendorType : uint32_t { | |||||
| NOT_SPEC = 0, | |||||
| MALI = 1, | |||||
| ADRENO = 2, | |||||
| CUDA = 3, | |||||
| INTEL = 4, | |||||
| POWERVR = 5, | |||||
| AMD = 6, | |||||
| }; | |||||
| protected: | |||||
| Handle(megcoreComputingHandle_t computing_handle, HandleType type); | |||||
| public: | |||||
| /** | |||||
| * \brief Create a MegDNN handle from a MegCore Computing handle. | |||||
| * | |||||
| * \param[in] computing_handle MegCore computing handle. Please note | |||||
| * that computing_handle would not be released when this Handle is | |||||
| * destructed | |||||
| * \param[in] debug_level | |||||
| * Applicable for CPU computing handle. | |||||
| * 0 means taking the fastest possible code path; it may contains | |||||
| * platform-specific instructions such as SSE for x86_64 or NEON for | |||||
| * armv7v7. | |||||
| * 1 means taking the fastest possible code path without | |||||
| * platform-specific instructions in C++ code. Note that the compiled | |||||
| * binary file still contains platform-specific codes. | |||||
| * 2 means taking the naive code path. Performance is severely | |||||
| * hampered, but it is less error-prone since the internal | |||||
| * implementation is rather straightforward. | |||||
| * | |||||
| * **Debug level 1 and 2 should not be used in productions.** | |||||
| */ | |||||
| static std::unique_ptr<Handle> make( | |||||
| megcoreComputingHandle_t computing_handle, int debug_level = 0); | |||||
| #if MEGDNN_WITH_CUDA | #if MEGDNN_WITH_CUDA | ||||
| static std::unique_ptr<Handle> make_cuda_handle( | |||||
| megcoreComputingHandle_t computing_handle); | |||||
| template <typename opr> | |||||
| std::unique_ptr<opr> create_cuda_operator(); | |||||
| static std::unique_ptr<Handle> make_cuda_handle( | |||||
| megcoreComputingHandle_t computing_handle); | |||||
| template <typename opr> | |||||
| std::unique_ptr<opr> create_cuda_operator(); | |||||
| #endif | #endif | ||||
| #if MEGDNN_WITH_ROCM | #if MEGDNN_WITH_ROCM | ||||
| static std::unique_ptr<Handle> make_rocm_handle( | |||||
| megcoreComputingHandle_t computing_handle); | |||||
| template <typename opr> | |||||
| std::unique_ptr<opr> create_rocm_operator(); | |||||
| static std::unique_ptr<Handle> make_rocm_handle( | |||||
| megcoreComputingHandle_t computing_handle); | |||||
| template <typename opr> | |||||
| std::unique_ptr<opr> create_rocm_operator(); | |||||
| #endif | #endif | ||||
| virtual ~Handle(); | |||||
| /*! | |||||
| * \brief Get the underlying megcore computing handle. | |||||
| */ | |||||
| megcoreComputingHandle_t megcore_computing_handle() const { | |||||
| return m_computing_handle; | |||||
| } | |||||
| /*! | |||||
| * \brief set a callback function to be invoked when this handle is | |||||
| * destructed, so associated resources can be released (e.g. | |||||
| * computing handle) | |||||
| * | |||||
| * This function can be called at most once. | |||||
| */ | |||||
| void set_destructor(const thin_function<void()> &d); | |||||
| /*! | |||||
| * \brief set a callback to be invoked when an operator is destructed | |||||
| * \param[in,out] cb the callback function; it would be set to the | |||||
| * previous callback function | |||||
| */ | |||||
| void set_opr_destruct_callback(thin_function<void(OperatorBase*)> &cb) { | |||||
| cb.swap(m_on_opr_destructed); | |||||
| } | |||||
| void on_opr_destructed(OperatorBase* opr); | |||||
| /** | |||||
| * \brief Create operator of Opr type. | |||||
| */ | |||||
| template <typename Opr> | |||||
| std::unique_ptr<Opr> create_operator(); | |||||
| /* | |||||
| * ============================================================= | |||||
| * Users should call functions below to query memory requirement. | |||||
| * ============================================================= | |||||
| */ | |||||
| /** | |||||
| * \brief The internal data pointer of TensorND should be aligned to | |||||
| * alignment_requirement() in bytes. | |||||
| */ | |||||
| virtual size_t alignment_requirement() const; | |||||
| //! get alignment in bytes for rows of image 2D tensor format | |||||
| virtual size_t image2d_pitch_alignment() const; | |||||
| //! get vendor type | |||||
| virtual HandleVendorType vendor_type() const; | |||||
| HandleType type() const { | |||||
| return m_handle_type; | |||||
| } | |||||
| /** | |||||
| * \brief Check is the layout satisfy cross device copy constraint. | |||||
| * 1. The handle of the src and the dst is the same kind | |||||
| * 2. The dst is continguous. | |||||
| */ | |||||
| virtual bool check_cross_dev_copy_constraint(const TensorLayout &src); | |||||
| private: | |||||
| static constexpr uint32_t ALIVE_MAGIC = 0x8595e9d2u; | |||||
| volatile uint32_t m_alive_magic = ALIVE_MAGIC; | |||||
| megcoreComputingHandle_t m_computing_handle; | |||||
| const HandleType m_handle_type; | |||||
| thin_function<void()> m_destructor; | |||||
| thin_function<void(OperatorBase*)> m_on_opr_destructed; | |||||
| Handle() = delete; | |||||
| Handle(const Handle &rhs) = delete; | |||||
| Handle &operator=(const Handle &rhs) = delete; | |||||
| virtual ~Handle(); | |||||
| /*! | |||||
| * \brief Get the underlying megcore computing handle. | |||||
| */ | |||||
| megcoreComputingHandle_t megcore_computing_handle() const { | |||||
| return m_computing_handle; | |||||
| } | |||||
| /*! | |||||
| * \brief set a callback function to be invoked when this handle is | |||||
| * destructed, so associated resources can be released (e.g. | |||||
| * computing handle) | |||||
| * | |||||
| * This function can be called at most once. | |||||
| */ | |||||
| void set_destructor(const thin_function<void()>& d); | |||||
| /*! | |||||
| * \brief set a callback to be invoked when an operator is destructed | |||||
| * \param[in,out] cb the callback function; it would be set to the | |||||
| * previous callback function | |||||
| */ | |||||
| void set_opr_destruct_callback(thin_function<void(OperatorBase*)>& cb) { | |||||
| cb.swap(m_on_opr_destructed); | |||||
| } | |||||
| void on_opr_destructed(OperatorBase* opr); | |||||
| /** | |||||
| * \brief Create operator of Opr type. | |||||
| */ | |||||
| template <typename Opr> | |||||
| std::unique_ptr<Opr> create_operator(); | |||||
| /* | |||||
| * ============================================================= | |||||
| * Users should call functions below to query memory requirement. | |||||
| * ============================================================= | |||||
| */ | |||||
| /** | |||||
| * \brief The internal data pointer of TensorND should be aligned to | |||||
| * alignment_requirement() in bytes. | |||||
| */ | |||||
| virtual size_t alignment_requirement() const; | |||||
| //! get alignment in bytes for rows of image 2D tensor format | |||||
| virtual size_t image2d_pitch_alignment() const; | |||||
| //! get vendor type | |||||
| virtual HandleVendorType vendor_type() const; | |||||
| HandleType type() const { return m_handle_type; } | |||||
| /** | |||||
| * \brief Check is the layout satisfy cross device copy constraint. | |||||
| * 1. The handle of the src and the dst is the same kind | |||||
| * 2. The dst is continguous. | |||||
| */ | |||||
| virtual bool check_cross_dev_copy_constraint(const TensorLayout& src); | |||||
| private: | |||||
| static constexpr uint32_t ALIVE_MAGIC = 0x8595e9d2u; | |||||
| volatile uint32_t m_alive_magic = ALIVE_MAGIC; | |||||
| megcoreComputingHandle_t m_computing_handle; | |||||
| const HandleType m_handle_type; | |||||
| thin_function<void()> m_destructor; | |||||
| thin_function<void(OperatorBase*)> m_on_opr_destructed; | |||||
| Handle() = delete; | |||||
| Handle(const Handle& rhs) = delete; | |||||
| Handle& operator=(const Handle& rhs) = delete; | |||||
| }; | }; | ||||
| } // namespace megdnn | |||||
| } // namespace megdnn | |||||
| #include "megdnn/internal/visibility_epilogue.h" | #include "megdnn/internal/visibility_epilogue.h" | ||||
| @@ -49,8 +49,9 @@ public: | |||||
| mutable std::string m_input; | mutable std::string m_input; | ||||
| public: | public: | ||||
| Key(Handle* opr_handle, Algorithm::OprType opr_type, const TensorLayout* inp_layouts_ptr, | |||||
| size_t inp_layouts_size, const void* param_ptr = nullptr, size_t param_size = 0) | |||||
| Key(Handle* opr_handle, Algorithm::OprType opr_type, | |||||
| const TensorLayout* inp_layouts_ptr, size_t inp_layouts_size, | |||||
| const void* param_ptr = nullptr, size_t param_size = 0) | |||||
| : m_handle{opr_handle}, | : m_handle{opr_handle}, | ||||
| m_opr_type{static_cast<uint32_t>(opr_type)}, | m_opr_type{static_cast<uint32_t>(opr_type)}, | ||||
| m_inp_layouts_ptr{inp_layouts_ptr}, | m_inp_layouts_ptr{inp_layouts_ptr}, | ||||
| @@ -16,20 +16,19 @@ | |||||
| * \brief iterate through small (usually used) ndim values | * \brief iterate through small (usually used) ndim values | ||||
| */ | */ | ||||
| #define MEGDNN_FOREACH_TENSOR_NDIM_SMALL(cb, ...) \ | #define MEGDNN_FOREACH_TENSOR_NDIM_SMALL(cb, ...) \ | ||||
| cb(1 ,##__VA_ARGS__) cb(2 ,##__VA_ARGS__) cb(3 ,##__VA_ARGS__) | |||||
| cb(1, ##__VA_ARGS__) cb(2, ##__VA_ARGS__) cb(3, ##__VA_ARGS__) | |||||
| /*! | /*! | ||||
| * \brief iterate through large (rarely used) ndim values | * \brief iterate through large (rarely used) ndim values | ||||
| */ | */ | ||||
| #define MEGDNN_FOREACH_TENSOR_NDIM_LARGE(cb, ...) \ | #define MEGDNN_FOREACH_TENSOR_NDIM_LARGE(cb, ...) \ | ||||
| cb(4 ,##__VA_ARGS__) cb(5 ,##__VA_ARGS__) cb(6 ,##__VA_ARGS__) \ | |||||
| cb(7, ##__VA_ARGS__) | |||||
| cb(4, ##__VA_ARGS__) cb(5, ##__VA_ARGS__) cb(6, ##__VA_ARGS__) cb(7, ##__VA_ARGS__) | |||||
| /*! | /*! | ||||
| * \brief iterate through all ndim values | * \brief iterate through all ndim values | ||||
| */ | */ | ||||
| #define MEGDNN_FOREACH_TENSOR_NDIM(cb, ...) \ | |||||
| MEGDNN_FOREACH_TENSOR_NDIM_SMALL(cb ,##__VA_ARGS__) \ | |||||
| MEGDNN_FOREACH_TENSOR_NDIM_LARGE(cb ,##__VA_ARGS__) | |||||
| #define MEGDNN_FOREACH_TENSOR_NDIM(cb, ...) \ | |||||
| MEGDNN_FOREACH_TENSOR_NDIM_SMALL(cb, ##__VA_ARGS__) \ | |||||
| MEGDNN_FOREACH_TENSOR_NDIM_LARGE(cb, ##__VA_ARGS__) | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -11,14 +11,14 @@ | |||||
| // intentional no header guard here | // intentional no header guard here | ||||
| #include "megdnn/handle.h" | #include "megdnn/handle.h" | ||||
| #include "megdnn/oprs/base.h" | |||||
| #include "megdnn/opr_param_defs.h" | #include "megdnn/opr_param_defs.h" | ||||
| #include "megdnn/opr_result_defs.h" | #include "megdnn/opr_result_defs.h" | ||||
| #include "megdnn/oprs/base.h" | |||||
| #include "./visibility_prologue.h" | #include "./visibility_prologue.h" | ||||
| #include <limits> | |||||
| #include <array> | #include <array> | ||||
| #include <limits> | |||||
| #ifndef _megdnn_in | #ifndef _megdnn_in | ||||
| #define _megdnn_in | #define _megdnn_in | ||||
| @@ -29,36 +29,37 @@ | |||||
| #endif | #endif | ||||
| #ifndef _megdnn_tensor_in | #ifndef _megdnn_tensor_in | ||||
| #define _megdnn_tensor_in const TensorND & | |||||
| #define _megdnn_tensor_in const TensorND& | |||||
| #endif | #endif | ||||
| #ifndef _megdnn_tensor_out | #ifndef _megdnn_tensor_out | ||||
| #define _megdnn_tensor_out const TensorND & | |||||
| #define _megdnn_tensor_out const TensorND& | |||||
| #endif | #endif | ||||
| #ifndef _megdnn_tensor_inout | #ifndef _megdnn_tensor_inout | ||||
| #define _megdnn_tensor_inout const TensorND & | |||||
| #define _megdnn_tensor_inout const TensorND& | |||||
| #endif | #endif | ||||
| #ifndef _megdnn_workspace | #ifndef _megdnn_workspace | ||||
| #define _megdnn_workspace const Workspace & | |||||
| #define _megdnn_workspace const Workspace& | |||||
| #endif | #endif | ||||
| #define DEF_OPR_IMPL_CTOR(_opr_name, _base_name) \ | |||||
| public: \ | |||||
| _opr_name(Handle *handle): _base_name(handle) {} \ | |||||
| #define DEF_OPR_IMPL_CTOR(_opr_name, _base_name) \ | |||||
| public: \ | |||||
| _opr_name(Handle* handle) : _base_name(handle) {} | |||||
| #define DEF_OPR_IMPL(_opr_name, _base_name, _nr_inputs, _nr_outputs) \ | #define DEF_OPR_IMPL(_opr_name, _base_name, _nr_inputs, _nr_outputs) \ | ||||
| DEF_OPR_IMPL_CTOR(_opr_name, _base_name) \ | |||||
| static MEGDNN_CONSTEXPR int NR_INPUTS = _nr_inputs; \ | |||||
| static MEGDNN_CONSTEXPR int NR_OUTPUTS = _nr_outputs; \ | |||||
| DEF_OPR_IMPL_CTOR(_opr_name, _base_name) \ | |||||
| static MEGDNN_CONSTEXPR int NR_INPUTS = _nr_inputs; \ | |||||
| static MEGDNN_CONSTEXPR int NR_OUTPUTS = _nr_outputs; | |||||
| #define DEF_OPR_PARAM(_pname) \ | |||||
| public: \ | |||||
| using Param = param::_pname; \ | |||||
| Param& param() { return m_param; } \ | |||||
| const Param& param() const { return m_param; } \ | |||||
| protected: \ | |||||
| Param m_param | |||||
| #define DEF_OPR_PARAM(_pname) \ | |||||
| public: \ | |||||
| using Param = param::_pname; \ | |||||
| Param& param() { return m_param; } \ | |||||
| const Param& param() const { return m_param; } \ | |||||
| \ | |||||
| protected: \ | |||||
| Param m_param | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -20,4 +20,3 @@ | |||||
| #endif | #endif | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -16,25 +16,21 @@ | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace opr_result { | namespace opr_result { | ||||
| struct Checksum { | |||||
| uint32_t checksum; | |||||
| union { | |||||
| int32_t iv; | |||||
| float fv; | |||||
| } last_val; | |||||
| bool operator == (const Checksum &rhs) const { | |||||
| return checksum == rhs.checksum && | |||||
| last_val.iv == rhs.last_val.iv; | |||||
| } | |||||
| bool operator != (const Checksum &rhs) const { | |||||
| return !operator==(rhs); | |||||
| } | |||||
| }; | |||||
| } // namespace opr_result | |||||
| } // namespace megdnn | |||||
| struct Checksum { | |||||
| uint32_t checksum; | |||||
| union { | |||||
| int32_t iv; | |||||
| float fv; | |||||
| } last_val; | |||||
| bool operator==(const Checksum& rhs) const { | |||||
| return checksum == rhs.checksum && last_val.iv == rhs.last_val.iv; | |||||
| } | |||||
| bool operator!=(const Checksum& rhs) const { return !operator==(rhs); } | |||||
| }; | |||||
| } // namespace opr_result | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -12,11 +12,11 @@ | |||||
| #include "megdnn/oprs/cv.h" | #include "megdnn/oprs/cv.h" | ||||
| #include "megdnn/oprs/general.h" | #include "megdnn/oprs/general.h" | ||||
| #include "megdnn/oprs/imgproc.h" | |||||
| #include "megdnn/oprs/linalg.h" | |||||
| #include "megdnn/oprs/nn.h" | #include "megdnn/oprs/nn.h" | ||||
| #include "megdnn/oprs/nn_int.h" | #include "megdnn/oprs/nn_int.h" | ||||
| #include "megdnn/oprs/imgproc.h" | |||||
| #include "megdnn/oprs/utils.h" | #include "megdnn/oprs/utils.h" | ||||
| #include "megdnn/oprs/linalg.h" | |||||
| template <typename Opr> | template <typename Opr> | ||||
| struct OprArityTrait; | struct OprArityTrait; | ||||
| @@ -53,6 +53,4 @@ INST_ARITY(megdnn::PoolingBackward, 3, 1); | |||||
| #undef INST_ARITY | #undef INST_ARITY | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -90,7 +90,7 @@ enum class AlgoDataType : uint32_t { | |||||
| INT8X8X16 = 1 << 4, | INT8X8X16 = 1 << 4, | ||||
| INT16X16X32 = 1 << 5, | INT16X16X32 = 1 << 5, | ||||
| INT4X4X16 = 1 << 6, | INT4X4X16 = 1 << 6, | ||||
| QINT4x4x32 = 1 << 7, | |||||
| QINT4x4x32 = 1 << 7, | |||||
| }; | }; | ||||
| /*! | /*! | ||||
| @@ -195,16 +195,16 @@ public: | |||||
| Handle::HandleType handle_type() const { return m_handle_type; } | Handle::HandleType handle_type() const { return m_handle_type; } | ||||
| Info::Desc desc() const { return {handle_type(), type(), param(), name()}; } | Info::Desc desc() const { return {handle_type(), type(), param(), name()}; } | ||||
| Info info() const { | |||||
| return {desc(), attribute()}; | |||||
| } | |||||
| Info info() const { return {desc(), attribute()}; } | |||||
| template <typename T> | template <typename T> | ||||
| static void serialize_write_pod(const T& val, std::string& result) { | static void serialize_write_pod(const T& val, std::string& result) { | ||||
| static_assert(std::is_trivially_copyable<T>::value, | |||||
| "type should be trivially copyable"); | |||||
| static_assert(!std::is_pointer<T>::value, | |||||
| "serialize pointer is unsafe in eager execution mode"); | |||||
| static_assert( | |||||
| std::is_trivially_copyable<T>::value, | |||||
| "type should be trivially copyable"); | |||||
| static_assert( | |||||
| !std::is_pointer<T>::value, | |||||
| "serialize pointer is unsafe in eager execution mode"); | |||||
| result.append(reinterpret_cast<const char*>(&val), sizeof(T)); | result.append(reinterpret_cast<const char*>(&val), sizeof(T)); | ||||
| } | } | ||||
| @@ -231,9 +231,8 @@ public: | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| static std::string deserialize_read_pod(const std::string& data, | |||||
| size_t offset = 0, | |||||
| size_t size = 0) { | |||||
| static std::string deserialize_read_pod( | |||||
| const std::string& data, size_t offset = 0, size_t size = 0) { | |||||
| return std::string(data.data() + offset, size); | return std::string(data.data() + offset, size); | ||||
| } | } | ||||
| @@ -286,8 +285,8 @@ public: | |||||
| * \param layouts origin layouts of the parent opr | * \param layouts origin layouts of the parent opr | ||||
| * \param opr parent opr | * \param opr parent opr | ||||
| */ | */ | ||||
| virtual std::vector<SearchItem> get_subopr_list(const TensorLayoutArray&, | |||||
| const OperatorBase*) const { | |||||
| virtual std::vector<SearchItem> get_subopr_list( | |||||
| const TensorLayoutArray&, const OperatorBase*) const { | |||||
| return {}; | return {}; | ||||
| } | } | ||||
| @@ -333,9 +332,7 @@ public: | |||||
| ExecutionPolicy& execution_policy() { return m_execution_policy; } | ExecutionPolicy& execution_policy() { return m_execution_policy; } | ||||
| const ExecutionPolicy& execution_policy() const { | |||||
| return m_execution_policy; | |||||
| } | |||||
| const ExecutionPolicy& execution_policy() const { return m_execution_policy; } | |||||
| virtual Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) = 0; | virtual Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) = 0; | ||||
| @@ -355,8 +352,8 @@ public: | |||||
| using AlgoAttribute = detail::Algorithm::Attribute; | using AlgoAttribute = detail::Algorithm::Attribute; | ||||
| //! get all possible algorithm decriptions for the specified layouts | //! get all possible algorithm decriptions for the specified layouts | ||||
| std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0, | |||||
| const TensorLayout& p1) { | |||||
| std::vector<AlgorithmInfo> get_all_algorithms_info( | |||||
| const TensorLayout& p0, const TensorLayout& p1) { | |||||
| std::vector<AlgorithmInfo> ret; | std::vector<AlgorithmInfo> ret; | ||||
| for (auto&& algo : get_all_algorithms(p0, p1)) { | for (auto&& algo : get_all_algorithms(p0, p1)) { | ||||
| ret.emplace_back(algo->info()); | ret.emplace_back(algo->info()); | ||||
| @@ -364,8 +361,8 @@ public: | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| std::vector<AlgorithmInfo> get_all_algorithms_info_safe(const TensorLayout& p0, | |||||
| const TensorLayout& p1) { | |||||
| std::vector<AlgorithmInfo> get_all_algorithms_info_safe( | |||||
| const TensorLayout& p0, const TensorLayout& p1) { | |||||
| std::vector<AlgorithmInfo> ret; | std::vector<AlgorithmInfo> ret; | ||||
| for (auto&& algo : get_all_algorithms_safe(p0, p1)) { | for (auto&& algo : get_all_algorithms_safe(p0, p1)) { | ||||
| ret.emplace_back(algo->info()); | ret.emplace_back(algo->info()); | ||||
| @@ -382,12 +379,11 @@ public: | |||||
| */ | */ | ||||
| AlgorithmInfo get_algorithm_info_heuristic( | AlgorithmInfo get_algorithm_info_heuristic( | ||||
| const TensorLayout& p0, const TensorLayout& p1, | const TensorLayout& p0, const TensorLayout& p1, | ||||
| size_t workspace_limit_in_bytes = | |||||
| std::numeric_limits<size_t>::max(), | |||||
| size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(), | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | ||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { | const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { | ||||
| return get_algorithm_heuristic(p0, p1, workspace_limit_in_bytes, | |||||
| positive_attr, negative_attr) | |||||
| return get_algorithm_heuristic( | |||||
| p0, p1, workspace_limit_in_bytes, positive_attr, negative_attr) | |||||
| ->info(); | ->info(); | ||||
| } | } | ||||
| @@ -408,8 +404,7 @@ protected: | |||||
| */ | */ | ||||
| virtual Algorithm* get_algorithm_heuristic( | virtual Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& p0, const TensorLayout& p1, | const TensorLayout& p0, const TensorLayout& p1, | ||||
| size_t workspace_limit_in_bytes = | |||||
| std::numeric_limits<size_t>::max(), | |||||
| size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(), | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | ||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; | const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; | ||||
| }; | }; | ||||
| @@ -423,9 +418,8 @@ public: | |||||
| using AlgoAttribute = detail::Algorithm::Attribute; | using AlgoAttribute = detail::Algorithm::Attribute; | ||||
| //! get all possible algorithm decriptions for the specified layouts | //! get all possible algorithm decriptions for the specified layouts | ||||
| std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0, | |||||
| const TensorLayout& p1, | |||||
| const TensorLayout& p2) { | |||||
| std::vector<AlgorithmInfo> get_all_algorithms_info( | |||||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2) { | |||||
| std::vector<AlgorithmInfo> ret; | std::vector<AlgorithmInfo> ret; | ||||
| for (auto&& algo : get_all_algorithms(p0, p1, p2)) { | for (auto&& algo : get_all_algorithms(p0, p1, p2)) { | ||||
| ret.emplace_back(algo->info()); | ret.emplace_back(algo->info()); | ||||
| @@ -433,9 +427,8 @@ public: | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| std::vector<AlgorithmInfo> get_all_algorithms_info_safe(const TensorLayout& p0, | |||||
| const TensorLayout& p1, | |||||
| const TensorLayout& p2) { | |||||
| std::vector<AlgorithmInfo> get_all_algorithms_info_safe( | |||||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2) { | |||||
| std::vector<AlgorithmInfo> ret; | std::vector<AlgorithmInfo> ret; | ||||
| for (auto&& algo : get_all_algorithms_safe(p0, p1, p2)) { | for (auto&& algo : get_all_algorithms_safe(p0, p1, p2)) { | ||||
| ret.emplace_back(algo->info()); | ret.emplace_back(algo->info()); | ||||
| @@ -451,14 +444,13 @@ public: | |||||
| * \p workspace_limit_in_bytes. | * \p workspace_limit_in_bytes. | ||||
| */ | */ | ||||
| AlgorithmInfo get_algorithm_info_heuristic( | AlgorithmInfo get_algorithm_info_heuristic( | ||||
| const TensorLayout& p0, const TensorLayout& p1, | |||||
| const TensorLayout& p2, | |||||
| size_t workspace_limit_in_bytes = | |||||
| std::numeric_limits<size_t>::max(), | |||||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||||
| size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(), | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | ||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { | const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { | ||||
| return get_algorithm_heuristic(p0, p1, p2, workspace_limit_in_bytes, | |||||
| positive_attr, negative_attr) | |||||
| return get_algorithm_heuristic( | |||||
| p0, p1, p2, workspace_limit_in_bytes, positive_attr, | |||||
| negative_attr) | |||||
| ->info(); | ->info(); | ||||
| } | } | ||||
| @@ -467,11 +459,9 @@ protected: | |||||
| //! get all possible algorithms for the specified layouts | //! get all possible algorithms for the specified layouts | ||||
| virtual std::vector<Algorithm*> get_all_algorithms( | virtual std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& p0, const TensorLayout& p1, | |||||
| const TensorLayout& p2) = 0; | |||||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2) = 0; | |||||
| virtual std::vector<Algorithm*> get_all_algorithms_safe( | virtual std::vector<Algorithm*> get_all_algorithms_safe( | ||||
| const TensorLayout& p0, const TensorLayout& p1, | |||||
| const TensorLayout& p2) = 0; | |||||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2) = 0; | |||||
| /** | /** | ||||
| * \brief Returns the best algorithm by heuristic. | * \brief Returns the best algorithm by heuristic. | ||||
| @@ -480,10 +470,8 @@ protected: | |||||
| * \p workspace_limit_in_bytes. | * \p workspace_limit_in_bytes. | ||||
| */ | */ | ||||
| virtual Algorithm* get_algorithm_heuristic( | virtual Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& p0, const TensorLayout& p1, | |||||
| const TensorLayout& p2, | |||||
| size_t workspace_limit_in_bytes = | |||||
| std::numeric_limits<size_t>::max(), | |||||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||||
| size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(), | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | ||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; | const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; | ||||
| }; | }; | ||||
| @@ -497,10 +485,9 @@ public: | |||||
| using AlgoAttribute = detail::Algorithm::Attribute; | using AlgoAttribute = detail::Algorithm::Attribute; | ||||
| //! get all possible algorithm decriptions for the specified layouts | //! get all possible algorithm decriptions for the specified layouts | ||||
| std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0, | |||||
| const TensorLayout& p1, | |||||
| const TensorLayout& p2, | |||||
| const TensorLayout& p3) { | |||||
| std::vector<AlgorithmInfo> get_all_algorithms_info( | |||||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||||
| const TensorLayout& p3) { | |||||
| std::vector<AlgorithmInfo> ret; | std::vector<AlgorithmInfo> ret; | ||||
| for (auto&& algo : get_all_algorithms(p0, p1, p2, p3)) { | for (auto&& algo : get_all_algorithms(p0, p1, p2, p3)) { | ||||
| ret.emplace_back(algo->info()); | ret.emplace_back(algo->info()); | ||||
| @@ -508,10 +495,9 @@ public: | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| std::vector<AlgorithmInfo> get_all_algorithms_info_safe(const TensorLayout& p0, | |||||
| const TensorLayout& p1, | |||||
| const TensorLayout& p2, | |||||
| const TensorLayout& p3) { | |||||
| std::vector<AlgorithmInfo> get_all_algorithms_info_safe( | |||||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||||
| const TensorLayout& p3) { | |||||
| std::vector<AlgorithmInfo> ret; | std::vector<AlgorithmInfo> ret; | ||||
| for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3)) { | for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3)) { | ||||
| ret.emplace_back(algo->info()); | ret.emplace_back(algo->info()); | ||||
| @@ -527,14 +513,14 @@ public: | |||||
| * \p workspace_limit_in_bytes. | * \p workspace_limit_in_bytes. | ||||
| */ | */ | ||||
| AlgorithmInfo get_algorithm_info_heuristic( | AlgorithmInfo get_algorithm_info_heuristic( | ||||
| const TensorLayout& p0, const TensorLayout& p1, | |||||
| const TensorLayout& p2, const TensorLayout& p3, | |||||
| size_t workspace_limit_in_bytes = | |||||
| std::numeric_limits<size_t>::max(), | |||||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||||
| const TensorLayout& p3, | |||||
| size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(), | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | ||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { | const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { | ||||
| return get_algorithm_heuristic(p0, p1, p2, p3, workspace_limit_in_bytes, | |||||
| positive_attr, negative_attr) | |||||
| return get_algorithm_heuristic( | |||||
| p0, p1, p2, p3, workspace_limit_in_bytes, positive_attr, | |||||
| negative_attr) | |||||
| ->info(); | ->info(); | ||||
| } | } | ||||
| @@ -543,11 +529,11 @@ protected: | |||||
| //! get all possible algorithms for the specified layouts | //! get all possible algorithms for the specified layouts | ||||
| virtual std::vector<Algorithm*> get_all_algorithms( | virtual std::vector<Algorithm*> get_all_algorithms( | ||||
| const TensorLayout& p0, const TensorLayout& p1, | |||||
| const TensorLayout& p2, const TensorLayout& p3) = 0; | |||||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||||
| const TensorLayout& p3) = 0; | |||||
| virtual std::vector<Algorithm*> get_all_algorithms_safe( | virtual std::vector<Algorithm*> get_all_algorithms_safe( | ||||
| const TensorLayout& p0, const TensorLayout& p1, | |||||
| const TensorLayout& p2, const TensorLayout& p3) = 0; | |||||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||||
| const TensorLayout& p3) = 0; | |||||
| /** | /** | ||||
| * \brief Returns the best algorithm by heuristic. | * \brief Returns the best algorithm by heuristic. | ||||
| @@ -556,10 +542,9 @@ protected: | |||||
| * \p workspace_limit_in_bytes. | * \p workspace_limit_in_bytes. | ||||
| */ | */ | ||||
| virtual Algorithm* get_algorithm_heuristic( | virtual Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& p0, const TensorLayout& p1, | |||||
| const TensorLayout& p2, const TensorLayout& p3, | |||||
| size_t workspace_limit_in_bytes = | |||||
| std::numeric_limits<size_t>::max(), | |||||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||||
| const TensorLayout& p3, | |||||
| size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(), | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | ||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; | const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; | ||||
| }; | }; | ||||
| @@ -573,11 +558,9 @@ public: | |||||
| using AlgoAttribute = detail::Algorithm::Attribute; | using AlgoAttribute = detail::Algorithm::Attribute; | ||||
| //! get all possible algorithm decriptions for the specified layouts | //! get all possible algorithm decriptions for the specified layouts | ||||
| std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0, | |||||
| const TensorLayout& p1, | |||||
| const TensorLayout& p2, | |||||
| const TensorLayout& p3, | |||||
| const TensorLayout& p4) { | |||||
| std::vector<AlgorithmInfo> get_all_algorithms_info( | |||||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||||
| const TensorLayout& p3, const TensorLayout& p4) { | |||||
| std::vector<AlgorithmInfo> ret; | std::vector<AlgorithmInfo> ret; | ||||
| for (auto&& algo : get_all_algorithms(p0, p1, p2, p3, p4)) { | for (auto&& algo : get_all_algorithms(p0, p1, p2, p3, p4)) { | ||||
| ret.emplace_back(algo->info()); | ret.emplace_back(algo->info()); | ||||
| @@ -585,11 +568,9 @@ public: | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| std::vector<AlgorithmInfo> get_all_algorithms_info_safe(const TensorLayout& p0, | |||||
| const TensorLayout& p1, | |||||
| const TensorLayout& p2, | |||||
| const TensorLayout& p3, | |||||
| const TensorLayout& p4) { | |||||
| std::vector<AlgorithmInfo> get_all_algorithms_info_safe( | |||||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||||
| const TensorLayout& p3, const TensorLayout& p4) { | |||||
| std::vector<AlgorithmInfo> ret; | std::vector<AlgorithmInfo> ret; | ||||
| for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3, p4)) { | for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3, p4)) { | ||||
| ret.emplace_back(algo->info()); | ret.emplace_back(algo->info()); | ||||
| @@ -605,16 +586,14 @@ public: | |||||
| * \p workspace_limit_in_bytes. | * \p workspace_limit_in_bytes. | ||||
| */ | */ | ||||
| AlgorithmInfo get_algorithm_info_heuristic( | AlgorithmInfo get_algorithm_info_heuristic( | ||||
| const TensorLayout& p0, const TensorLayout& p1, | |||||
| const TensorLayout& p2, const TensorLayout& p3, | |||||
| const TensorLayout& p4, | |||||
| size_t workspace_limit_in_bytes = | |||||
| std::numeric_limits<size_t>::max(), | |||||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||||
| const TensorLayout& p3, const TensorLayout& p4, | |||||
| size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(), | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | ||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { | const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { | ||||
| return get_algorithm_heuristic(p0, p1, p2, p3, p4, | |||||
| workspace_limit_in_bytes, positive_attr, | |||||
| negative_attr) | |||||
| return get_algorithm_heuristic( | |||||
| p0, p1, p2, p3, p4, workspace_limit_in_bytes, positive_attr, | |||||
| negative_attr) | |||||
| ->info(); | ->info(); | ||||
| } | } | ||||
| @@ -622,14 +601,12 @@ protected: | |||||
| ~MultiAlgoOpr() = default; | ~MultiAlgoOpr() = default; | ||||
| //! get all possible algorithms for the specified layouts | //! get all possible algorithms for the specified layouts | ||||
| virtual std::vector<Algorithm*> get_all_algorithms( | |||||
| const TensorLayout& p0, const TensorLayout& p1, | |||||
| const TensorLayout& p2, const TensorLayout& p3, | |||||
| const TensorLayout& p4) = 0; | |||||
| virtual std::vector<Algorithm*> get_all_algorithms( | |||||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||||
| const TensorLayout& p3, const TensorLayout& p4) = 0; | |||||
| virtual std::vector<Algorithm*> get_all_algorithms_safe( | virtual std::vector<Algorithm*> get_all_algorithms_safe( | ||||
| const TensorLayout& p0, const TensorLayout& p1, | |||||
| const TensorLayout& p2, const TensorLayout& p3, | |||||
| const TensorLayout& p4) = 0; | |||||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||||
| const TensorLayout& p3, const TensorLayout& p4) = 0; | |||||
| /** | /** | ||||
| * \brief Returns the best algorithm by heuristic. | * \brief Returns the best algorithm by heuristic. | ||||
| @@ -638,11 +615,9 @@ protected: | |||||
| * \p workspace_limit_in_bytes. | * \p workspace_limit_in_bytes. | ||||
| */ | */ | ||||
| virtual Algorithm* get_algorithm_heuristic( | virtual Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& p0, const TensorLayout& p1, | |||||
| const TensorLayout& p2, const TensorLayout& p3, | |||||
| const TensorLayout& p4, | |||||
| size_t workspace_limit_in_bytes = | |||||
| std::numeric_limits<size_t>::max(), | |||||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||||
| const TensorLayout& p3, const TensorLayout& p4, | |||||
| size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(), | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | ||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; | const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; | ||||
| }; | }; | ||||
| @@ -657,9 +632,8 @@ public: | |||||
| //! get all possible algorithm decriptions for the specified layouts | //! get all possible algorithm decriptions for the specified layouts | ||||
| std::vector<AlgorithmInfo> get_all_algorithms_info( | std::vector<AlgorithmInfo> get_all_algorithms_info( | ||||
| const TensorLayout& p0, const TensorLayout& p1, | |||||
| const TensorLayout& p2, const TensorLayout& p3, | |||||
| const TensorLayout& p4, const TensorLayout& p5, | |||||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||||
| const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5, | |||||
| const TensorLayout& p6, const TensorLayout& p7) { | const TensorLayout& p6, const TensorLayout& p7) { | ||||
| std::vector<AlgorithmInfo> ret; | std::vector<AlgorithmInfo> ret; | ||||
| for (auto&& algo : get_all_algorithms(p0, p1, p2, p3, p4, p5, p6, p7)) { | for (auto&& algo : get_all_algorithms(p0, p1, p2, p3, p4, p5, p6, p7)) { | ||||
| @@ -669,9 +643,8 @@ public: | |||||
| } | } | ||||
| std::vector<AlgorithmInfo> get_all_algorithms_info_safe( | std::vector<AlgorithmInfo> get_all_algorithms_info_safe( | ||||
| const TensorLayout& p0, const TensorLayout& p1, | |||||
| const TensorLayout& p2, const TensorLayout& p3, | |||||
| const TensorLayout& p4, const TensorLayout& p5, | |||||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||||
| const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5, | |||||
| const TensorLayout& p6, const TensorLayout& p7) { | const TensorLayout& p6, const TensorLayout& p7) { | ||||
| std::vector<AlgorithmInfo> ret; | std::vector<AlgorithmInfo> ret; | ||||
| for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3, p4, p5, p6, p7)) { | for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3, p4, p5, p6, p7)) { | ||||
| @@ -687,17 +660,15 @@ public: | |||||
| * The selected algorithm should not use workspace more than | * The selected algorithm should not use workspace more than | ||||
| */ | */ | ||||
| AlgorithmInfo get_algorithm_info_heuristic( | AlgorithmInfo get_algorithm_info_heuristic( | ||||
| const TensorLayout& p0, const TensorLayout& p1, | |||||
| const TensorLayout& p2, const TensorLayout& p3, | |||||
| const TensorLayout& p4, const TensorLayout& p5, | |||||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||||
| const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5, | |||||
| const TensorLayout& p6, const TensorLayout& p7, | const TensorLayout& p6, const TensorLayout& p7, | ||||
| size_t workspace_limit_in_bytes = | |||||
| std::numeric_limits<size_t>::max(), | |||||
| size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(), | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | ||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { | const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { | ||||
| return get_algorithm_heuristic(p0, p1, p2, p3, p4, p5, p6, p7, | |||||
| workspace_limit_in_bytes, positive_attr, | |||||
| negative_attr) | |||||
| return get_algorithm_heuristic( | |||||
| p0, p1, p2, p3, p4, p5, p6, p7, workspace_limit_in_bytes, | |||||
| positive_attr, negative_attr) | |||||
| ->info(); | ->info(); | ||||
| } | } | ||||
| @@ -705,15 +676,13 @@ protected: | |||||
| ~MultiAlgoOpr() = default; | ~MultiAlgoOpr() = default; | ||||
| //! get all possible algorithms for the specified layouts | //! get all possible algorithms for the specified layouts | ||||
| virtual std::vector<Algorithm*> get_all_algorithms( | |||||
| const TensorLayout& p0, const TensorLayout& p1, | |||||
| const TensorLayout& p2, const TensorLayout& p3, | |||||
| const TensorLayout& p4, const TensorLayout& p5, | |||||
| virtual std::vector<Algorithm*> get_all_algorithms( | |||||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||||
| const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5, | |||||
| const TensorLayout& p6, const TensorLayout& p7) = 0; | const TensorLayout& p6, const TensorLayout& p7) = 0; | ||||
| virtual std::vector<Algorithm*> get_all_algorithms_safe( | virtual std::vector<Algorithm*> get_all_algorithms_safe( | ||||
| const TensorLayout& p0, const TensorLayout& p1, | |||||
| const TensorLayout& p2, const TensorLayout& p3, | |||||
| const TensorLayout& p4, const TensorLayout& p5, | |||||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||||
| const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5, | |||||
| const TensorLayout& p6, const TensorLayout& p7) = 0; | const TensorLayout& p6, const TensorLayout& p7) = 0; | ||||
| /** | /** | ||||
| @@ -723,12 +692,10 @@ protected: | |||||
| * \p workspace_limit_in_bytes. | * \p workspace_limit_in_bytes. | ||||
| */ | */ | ||||
| virtual Algorithm* get_algorithm_heuristic( | virtual Algorithm* get_algorithm_heuristic( | ||||
| const TensorLayout& p0, const TensorLayout& p1, | |||||
| const TensorLayout& p2, const TensorLayout& p3, | |||||
| const TensorLayout& p4, const TensorLayout& p5, | |||||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||||
| const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5, | |||||
| const TensorLayout& p6, const TensorLayout& p7, | const TensorLayout& p6, const TensorLayout& p7, | ||||
| size_t workspace_limit_in_bytes = | |||||
| std::numeric_limits<size_t>::max(), | |||||
| size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(), | |||||
| const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | ||||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; | const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; | ||||
| }; | }; | ||||
| @@ -31,15 +31,17 @@ class FlipForward : public FlipBase { | |||||
| DEF_OPR_IMPL(FlipForward, FlipBase, 1, 1); | DEF_OPR_IMPL(FlipForward, FlipBase, 1, 1); | ||||
| public: | public: | ||||
| virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| virtual void exec( | |||||
| _megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| void deduce_layout(const TensorLayout& src, TensorLayout& dst); | void deduce_layout(const TensorLayout& src, TensorLayout& dst); | ||||
| virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
| const TensorLayout& dst) = 0; | |||||
| virtual size_t get_workspace_in_bytes( | |||||
| const TensorLayout& src, const TensorLayout& dst) = 0; | |||||
| protected: | protected: | ||||
| void check_exec(const TensorLayout& src, const TensorLayout& dst, | |||||
| size_t workspace_in_bytes); | |||||
| void check_exec( | |||||
| const TensorLayout& src, const TensorLayout& dst, | |||||
| size_t workspace_in_bytes); | |||||
| }; | }; | ||||
| using Flip = FlipForward; | using Flip = FlipForward; | ||||
| @@ -56,15 +58,17 @@ class RotateForward : public RotateBase { | |||||
| DEF_OPR_IMPL(RotateForward, RotateBase, 1, 1); | DEF_OPR_IMPL(RotateForward, RotateBase, 1, 1); | ||||
| public: | public: | ||||
| virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| virtual void exec( | |||||
| _megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| void deduce_layout(const TensorLayout& src, TensorLayout& dst); | void deduce_layout(const TensorLayout& src, TensorLayout& dst); | ||||
| virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
| const TensorLayout& dst) = 0; | |||||
| virtual size_t get_workspace_in_bytes( | |||||
| const TensorLayout& src, const TensorLayout& dst) = 0; | |||||
| protected: | protected: | ||||
| void check_exec(const TensorLayout& src, const TensorLayout& dst, | |||||
| size_t workspace_in_bytes); | |||||
| void check_exec( | |||||
| const TensorLayout& src, const TensorLayout& dst, | |||||
| size_t workspace_in_bytes); | |||||
| }; | }; | ||||
| using Rotate = RotateForward; | using Rotate = RotateForward; | ||||
| @@ -81,15 +85,17 @@ class ROICopyForward : public ROICopyBase { | |||||
| DEF_OPR_IMPL(ROICopyForward, ROICopyBase, 1, 1); | DEF_OPR_IMPL(ROICopyForward, ROICopyBase, 1, 1); | ||||
| public: | public: | ||||
| virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| virtual void exec( | |||||
| _megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| void deduce_layout(const TensorLayout& src, TensorLayout& dst); | void deduce_layout(const TensorLayout& src, TensorLayout& dst); | ||||
| virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
| const TensorLayout& dst) = 0; | |||||
| virtual size_t get_workspace_in_bytes( | |||||
| const TensorLayout& src, const TensorLayout& dst) = 0; | |||||
| protected: | protected: | ||||
| void check_exec(const TensorLayout& src, const TensorLayout& dst, | |||||
| size_t workspace_in_bytes); | |||||
| void check_exec( | |||||
| const TensorLayout& src, const TensorLayout& dst, | |||||
| size_t workspace_in_bytes); | |||||
| }; | }; | ||||
| using ROICopy = ROICopyForward; | using ROICopy = ROICopyForward; | ||||
| @@ -106,15 +112,17 @@ class CvtColorForward : public CvtColorBase { | |||||
| DEF_OPR_IMPL(CvtColorForward, CvtColorBase, 1, 1); | DEF_OPR_IMPL(CvtColorForward, CvtColorBase, 1, 1); | ||||
| public: | public: | ||||
| virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| virtual void exec( | |||||
| _megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| void deduce_layout(const TensorLayout& src, TensorLayout& dst); | void deduce_layout(const TensorLayout& src, TensorLayout& dst); | ||||
| virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
| const TensorLayout& dst) = 0; | |||||
| virtual size_t get_workspace_in_bytes( | |||||
| const TensorLayout& src, const TensorLayout& dst) = 0; | |||||
| protected: | protected: | ||||
| void check_exec(const TensorLayout& src, const TensorLayout& dst, | |||||
| size_t workspace_in_bytes); | |||||
| void check_exec( | |||||
| const TensorLayout& src, const TensorLayout& dst, | |||||
| size_t workspace_in_bytes); | |||||
| }; | }; | ||||
| using CvtColor = CvtColorForward; | using CvtColor = CvtColorForward; | ||||
| @@ -130,8 +138,9 @@ public: | |||||
| using BorderMode = Param::BorderMode; | using BorderMode = Param::BorderMode; | ||||
| protected: | protected: | ||||
| void check_layout_fwd(const TensorLayout& src, const TensorLayout& trans, | |||||
| const TensorLayout& dst); | |||||
| void check_layout_fwd( | |||||
| const TensorLayout& src, const TensorLayout& trans, | |||||
| const TensorLayout& dst); | |||||
| std::string param_msg() const; | std::string param_msg() const; | ||||
| int get_real_coord(int p, int len); | int get_real_coord(int p, int len); | ||||
| }; | }; | ||||
| @@ -148,15 +157,17 @@ public: | |||||
| * \warning src, trans, border_value, dst should be contiguous | * \warning src, trans, border_value, dst should be contiguous | ||||
| * The size of trans is N * 2 * 3 | * The size of trans is N * 2 * 3 | ||||
| */ | */ | ||||
| virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in trans, | |||||
| _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; | |||||
| virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
| const TensorLayout& trans, | |||||
| const TensorLayout& dst) = 0; | |||||
| virtual void exec( | |||||
| _megdnn_tensor_in src, _megdnn_tensor_in trans, _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| virtual size_t get_workspace_in_bytes( | |||||
| const TensorLayout& src, const TensorLayout& trans, | |||||
| const TensorLayout& dst) = 0; | |||||
| protected: | protected: | ||||
| void check_exec(const TensorLayout& src, const TensorLayout& trans, | |||||
| const TensorLayout& dst, size_t workspace_in_bytes); | |||||
| void check_exec( | |||||
| const TensorLayout& src, const TensorLayout& trans, const TensorLayout& dst, | |||||
| size_t workspace_in_bytes); | |||||
| }; | }; | ||||
| using WarpAffine = WarpAffineForward; | using WarpAffine = WarpAffineForward; | ||||
| @@ -173,15 +184,17 @@ class GaussianBlurForward : public GaussianBlurBase { | |||||
| DEF_OPR_IMPL(GaussianBlurForward, GaussianBlurBase, 1, 1); | DEF_OPR_IMPL(GaussianBlurForward, GaussianBlurBase, 1, 1); | ||||
| public: | public: | ||||
| virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| virtual void exec( | |||||
| _megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| void deduce_layout(const TensorLayout& src, TensorLayout& dst); | void deduce_layout(const TensorLayout& src, TensorLayout& dst); | ||||
| virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
| const TensorLayout& dst) = 0; | |||||
| virtual size_t get_workspace_in_bytes( | |||||
| const TensorLayout& src, const TensorLayout& dst) = 0; | |||||
| protected: | protected: | ||||
| void check_exec(const TensorLayout& src, const TensorLayout& dst, | |||||
| size_t workspace_in_bytes); | |||||
| void check_exec( | |||||
| const TensorLayout& src, const TensorLayout& dst, | |||||
| size_t workspace_in_bytes); | |||||
| }; | }; | ||||
| using GaussianBlur = GaussianBlurForward; | using GaussianBlur = GaussianBlurForward; | ||||
| @@ -212,15 +225,17 @@ class ResizeForward : public ResizeBase { | |||||
| DEF_OPR_IMPL(ResizeForward, ResizeBase, 1, 1); | DEF_OPR_IMPL(ResizeForward, ResizeBase, 1, 1); | ||||
| public: | public: | ||||
| virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| virtual void exec( | |||||
| _megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
| const TensorLayout& dst) = 0; | |||||
| virtual size_t get_workspace_in_bytes( | |||||
| const TensorLayout& src, const TensorLayout& dst) = 0; | |||||
| protected: | protected: | ||||
| void check_exec(const TensorLayout& src, const TensorLayout& dst, | |||||
| size_t workspace_in_bytes); | |||||
| void check_exec( | |||||
| const TensorLayout& src, const TensorLayout& dst, | |||||
| size_t workspace_in_bytes); | |||||
| }; | }; | ||||
| using Resize = ResizeForward; | using Resize = ResizeForward; | ||||
| @@ -228,15 +243,17 @@ class ResizeBackward : public ResizeBase { | |||||
| DEF_OPR_IMPL(ResizeBackward, ResizeBase, 1, 1); | DEF_OPR_IMPL(ResizeBackward, ResizeBase, 1, 1); | ||||
| public: | public: | ||||
| virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| virtual void exec( | |||||
| _megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| virtual size_t get_workspace_in_bytes(const TensorLayout& diff, | |||||
| const TensorLayout& mat) = 0; | |||||
| virtual size_t get_workspace_in_bytes( | |||||
| const TensorLayout& diff, const TensorLayout& mat) = 0; | |||||
| protected: | protected: | ||||
| void check_exec(const TensorLayout& diff, const TensorLayout& mat, | |||||
| size_t workspace_in_bytes); | |||||
| void check_exec( | |||||
| const TensorLayout& diff, const TensorLayout& mat, | |||||
| size_t workspace_in_bytes); | |||||
| }; | }; | ||||
| /** | /** | ||||
| @@ -251,29 +268,32 @@ public: | |||||
| using BorderMode = Param::BorderMode; | using BorderMode = Param::BorderMode; | ||||
| protected: | protected: | ||||
| void check_layout_fwd(const TensorLayout& src, const TensorLayout& map_xy, | |||||
| const TensorLayout& dst); | |||||
| void deduce_layout_fwd(const TensorLayout& src, const TensorLayout& map_xy, | |||||
| TensorLayout& dst); | |||||
| void check_layout_fwd( | |||||
| const TensorLayout& src, const TensorLayout& map_xy, | |||||
| const TensorLayout& dst); | |||||
| void deduce_layout_fwd( | |||||
| const TensorLayout& src, const TensorLayout& map_xy, TensorLayout& dst); | |||||
| }; | }; | ||||
| class RemapForward : public RemapBase { | class RemapForward : public RemapBase { | ||||
| DEF_OPR_IMPL(RemapForward, RemapBase, 2, 1); | DEF_OPR_IMPL(RemapForward, RemapBase, 2, 1); | ||||
| public: | public: | ||||
| virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy, | |||||
| _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; | |||||
| virtual void exec( | |||||
| _megdnn_tensor_in src, _megdnn_tensor_in map_xy, _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| void deduce_layout(const TensorLayout& src, const TensorLayout& map_xy, | |||||
| TensorLayout& dst); | |||||
| void deduce_layout( | |||||
| const TensorLayout& src, const TensorLayout& map_xy, TensorLayout& dst); | |||||
| virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
| const TensorLayout& map_xy, | |||||
| const TensorLayout& dst) = 0; | |||||
| virtual size_t get_workspace_in_bytes( | |||||
| const TensorLayout& src, const TensorLayout& map_xy, | |||||
| const TensorLayout& dst) = 0; | |||||
| protected: | protected: | ||||
| void check_exec(const TensorLayout& src, const TensorLayout& map_xy, | |||||
| const TensorLayout& dst, size_t workspace_in_bytes); | |||||
| void check_exec( | |||||
| const TensorLayout& src, const TensorLayout& map_xy, | |||||
| const TensorLayout& dst, size_t workspace_in_bytes); | |||||
| }; | }; | ||||
| using Remap = RemapForward; | using Remap = RemapForward; | ||||
| @@ -281,35 +301,37 @@ class RemapBackwardData : public RemapBase { | |||||
| DEF_OPR_IMPL(RemapBackwardData, RemapBase, 2, 1); | DEF_OPR_IMPL(RemapBackwardData, RemapBase, 2, 1); | ||||
| public: | public: | ||||
| virtual void exec(_megdnn_tensor_in map_xy, _megdnn_tensor_in diff, | |||||
| _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; | |||||
| virtual void exec( | |||||
| _megdnn_tensor_in map_xy, _megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| virtual size_t get_workspace_in_bytes(const TensorLayout& map_xy, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad) = 0; | |||||
| virtual size_t get_workspace_in_bytes( | |||||
| const TensorLayout& map_xy, const TensorLayout& diff, | |||||
| const TensorLayout& grad) = 0; | |||||
| protected: | protected: | ||||
| void check_exec(const TensorLayout& map_xy, const TensorLayout& diff, | |||||
| const TensorLayout& grad, size_t workspace_in_bytes); | |||||
| void check_exec( | |||||
| const TensorLayout& map_xy, const TensorLayout& diff, | |||||
| const TensorLayout& grad, size_t workspace_in_bytes); | |||||
| }; | }; | ||||
| class RemapBackwardMat : public RemapBase { | class RemapBackwardMat : public RemapBase { | ||||
| DEF_OPR_IMPL(RemapBackwardMat, RemapBase, 3, 1); | DEF_OPR_IMPL(RemapBackwardMat, RemapBase, 3, 1); | ||||
| public: | public: | ||||
| virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy, | |||||
| _megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| virtual void exec( | |||||
| _megdnn_tensor_in src, _megdnn_tensor_in map_xy, _megdnn_tensor_in diff, | |||||
| _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; | |||||
| virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
| const TensorLayout& map_xy, | |||||
| const TensorLayout& diff, | |||||
| const TensorLayout& grad) = 0; | |||||
| virtual size_t get_workspace_in_bytes( | |||||
| const TensorLayout& src, const TensorLayout& map_xy, | |||||
| const TensorLayout& diff, const TensorLayout& grad) = 0; | |||||
| protected: | protected: | ||||
| void check_exec(const TensorLayout& src, const TensorLayout& map_xy, | |||||
| const TensorLayout& diff, const TensorLayout& grad, | |||||
| size_t workspace_in_bytes); | |||||
| void check_exec( | |||||
| const TensorLayout& src, const TensorLayout& map_xy, | |||||
| const TensorLayout& diff, const TensorLayout& grad, | |||||
| size_t workspace_in_bytes); | |||||
| }; | }; | ||||
| class SeparableFilterBase : public OperatorBase { | class SeparableFilterBase : public OperatorBase { | ||||
| @@ -317,32 +339,34 @@ class SeparableFilterBase : public OperatorBase { | |||||
| DEF_OPR_PARAM(SeparableFilter); | DEF_OPR_PARAM(SeparableFilter); | ||||
| protected: | protected: | ||||
| void deduce_layout_fwd(const TensorLayout& src, | |||||
| const TensorLayout& filter_x, | |||||
| const TensorLayout& filter_y, TensorLayout& dst); | |||||
| void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter_x, | |||||
| const TensorLayout& filter_y, | |||||
| const TensorLayout& dst); | |||||
| void deduce_layout_fwd( | |||||
| const TensorLayout& src, const TensorLayout& filter_x, | |||||
| const TensorLayout& filter_y, TensorLayout& dst); | |||||
| void check_layout_fwd( | |||||
| const TensorLayout& src, const TensorLayout& filter_x, | |||||
| const TensorLayout& filter_y, const TensorLayout& dst); | |||||
| }; | }; | ||||
| class SeparableFilterForward : public SeparableFilterBase { | class SeparableFilterForward : public SeparableFilterBase { | ||||
| DEF_OPR_IMPL(SeparableFilterForward, SeparableFilterBase, 3, 1); | DEF_OPR_IMPL(SeparableFilterForward, SeparableFilterBase, 3, 1); | ||||
| public: | public: | ||||
| virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter_x, | |||||
| _megdnn_tensor_in filter_y, _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| void deduce_layout(const TensorLayout& src, const TensorLayout& filter_x, | |||||
| const TensorLayout& filter_y, TensorLayout& dst); | |||||
| virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
| const TensorLayout& filter_x, | |||||
| const TensorLayout& filter_y, | |||||
| const TensorLayout& dst) = 0; | |||||
| virtual void exec( | |||||
| _megdnn_tensor_in src, _megdnn_tensor_in filter_x, | |||||
| _megdnn_tensor_in filter_y, _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| void deduce_layout( | |||||
| const TensorLayout& src, const TensorLayout& filter_x, | |||||
| const TensorLayout& filter_y, TensorLayout& dst); | |||||
| virtual size_t get_workspace_in_bytes( | |||||
| const TensorLayout& src, const TensorLayout& filter_x, | |||||
| const TensorLayout& filter_y, const TensorLayout& dst) = 0; | |||||
| protected: | protected: | ||||
| void check_exec(const TensorLayout& src, const TensorLayout& filter_x, | |||||
| const TensorLayout& filter_y, const TensorLayout& dst, | |||||
| size_t workspace_in_bytes); | |||||
| void check_exec( | |||||
| const TensorLayout& src, const TensorLayout& filter_x, | |||||
| const TensorLayout& filter_y, const TensorLayout& dst, | |||||
| size_t workspace_in_bytes); | |||||
| }; | }; | ||||
| using SeparableFilter = SeparableFilterForward; | using SeparableFilter = SeparableFilterForward; | ||||
| @@ -13,173 +13,162 @@ | |||||
| namespace megdnn { | namespace megdnn { | ||||
| class WarpPerspectiveBase: public OperatorBase { | |||||
| class WarpPerspectiveBase : public OperatorBase { | |||||
| DEF_OPR_IMPL_CTOR(WarpPerspectiveBase, OperatorBase); | DEF_OPR_IMPL_CTOR(WarpPerspectiveBase, OperatorBase); | ||||
| DEF_OPR_PARAM(WarpPerspective); | DEF_OPR_PARAM(WarpPerspective); | ||||
| public: | |||||
| using InterpolationMode = Param::InterpolationMode; | |||||
| using BorderMode = Param::BorderMode; | |||||
| protected: | |||||
| void check_layout_fwd(const TensorLayout &src, const TensorLayout &mat, | |||||
| const TensorLayout &dst) { | |||||
| check_layout_fwd(src, mat, {}, dst); | |||||
| } | |||||
| void check_layout_fwd(const TensorLayout &src, const TensorLayout &mat, | |||||
| const TensorLayout &mat_idx, const TensorLayout &dst); | |||||
| std::string param_msg() const; | |||||
| int get_real_coord(int p, int len); | |||||
| public: | |||||
| using InterpolationMode = Param::InterpolationMode; | |||||
| using BorderMode = Param::BorderMode; | |||||
| protected: | |||||
| void check_layout_fwd( | |||||
| const TensorLayout& src, const TensorLayout& mat, const TensorLayout& dst) { | |||||
| check_layout_fwd(src, mat, {}, dst); | |||||
| } | |||||
| void check_layout_fwd( | |||||
| const TensorLayout& src, const TensorLayout& mat, | |||||
| const TensorLayout& mat_idx, const TensorLayout& dst); | |||||
| std::string param_msg() const; | |||||
| int get_real_coord(int p, int len); | |||||
| }; | }; | ||||
| class WarpPerspectiveForward: public WarpPerspectiveBase { | |||||
| class WarpPerspectiveForward : public WarpPerspectiveBase { | |||||
| DEF_OPR_IMPL(WarpPerspectiveForward, WarpPerspectiveBase, 0, 1); | DEF_OPR_IMPL(WarpPerspectiveForward, WarpPerspectiveBase, 0, 1); | ||||
| public: | |||||
| /** | |||||
| * \param[in] src (n, channel, in_height, in_width) | |||||
| * \param[in] mat (n, 3, 3) | |||||
| * \param[out] dst (n, channel, out_height, out_width) | |||||
| * | |||||
| * \see http://docs.opencv.org/2.4/modules/imgproc/doc/geometric_transformations.html?highlight=warpaffine | |||||
| * | |||||
| * denominator = mat[2][0]*w+mat[2][1]*h+mat[2][2] | |||||
| * dst(h, w) = src((mat[1][0]*w+mat[1][1]*h+mat[1][2])/denominator, | |||||
| * (mat[0][0]*w+mat[0][1]*h+mat[0][2])/denominator) | |||||
| * | |||||
| * src and dst can have different shapes, as long as their n and c agree. | |||||
| * src, mat and dst should be contiguous. | |||||
| */ | |||||
| void exec(_megdnn_tensor_in src, | |||||
| _megdnn_tensor_in mat, | |||||
| _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) { | |||||
| exec(src, mat, {}, dst, workspace); | |||||
| } | |||||
| /** | |||||
| * \p src should have batch size m, and \p mat and \p mat_idx should | |||||
| * both have batch size n. Each item in \p mat_idx must be in the range | |||||
| * of [0, m-1]. | |||||
| * | |||||
| * \param mat_idx the indices of input image that each matrix in \p mat | |||||
| * should act on. It can also be empty and in such case \p mat | |||||
| * should have the same batch size as \p src. | |||||
| */ | |||||
| virtual void exec(_megdnn_tensor_in src, | |||||
| _megdnn_tensor_in mat, | |||||
| _megdnn_tensor_in mat_idx, | |||||
| _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| size_t get_workspace_in_bytes(const TensorLayout &src, | |||||
| const TensorLayout &mat, | |||||
| const TensorLayout &dst) { | |||||
| return get_workspace_in_bytes(src, mat, {}, dst); | |||||
| } | |||||
| virtual size_t get_workspace_in_bytes(const TensorLayout &src, | |||||
| const TensorLayout &mat, | |||||
| const TensorLayout &mat_idx, | |||||
| const TensorLayout &dst) = 0; | |||||
| protected: | |||||
| void check_exec(const TensorLayout &src, | |||||
| const TensorLayout &mat, | |||||
| const TensorLayout &mat_idx, | |||||
| const TensorLayout &dst, | |||||
| size_t workspace_in_bytes); | |||||
| void check_exec_allow_nhwc_mat_idx(const TensorLayout &src, | |||||
| const TensorLayout &mat, | |||||
| const TensorLayout &mat_idx, | |||||
| const TensorLayout &dst, | |||||
| size_t workspace_in_bytes); | |||||
| public: | |||||
| /** | |||||
| * \param[in] src (n, channel, in_height, in_width) | |||||
| * \param[in] mat (n, 3, 3) | |||||
| * \param[out] dst (n, channel, out_height, out_width) | |||||
| * | |||||
| * \see | |||||
| * http://docs.opencv.org/2.4/modules/imgproc/doc/geometric_transformations.html?highlight=warpaffine | |||||
| * | |||||
| * denominator = mat[2][0]*w+mat[2][1]*h+mat[2][2] | |||||
| * dst(h, w) = src((mat[1][0]*w+mat[1][1]*h+mat[1][2])/denominator, | |||||
| * (mat[0][0]*w+mat[0][1]*h+mat[0][2])/denominator) | |||||
| * | |||||
| * src and dst can have different shapes, as long as their n and c agree. | |||||
| * src, mat and dst should be contiguous. | |||||
| */ | |||||
| void exec( | |||||
| _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) { | |||||
| exec(src, mat, {}, dst, workspace); | |||||
| } | |||||
| /** | |||||
| * \p src should have batch size m, and \p mat and \p mat_idx should | |||||
| * both have batch size n. Each item in \p mat_idx must be in the range | |||||
| * of [0, m-1]. | |||||
| * | |||||
| * \param mat_idx the indices of input image that each matrix in \p mat | |||||
| * should act on. It can also be empty and in such case \p mat | |||||
| * should have the same batch size as \p src. | |||||
| */ | |||||
| virtual void exec( | |||||
| _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, | |||||
| _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; | |||||
| size_t get_workspace_in_bytes( | |||||
| const TensorLayout& src, const TensorLayout& mat, const TensorLayout& dst) { | |||||
| return get_workspace_in_bytes(src, mat, {}, dst); | |||||
| } | |||||
| virtual size_t get_workspace_in_bytes( | |||||
| const TensorLayout& src, const TensorLayout& mat, | |||||
| const TensorLayout& mat_idx, const TensorLayout& dst) = 0; | |||||
| protected: | |||||
| void check_exec( | |||||
| const TensorLayout& src, const TensorLayout& mat, | |||||
| const TensorLayout& mat_idx, const TensorLayout& dst, | |||||
| size_t workspace_in_bytes); | |||||
| void check_exec_allow_nhwc_mat_idx( | |||||
| const TensorLayout& src, const TensorLayout& mat, | |||||
| const TensorLayout& mat_idx, const TensorLayout& dst, | |||||
| size_t workspace_in_bytes); | |||||
| }; | }; | ||||
| using WarpPerspective = WarpPerspectiveForward; | using WarpPerspective = WarpPerspectiveForward; | ||||
| class WarpPerspectiveBackwardData: public WarpPerspectiveBase { | |||||
| class WarpPerspectiveBackwardData : public WarpPerspectiveBase { | |||||
| DEF_OPR_IMPL(WarpPerspectiveBackwardData, WarpPerspectiveBase, 2, 1); | DEF_OPR_IMPL(WarpPerspectiveBackwardData, WarpPerspectiveBase, 2, 1); | ||||
| public: | |||||
| /** | |||||
| * \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec | |||||
| * \param[in] diff the backpropagated gradient wrt. dst | |||||
| * \param[out] grad the backpropagated gradient wrt. src | |||||
| * \param[out] workspace temporary workspace to perform backward | |||||
| */ | |||||
| void exec(_megdnn_tensor_in mat, | |||||
| _megdnn_tensor_in diff, | |||||
| _megdnn_tensor_out grad, | |||||
| _megdnn_workspace workspace) { | |||||
| exec(mat, {}, diff, grad, workspace); | |||||
| } | |||||
| virtual void exec(_megdnn_tensor_in mat, | |||||
| _megdnn_tensor_in mat_idx, | |||||
| _megdnn_tensor_in diff, | |||||
| _megdnn_tensor_out grad, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| size_t get_workspace_in_bytes(const TensorLayout &mat, | |||||
| const TensorLayout &diff, | |||||
| const TensorLayout &grad) { | |||||
| return get_workspace_in_bytes(mat, {}, diff, grad); | |||||
| } | |||||
| virtual size_t get_workspace_in_bytes(const TensorLayout &mat, | |||||
| const TensorLayout &mat_idx, | |||||
| const TensorLayout &diff, | |||||
| const TensorLayout &grad) = 0; | |||||
| protected: | |||||
| void check_exec(const TensorLayout &mat, | |||||
| const TensorLayout &mat_idx, | |||||
| const TensorLayout &diff, | |||||
| const TensorLayout &grad, | |||||
| size_t workspace_in_bytes); | |||||
| public: | |||||
| /** | |||||
| * \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec | |||||
| * \param[in] diff the backpropagated gradient wrt. dst | |||||
| * \param[out] grad the backpropagated gradient wrt. src | |||||
| * \param[out] workspace temporary workspace to perform backward | |||||
| */ | |||||
| void exec( | |||||
| _megdnn_tensor_in mat, _megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||||
| _megdnn_workspace workspace) { | |||||
| exec(mat, {}, diff, grad, workspace); | |||||
| } | |||||
| virtual void exec( | |||||
| _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, _megdnn_tensor_in diff, | |||||
| _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; | |||||
| size_t get_workspace_in_bytes( | |||||
| const TensorLayout& mat, const TensorLayout& diff, | |||||
| const TensorLayout& grad) { | |||||
| return get_workspace_in_bytes(mat, {}, diff, grad); | |||||
| } | |||||
| virtual size_t get_workspace_in_bytes( | |||||
| const TensorLayout& mat, const TensorLayout& mat_idx, | |||||
| const TensorLayout& diff, const TensorLayout& grad) = 0; | |||||
| protected: | |||||
| void check_exec( | |||||
| const TensorLayout& mat, const TensorLayout& mat_idx, | |||||
| const TensorLayout& diff, const TensorLayout& grad, | |||||
| size_t workspace_in_bytes); | |||||
| }; | }; | ||||
| class WarpPerspectiveBackwardMat: public WarpPerspectiveBase { | |||||
| class WarpPerspectiveBackwardMat : public WarpPerspectiveBase { | |||||
| DEF_OPR_IMPL(WarpPerspectiveBackwardMat, WarpPerspectiveBase, 3, 1); | DEF_OPR_IMPL(WarpPerspectiveBackwardMat, WarpPerspectiveBase, 3, 1); | ||||
| public: | |||||
| /** | |||||
| * \param[in] src the `src' parameter in WarpPerspectiveForward::exec | |||||
| * \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec | |||||
| * \param[in] diff the backpropagated gradient wrt. dst | |||||
| * \param[out] grad the backpropagated gradient wrt. mat | |||||
| * \param[out] workspace temporary workspace to perform backward | |||||
| */ | |||||
| void exec(_megdnn_tensor_in src, | |||||
| _megdnn_tensor_in mat, | |||||
| _megdnn_tensor_in diff, | |||||
| _megdnn_tensor_out grad, | |||||
| _megdnn_workspace workspace) { | |||||
| exec(src, mat, {}, diff, grad, workspace); | |||||
| } | |||||
| virtual void exec(_megdnn_tensor_in src, | |||||
| _megdnn_tensor_in mat, | |||||
| _megdnn_tensor_in mat_idx, | |||||
| _megdnn_tensor_in diff, | |||||
| _megdnn_tensor_out grad, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| size_t get_workspace_in_bytes(const TensorLayout &src, | |||||
| const TensorLayout &mat, | |||||
| const TensorLayout &diff, | |||||
| const TensorLayout &grad) { | |||||
| return get_workspace_in_bytes(src, mat, {}, diff, grad); | |||||
| } | |||||
| virtual size_t get_workspace_in_bytes(const TensorLayout &src, | |||||
| const TensorLayout &mat, | |||||
| const TensorLayout &mat_idx, | |||||
| const TensorLayout &diff, | |||||
| const TensorLayout &grad) = 0; | |||||
| protected: | |||||
| void check_exec(const TensorLayout &src, | |||||
| const TensorLayout &mat, | |||||
| const TensorLayout &mat_idx, | |||||
| const TensorLayout &diff, | |||||
| const TensorLayout &grad, | |||||
| size_t workspace_in_bytes); | |||||
| public: | |||||
| /** | |||||
| * \param[in] src the `src' parameter in WarpPerspectiveForward::exec | |||||
| * \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec | |||||
| * \param[in] diff the backpropagated gradient wrt. dst | |||||
| * \param[out] grad the backpropagated gradient wrt. mat | |||||
| * \param[out] workspace temporary workspace to perform backward | |||||
| */ | |||||
| void exec( | |||||
| _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in diff, | |||||
| _megdnn_tensor_out grad, _megdnn_workspace workspace) { | |||||
| exec(src, mat, {}, diff, grad, workspace); | |||||
| } | |||||
| virtual void exec( | |||||
| _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, | |||||
| _megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| size_t get_workspace_in_bytes( | |||||
| const TensorLayout& src, const TensorLayout& mat, const TensorLayout& diff, | |||||
| const TensorLayout& grad) { | |||||
| return get_workspace_in_bytes(src, mat, {}, diff, grad); | |||||
| } | |||||
| virtual size_t get_workspace_in_bytes( | |||||
| const TensorLayout& src, const TensorLayout& mat, | |||||
| const TensorLayout& mat_idx, const TensorLayout& diff, | |||||
| const TensorLayout& grad) = 0; | |||||
| protected: | |||||
| void check_exec( | |||||
| const TensorLayout& src, const TensorLayout& mat, | |||||
| const TensorLayout& mat_idx, const TensorLayout& diff, | |||||
| const TensorLayout& grad, size_t workspace_in_bytes); | |||||
| }; | }; | ||||
| class DctChannelSelectForward : public OperatorBase { | class DctChannelSelectForward : public OperatorBase { | ||||
| @@ -194,37 +183,32 @@ public: | |||||
| * \param[dst] DctChannelSelectForward output, default fp32 nchw tensor | * \param[dst] DctChannelSelectForward output, default fp32 nchw tensor | ||||
| * \param[out] workspace temporary workspace to perform forward | * \param[out] workspace temporary workspace to perform forward | ||||
| */ | */ | ||||
| virtual void exec(_megdnn_tensor_in src, | |||||
| _megdnn_tensor_in mask_offset, | |||||
| _megdnn_tensor_in mask_val, | |||||
| _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| void deduce_layout(const TensorLayout& src, | |||||
| const TensorLayout& mask_offset, | |||||
| const TensorLayout& mask_val, | |||||
| TensorLayout& dst); | |||||
| virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
| const TensorLayout& mask_offset, | |||||
| const TensorLayout& mask_val, | |||||
| const TensorLayout& dst) = 0; | |||||
| virtual void exec( | |||||
| _megdnn_tensor_in src, _megdnn_tensor_in mask_offset, | |||||
| _megdnn_tensor_in mask_val, _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| void deduce_layout( | |||||
| const TensorLayout& src, const TensorLayout& mask_offset, | |||||
| const TensorLayout& mask_val, TensorLayout& dst); | |||||
| virtual size_t get_workspace_in_bytes( | |||||
| const TensorLayout& src, const TensorLayout& mask_offset, | |||||
| const TensorLayout& mask_val, const TensorLayout& dst) = 0; | |||||
| protected: | protected: | ||||
| void check_layout_fwd(const TensorLayout& src, | |||||
| const TensorLayout& mask_offset, | |||||
| const TensorLayout& mask_val, | |||||
| const TensorLayout& dst); | |||||
| void deduce_layout_fwd(const TensorLayout& src, | |||||
| const TensorLayout& mask_offset, | |||||
| const TensorLayout& mask_val, | |||||
| TensorLayout& dst); | |||||
| void check_layout_fwd( | |||||
| const TensorLayout& src, const TensorLayout& mask_offset, | |||||
| const TensorLayout& mask_val, const TensorLayout& dst); | |||||
| void deduce_layout_fwd( | |||||
| const TensorLayout& src, const TensorLayout& mask_offset, | |||||
| const TensorLayout& mask_val, TensorLayout& dst); | |||||
| std::string param_msg() const; | std::string param_msg() const; | ||||
| }; | }; | ||||
| } // namespace megdnn | |||||
| } // namespace megdnn | |||||
| #include "megdnn/internal/opr_header_epilogue.h" | #include "megdnn/internal/opr_header_epilogue.h" | ||||
| @@ -33,22 +33,22 @@ public: | |||||
| * op(A) = A if transposeA is false, otherwise op(A) = A^t. | * op(A) = A if transposeA is false, otherwise op(A) = A^t. | ||||
| * op(B) = B if transposeB is false, otherwise op(B) = B^t. | * op(B) = B if transposeB is false, otherwise op(B) = B^t. | ||||
| */ | */ | ||||
| virtual void exec(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||||
| _megdnn_tensor_out C, _megdnn_workspace workspace) = 0; | |||||
| void deduce_dtype(DType A, DType B, DType &C); | |||||
| void deduce_layout(const TensorLayout& A, const TensorLayout& B, | |||||
| TensorLayout& C); | |||||
| virtual size_t get_workspace_in_bytes(const TensorLayout& A, | |||||
| const TensorLayout& B, | |||||
| const TensorLayout& C) = 0; | |||||
| virtual void exec( | |||||
| _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| void deduce_dtype(DType A, DType B, DType& C); | |||||
| void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C); | |||||
| virtual size_t get_workspace_in_bytes( | |||||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0; | |||||
| static Algorithm::OprType get_opr_type() { | static Algorithm::OprType get_opr_type() { | ||||
| return Algorithm::OprType::BATCHED_MATRIX_MUL_FORWARD; | return Algorithm::OprType::BATCHED_MATRIX_MUL_FORWARD; | ||||
| } | } | ||||
| protected: | protected: | ||||
| void check_exec(const TensorLayout& A, const TensorLayout& B, | |||||
| const TensorLayout& C, size_t workspace_in_bytes); | |||||
| void check_exec( | |||||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | |||||
| size_t workspace_in_bytes); | |||||
| }; | }; | ||||
| using BatchedMatrixMul = BatchedMatrixMulForward; | using BatchedMatrixMul = BatchedMatrixMulForward; | ||||
| @@ -70,24 +70,24 @@ public: | |||||
| * op(A) = A if transposeA is false, otherwise op(A) = A^t. | * op(A) = A if transposeA is false, otherwise op(A) = A^t. | ||||
| * op(B) = B if transposeB is false, otherwise op(B) = B^t. | * op(B) = B if transposeB is false, otherwise op(B) = B^t. | ||||
| */ | */ | ||||
| virtual void exec(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||||
| _megdnn_tensor_out C, _megdnn_workspace workspace) = 0; | |||||
| virtual void exec( | |||||
| _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| void deduce_dtype(DType A, DType B, DType& C); | void deduce_dtype(DType A, DType B, DType& C); | ||||
| void deduce_layout(const TensorLayout& A, const TensorLayout& B, | |||||
| TensorLayout& C); | |||||
| virtual size_t get_workspace_in_bytes(const TensorLayout& A, | |||||
| const TensorLayout& B, | |||||
| const TensorLayout& C) = 0; | |||||
| void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C); | |||||
| virtual size_t get_workspace_in_bytes( | |||||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0; | |||||
| static size_t pack_size (const Param::Format format); | |||||
| static size_t pack_size(const Param::Format format); | |||||
| static Algorithm::OprType get_opr_type() { | static Algorithm::OprType get_opr_type() { | ||||
| return Algorithm::OprType::MATRIX_MUL_FORWARD; | return Algorithm::OprType::MATRIX_MUL_FORWARD; | ||||
| } | } | ||||
| protected: | protected: | ||||
| void check_exec(const TensorLayout& A, const TensorLayout& B, | |||||
| const TensorLayout& C, size_t workspace_in_bytes); | |||||
| void check_exec( | |||||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | |||||
| size_t workspace_in_bytes); | |||||
| }; | }; | ||||
| using MatrixMul = MatrixMulForward; | using MatrixMul = MatrixMulForward; | ||||
| @@ -104,11 +104,11 @@ class MatrixInverse : public OperatorBase { | |||||
| DEF_OPR_PARAM(Empty); | DEF_OPR_PARAM(Empty); | ||||
| public: | public: | ||||
| virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| virtual void exec( | |||||
| _megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| void deduce_layout(const TensorLayout& src, TensorLayout& dst); | void deduce_layout(const TensorLayout& src, TensorLayout& dst); | ||||
| size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
| const TensorLayout& dst); | |||||
| size_t get_workspace_in_bytes(const TensorLayout& src, const TensorLayout& dst); | |||||
| protected: | protected: | ||||
| /*! | /*! | ||||
| @@ -116,8 +116,7 @@ protected: | |||||
| * | * | ||||
| * Note that \p batch and \p n can be null | * Note that \p batch and \p n can be null | ||||
| */ | */ | ||||
| static void canonize_params(const TensorLayout& layout, size_t* batch, | |||||
| size_t* n); | |||||
| static void canonize_params(const TensorLayout& layout, size_t* batch, size_t* n); | |||||
| /*! | /*! | ||||
| * \brief canonize and validate input params for exec() impls | * \brief canonize and validate input params for exec() impls | ||||
| @@ -125,11 +124,12 @@ protected: | |||||
| * Since get_workspace_in_bytes() would be called, \p batch and \p n can not | * Since get_workspace_in_bytes() would be called, \p batch and \p n can not | ||||
| * be null | * be null | ||||
| */ | */ | ||||
| void check_exec(const TensorLayout& src, const TensorLayout& dst, | |||||
| _megdnn_workspace workspace, size_t* batch, size_t* n); | |||||
| void check_exec( | |||||
| const TensorLayout& src, const TensorLayout& dst, | |||||
| _megdnn_workspace workspace, size_t* batch, size_t* n); | |||||
| virtual size_t get_workspace_in_bytes(size_t batch, size_t n, | |||||
| size_t dtype_size) = 0; | |||||
| virtual size_t get_workspace_in_bytes( | |||||
| size_t batch, size_t n, size_t dtype_size) = 0; | |||||
| }; | }; | ||||
| //! inter-product of two vectors | //! inter-product of two vectors | ||||
| @@ -147,17 +147,17 @@ public: | |||||
| * A, B, C must be contiguous. A and B must have the same 1-dimensional | * A, B, C must be contiguous. A and B must have the same 1-dimensional | ||||
| * shape and non-negative strides. C must be scalar. | * shape and non-negative strides. C must be scalar. | ||||
| */ | */ | ||||
| virtual void exec(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||||
| _megdnn_tensor_out C, _megdnn_workspace workspace) = 0; | |||||
| void deduce_layout(const TensorLayout& A, const TensorLayout& B, | |||||
| TensorLayout& C); | |||||
| virtual size_t get_workspace_in_bytes(const TensorLayout& A, | |||||
| const TensorLayout& B, | |||||
| const TensorLayout& C) = 0; | |||||
| virtual void exec( | |||||
| _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C); | |||||
| virtual size_t get_workspace_in_bytes( | |||||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0; | |||||
| protected: | protected: | ||||
| void check_exec(const TensorLayout& A, const TensorLayout& B, | |||||
| const TensorLayout& C, size_t workspace_in_bytes); | |||||
| void check_exec( | |||||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | |||||
| size_t workspace_in_bytes); | |||||
| }; | }; | ||||
| using Dot = DotForward; | using Dot = DotForward; | ||||
| @@ -193,23 +193,24 @@ public: | |||||
| * if compute_uv is false (default to true). | * if compute_uv is false (default to true). | ||||
| * | * | ||||
| */ | */ | ||||
| virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out u, | |||||
| _megdnn_tensor_out s, _megdnn_tensor_out vt, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| void deduce_layout(const TensorLayout& src, TensorLayout& u, | |||||
| TensorLayout& s, TensorLayout& vt); | |||||
| size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
| const TensorLayout& u, const TensorLayout& s, | |||||
| const TensorLayout& vt); | |||||
| virtual void exec( | |||||
| _megdnn_tensor_in src, _megdnn_tensor_out u, _megdnn_tensor_out s, | |||||
| _megdnn_tensor_out vt, _megdnn_workspace workspace) = 0; | |||||
| void deduce_layout( | |||||
| const TensorLayout& src, TensorLayout& u, TensorLayout& s, | |||||
| TensorLayout& vt); | |||||
| size_t get_workspace_in_bytes( | |||||
| const TensorLayout& src, const TensorLayout& u, const TensorLayout& s, | |||||
| const TensorLayout& vt); | |||||
| protected: | protected: | ||||
| static void canonize_params(const TensorLayout& layout, size_t* batch, | |||||
| size_t* m, size_t* n); | |||||
| virtual size_t get_workspace_in_bytes(size_t block_cnt, size_t m, size_t n, | |||||
| size_t dtype_size) = 0; | |||||
| void check_exec(const TensorLayout& src, const TensorLayout& u, | |||||
| const TensorLayout& s, const TensorLayout& vt, | |||||
| size_t workspace_in_bytes); | |||||
| static void canonize_params( | |||||
| const TensorLayout& layout, size_t* batch, size_t* m, size_t* n); | |||||
| virtual size_t get_workspace_in_bytes( | |||||
| size_t block_cnt, size_t m, size_t n, size_t dtype_size) = 0; | |||||
| void check_exec( | |||||
| const TensorLayout& src, const TensorLayout& u, const TensorLayout& s, | |||||
| const TensorLayout& vt, size_t workspace_in_bytes); | |||||
| }; | }; | ||||
| using SVD = SVDForward; | using SVD = SVDForward; | ||||
| @@ -36,7 +36,7 @@ public: | |||||
| struct ModeTrait { | struct ModeTrait { | ||||
| uint32_t arity = 0; //!< number of inputs needed | uint32_t arity = 0; //!< number of inputs needed | ||||
| CheckDtypeFunc check_inp[MAX_ARITY]; | CheckDtypeFunc check_inp[MAX_ARITY]; | ||||
| SetOrCheckDtypeFunc check_out; //!< dtype of output var | |||||
| SetOrCheckDtypeFunc check_out; //!< dtype of output var | |||||
| bool need_specify_out_dtype = | bool need_specify_out_dtype = | ||||
| false; //!< the dtype should be setup externally, otherwise | false; //!< the dtype should be setup externally, otherwise | ||||
| //!< would be inferred by check_out(dtype, false) | //!< would be inferred by check_out(dtype, false) | ||||
| @@ -46,13 +46,10 @@ public: | |||||
| static const ModeTrait& from_mode(Mode mode); | static const ModeTrait& from_mode(Mode mode); | ||||
| }; | }; | ||||
| virtual void exec(_megdnn_in const TensorNDArray& src, | |||||
| _megdnn_tensor_out dst) = 0; | |||||
| virtual void exec(_megdnn_in const TensorNDArray& src, _megdnn_tensor_out dst) = 0; | |||||
| //! get trait of current mode | //! get trait of current mode | ||||
| const ModeTrait& mode_trait() const { | |||||
| return ModeTrait::from_mode(m_param.mode); | |||||
| } | |||||
| const ModeTrait& mode_trait() const { return ModeTrait::from_mode(m_param.mode); } | |||||
| //! deduce output layout | //! deduce output layout | ||||
| void deduce_layout(const TensorLayoutArray& src, TensorLayout& dst); | void deduce_layout(const TensorLayoutArray& src, TensorLayout& dst); | ||||
| @@ -60,8 +57,8 @@ public: | |||||
| protected: | protected: | ||||
| //! throw exception if incorrect layout; broadcast input shape to | //! throw exception if incorrect layout; broadcast input shape to | ||||
| //! output shape | //! output shape | ||||
| void check_layout_and_broadcast(const TensorLayoutPtrArray& src, | |||||
| const TensorLayout& dst); | |||||
| void check_layout_and_broadcast( | |||||
| const TensorLayoutPtrArray& src, const TensorLayout& dst); | |||||
| }; | }; | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -15,84 +15,97 @@ | |||||
| namespace megdnn { | namespace megdnn { | ||||
| //! base class for random number generators | //! base class for random number generators | ||||
| class RNGBase: public OperatorBase { | |||||
| class RNGBase : public OperatorBase { | |||||
| DEF_OPR_IMPL_CTOR(RNGBase, OperatorBase); | DEF_OPR_IMPL_CTOR(RNGBase, OperatorBase); | ||||
| public: | |||||
| virtual void exec(_megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| virtual size_t get_workspace_in_bytes(const TensorLayout &dst) = 0; | |||||
| protected: | |||||
| virtual void check_exec(const TensorLayout &dst, size_t workspace_in_bytes) = 0; | |||||
| public: | |||||
| virtual void exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; | |||||
| virtual size_t get_workspace_in_bytes(const TensorLayout& dst) = 0; | |||||
| protected: | |||||
| virtual void check_exec(const TensorLayout& dst, size_t workspace_in_bytes) = 0; | |||||
| }; | }; | ||||
| //! sample from poisson distribution | //! sample from poisson distribution | ||||
| class PoissonRNG: public OperatorBase { | |||||
| class PoissonRNG : public OperatorBase { | |||||
| DEF_OPR_IMPL(PoissonRNG, OperatorBase, 1, 1); | DEF_OPR_IMPL(PoissonRNG, OperatorBase, 1, 1); | ||||
| DEF_OPR_PARAM(PoissonRNG); | DEF_OPR_PARAM(PoissonRNG); | ||||
| public: | |||||
| virtual void exec(_megdnn_tensor_in lam, | |||||
| _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| virtual size_t get_workspace_in_bytes(const TensorLayout &lam, | |||||
| const TensorLayout &dst) = 0; | |||||
| protected: | |||||
| void check_exec(const TensorLayout &lam, const TensorLayout &dst, | |||||
| size_t workspace_in_bytes); | |||||
| public: | |||||
| virtual void exec( | |||||
| _megdnn_tensor_in lam, _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| virtual size_t get_workspace_in_bytes( | |||||
| const TensorLayout& lam, const TensorLayout& dst) = 0; | |||||
| protected: | |||||
| void check_exec( | |||||
| const TensorLayout& lam, const TensorLayout& dst, | |||||
| size_t workspace_in_bytes); | |||||
| }; | }; | ||||
| //! sample from beta distribution | //! sample from beta distribution | ||||
| class BetaRNG: public OperatorBase { | |||||
| class BetaRNG : public OperatorBase { | |||||
| DEF_OPR_IMPL(BetaRNG, OperatorBase, 2, 1); | DEF_OPR_IMPL(BetaRNG, OperatorBase, 2, 1); | ||||
| DEF_OPR_PARAM(BetaRNG); | DEF_OPR_PARAM(BetaRNG); | ||||
| public: | |||||
| virtual void exec(_megdnn_tensor_in alpha, | |||||
| _megdnn_tensor_in beta, | |||||
| _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| virtual size_t get_workspace_in_bytes(const TensorLayout &alpha, | |||||
| const TensorLayout &beta, const TensorLayout &dst) = 0; | |||||
| protected: | |||||
| void check_exec(const TensorLayout &alpha, const TensorLayout &beta, | |||||
| const TensorLayout &dst, size_t workspace_in_bytes); | |||||
| public: | |||||
| virtual void exec( | |||||
| _megdnn_tensor_in alpha, _megdnn_tensor_in beta, _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| virtual size_t get_workspace_in_bytes( | |||||
| const TensorLayout& alpha, const TensorLayout& beta, | |||||
| const TensorLayout& dst) = 0; | |||||
| protected: | |||||
| void check_exec( | |||||
| const TensorLayout& alpha, const TensorLayout& beta, | |||||
| const TensorLayout& dst, size_t workspace_in_bytes); | |||||
| }; | }; | ||||
| //! sample from gamma distribution | //! sample from gamma distribution | ||||
| class GammaRNG: public OperatorBase { | |||||
| class GammaRNG : public OperatorBase { | |||||
| DEF_OPR_IMPL(GammaRNG, OperatorBase, 2, 1); | DEF_OPR_IMPL(GammaRNG, OperatorBase, 2, 1); | ||||
| DEF_OPR_PARAM(GammaRNG); | DEF_OPR_PARAM(GammaRNG); | ||||
| public: | |||||
| virtual void exec(_megdnn_tensor_in shape, | |||||
| _megdnn_tensor_in scale, | |||||
| _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| virtual size_t get_workspace_in_bytes(const TensorLayout &shape, | |||||
| const TensorLayout &scale, const TensorLayout &dst) = 0; | |||||
| protected: | |||||
| void check_exec(const TensorLayout &shape, const TensorLayout &scale, | |||||
| const TensorLayout &dst, size_t workspace_in_bytes); | |||||
| public: | |||||
| virtual void exec( | |||||
| _megdnn_tensor_in shape, _megdnn_tensor_in scale, _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| virtual size_t get_workspace_in_bytes( | |||||
| const TensorLayout& shape, const TensorLayout& scale, | |||||
| const TensorLayout& dst) = 0; | |||||
| protected: | |||||
| void check_exec( | |||||
| const TensorLayout& shape, const TensorLayout& scale, | |||||
| const TensorLayout& dst, size_t workspace_in_bytes); | |||||
| }; | }; | ||||
| //! sample from uniform distribution on the interval (0, 1] | //! sample from uniform distribution on the interval (0, 1] | ||||
| class UniformRNG: public RNGBase { | |||||
| class UniformRNG : public RNGBase { | |||||
| DEF_OPR_IMPL(UniformRNG, RNGBase, 0, 1); | DEF_OPR_IMPL(UniformRNG, RNGBase, 0, 1); | ||||
| DEF_OPR_PARAM(UniformRNG); | DEF_OPR_PARAM(UniformRNG); | ||||
| protected: | |||||
| void check_exec(const TensorLayout &dst, size_t workspace_in_bytes); | |||||
| protected: | |||||
| void check_exec(const TensorLayout& dst, size_t workspace_in_bytes); | |||||
| }; | }; | ||||
| //! sample from gaussian distribution | //! sample from gaussian distribution | ||||
| class GaussianRNG: public RNGBase { | |||||
| class GaussianRNG : public RNGBase { | |||||
| DEF_OPR_IMPL(GaussianRNG, RNGBase, 0, 1); | DEF_OPR_IMPL(GaussianRNG, RNGBase, 0, 1); | ||||
| DEF_OPR_PARAM(GaussianRNG); | DEF_OPR_PARAM(GaussianRNG); | ||||
| protected: | |||||
| void check_exec(const TensorLayout &dst, size_t workspace_in_bytes); | |||||
| protected: | |||||
| void check_exec(const TensorLayout& dst, size_t workspace_in_bytes); | |||||
| }; | }; | ||||
| class PermutationRNG: public RNGBase { | |||||
| class PermutationRNG : public RNGBase { | |||||
| DEF_OPR_IMPL(PermutationRNG, RNGBase, 0, 1); | DEF_OPR_IMPL(PermutationRNG, RNGBase, 0, 1); | ||||
| DEF_OPR_PARAM(PermutationRNG); | DEF_OPR_PARAM(PermutationRNG); | ||||
| protected: | |||||
| void check_exec(const TensorLayout &dst, size_t workspace_in_bytes); | |||||
| protected: | |||||
| void check_exec(const TensorLayout& dst, size_t workspace_in_bytes); | |||||
| }; | }; | ||||
| class ShuffleRNGForward : public OperatorBase { | class ShuffleRNGForward : public OperatorBase { | ||||
| @@ -100,18 +113,19 @@ class ShuffleRNGForward : public OperatorBase { | |||||
| DEF_OPR_PARAM(ShuffleRNG); | DEF_OPR_PARAM(ShuffleRNG); | ||||
| public: | public: | ||||
| virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
| _megdnn_tensor_out indices, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| void deduce_layout(const TensorLayout& src, TensorLayout& dst, | |||||
| TensorLayout& indices); | |||||
| virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
| const TensorLayout& dst, | |||||
| const TensorLayout& indices) = 0; | |||||
| virtual void exec( | |||||
| _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_tensor_out indices, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| void deduce_layout( | |||||
| const TensorLayout& src, TensorLayout& dst, TensorLayout& indices); | |||||
| virtual size_t get_workspace_in_bytes( | |||||
| const TensorLayout& src, const TensorLayout& dst, | |||||
| const TensorLayout& indices) = 0; | |||||
| protected: | protected: | ||||
| void check_exec(const TensorLayout& src, const TensorLayout& dst, | |||||
| const TensorLayout& indices, size_t workspace_in_bytes); | |||||
| void check_exec( | |||||
| const TensorLayout& src, const TensorLayout& dst, | |||||
| const TensorLayout& indices, size_t workspace_in_bytes); | |||||
| }; | }; | ||||
| using ShuffleRNG = ShuffleRNGForward; | using ShuffleRNG = ShuffleRNGForward; | ||||
| @@ -120,27 +134,29 @@ class ShuffleRNGBackward : public OperatorBase { | |||||
| DEF_OPR_PARAM(ShuffleRNG); | DEF_OPR_PARAM(ShuffleRNG); | ||||
| public: | public: | ||||
| virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in indices, | |||||
| _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; | |||||
| virtual size_t get_workspace_in_bytes(const TensorLayout& diff, | |||||
| const TensorLayout& indices, | |||||
| const TensorLayout& grad) = 0; | |||||
| virtual void exec( | |||||
| _megdnn_tensor_in diff, _megdnn_tensor_in indices, _megdnn_tensor_out grad, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| virtual size_t get_workspace_in_bytes( | |||||
| const TensorLayout& diff, const TensorLayout& indices, | |||||
| const TensorLayout& grad) = 0; | |||||
| protected: | protected: | ||||
| void check_exec(const TensorLayout& diff, const TensorLayout& indices, | |||||
| const TensorLayout& grad, size_t workspace_in_bytes); | |||||
| void check_exec( | |||||
| const TensorLayout& diff, const TensorLayout& indices, | |||||
| const TensorLayout& grad, size_t workspace_in_bytes); | |||||
| }; | }; | ||||
| /*! | /*! | ||||
| * \brief sleep for specific time on the computing device; useful for testing | * \brief sleep for specific time on the computing device; useful for testing | ||||
| * async problems | * async problems | ||||
| */ | */ | ||||
| class SleepForward: public OperatorBase { | |||||
| class SleepForward : public OperatorBase { | |||||
| DEF_OPR_IMPL(SleepForward, OperatorBase, 0, 0); | DEF_OPR_IMPL(SleepForward, OperatorBase, 0, 0); | ||||
| DEF_OPR_PARAM(Sleep); | DEF_OPR_PARAM(Sleep); | ||||
| public: | |||||
| virtual void exec() = 0; | |||||
| public: | |||||
| virtual void exec() = 0; | |||||
| }; | }; | ||||
| using Sleep = SleepForward; | using Sleep = SleepForward; | ||||
| @@ -149,20 +165,19 @@ using Sleep = SleepForward; | |||||
| * | * | ||||
| * data must be a one-dimensional contiguous tensor with dtype byte | * data must be a one-dimensional contiguous tensor with dtype byte | ||||
| */ | */ | ||||
| class ChecksumForward: public OperatorBase { | |||||
| class ChecksumForward : public OperatorBase { | |||||
| DEF_OPR_PARAM(Empty); | DEF_OPR_PARAM(Empty); | ||||
| DEF_OPR_IMPL(ChecksumForward, OperatorBase, 0, 1); | DEF_OPR_IMPL(ChecksumForward, OperatorBase, 0, 1); | ||||
| public: | |||||
| using Result = opr_result::Checksum; | |||||
| public: | |||||
| using Result = opr_result::Checksum; | |||||
| virtual size_t get_workspace_in_bytes(const TensorLayout &data) = 0; | |||||
| virtual size_t get_workspace_in_bytes(const TensorLayout& data) = 0; | |||||
| virtual Result exec(_megdnn_tensor_in data, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| virtual Result exec(_megdnn_tensor_in data, _megdnn_workspace workspace) = 0; | |||||
| protected: | |||||
| void check_exec(const TensorLayout &layout, size_t workspace_in_bytes); | |||||
| protected: | |||||
| void check_exec(const TensorLayout& layout, size_t workspace_in_bytes); | |||||
| }; | }; | ||||
| using Checksum = ChecksumForward; | using Checksum = ChecksumForward; | ||||
| @@ -175,21 +190,22 @@ class MaxTensorDiff : public OperatorBase { | |||||
| DEF_OPR_PARAM(Empty); | DEF_OPR_PARAM(Empty); | ||||
| DEF_OPR_IMPL(MaxTensorDiff, OperatorBase, 0, 2); | DEF_OPR_IMPL(MaxTensorDiff, OperatorBase, 0, 2); | ||||
| public: | |||||
| virtual size_t get_workspace_in_bytes(const TensorLayout& layout1, | |||||
| const TensorLayout& layout2) = 0; | |||||
| public: | |||||
| virtual size_t get_workspace_in_bytes( | |||||
| const TensorLayout& layout1, const TensorLayout& layout2) = 0; | |||||
| virtual float exec(_megdnn_tensor_in src1, _megdnn_tensor_in src2, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| virtual float exec( | |||||
| _megdnn_tensor_in src1, _megdnn_tensor_in src2, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| protected: | |||||
| void check_exec(const TensorLayout& layout1, | |||||
| const TensorLayout& layout2, size_t workspace_in_bytes); | |||||
| protected: | |||||
| void check_exec( | |||||
| const TensorLayout& layout1, const TensorLayout& layout2, | |||||
| size_t workspace_in_bytes); | |||||
| }; | }; | ||||
| bool check_bias_share_in_channel(const TensorLayout& bias, | |||||
| const param::ConvBias::Format format); | |||||
| bool check_bias_share_in_channel( | |||||
| const TensorLayout& bias, const param::ConvBias::Format format); | |||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -18,9 +18,9 @@ | |||||
| namespace megdnn { | namespace megdnn { | ||||
| enum class TensorFormat::Type { | enum class TensorFormat::Type { | ||||
| DEFAULT = 0, //!< see DefaultTensorFormat | |||||
| IMAGE2D_PACK4 = 1, //!< see Image2DPack4TensorFormat | |||||
| LOWBITS_ALIGNED_TO_BYTE = 2, //!< | |||||
| DEFAULT = 0, //!< see DefaultTensorFormat | |||||
| IMAGE2D_PACK4 = 1, //!< see Image2DPack4TensorFormat | |||||
| LOWBITS_ALIGNED_TO_BYTE = 2, //!< | |||||
| }; | }; | ||||
| class TensorFormat::ImplBase { | class TensorFormat::ImplBase { | ||||
| @@ -33,8 +33,7 @@ public: | |||||
| virtual bool is_contiguous_spec(const TensorLayout& layout) const = 0; | virtual bool is_contiguous_spec(const TensorLayout& layout) const = 0; | ||||
| virtual TensorLayout collapse_contiguous_spec( | |||||
| const TensorLayout& layout) const = 0; | |||||
| virtual TensorLayout collapse_contiguous_spec(const TensorLayout& layout) const = 0; | |||||
| virtual TensorLayout::Span span_spec(const TensorLayout& layout) const = 0; | virtual TensorLayout::Span span_spec(const TensorLayout& layout) const = 0; | ||||
| @@ -79,8 +78,7 @@ public: | |||||
| */ | */ | ||||
| bool is_contiguous_spec(const TensorLayout& layout) const override; | bool is_contiguous_spec(const TensorLayout& layout) const override; | ||||
| TensorLayout collapse_contiguous_spec( | |||||
| const TensorLayout& layout) const override; | |||||
| TensorLayout collapse_contiguous_spec(const TensorLayout& layout) const override; | |||||
| TensorLayout::Span span_spec(const TensorLayout& layout) const override; | TensorLayout::Span span_spec(const TensorLayout& layout) const override; | ||||
| @@ -88,8 +86,7 @@ public: | |||||
| void serialize_append(std::string& result) const override; | void serialize_append(std::string& result) const override; | ||||
| static TensorFormat make(); | static TensorFormat make(); | ||||
| static TensorFormat deserialize(const Handle* handle, const void* buf, | |||||
| size_t size); | |||||
| static TensorFormat deserialize(const Handle* handle, const void* buf, size_t size); | |||||
| }; | }; | ||||
| namespace detail { | namespace detail { | ||||
| @@ -112,8 +109,8 @@ class Image2DTensorFormatBase : public TensorFormat::ImplBase { | |||||
| size_t m_align_axis, m_align_size_in_elements_log2; | size_t m_align_axis, m_align_size_in_elements_log2; | ||||
| protected: | protected: | ||||
| Image2DTensorFormatBase(Type type, size_t align_axis, | |||||
| size_t align_size_in_elements); | |||||
| Image2DTensorFormatBase( | |||||
| Type type, size_t align_axis, size_t align_size_in_elements); | |||||
| virtual ~Image2DTensorFormatBase() = default; | virtual ~Image2DTensorFormatBase() = default; | ||||
| public: | public: | ||||
| @@ -129,9 +126,7 @@ public: | |||||
| size_t align_axis() const { return m_align_axis; } | size_t align_axis() const { return m_align_axis; } | ||||
| size_t align_size_in_elements_log2() const { | |||||
| return m_align_size_in_elements_log2; | |||||
| } | |||||
| size_t align_size_in_elements_log2() const { return m_align_size_in_elements_log2; } | |||||
| std::string to_string() const override; | std::string to_string() const override; | ||||
| @@ -145,6 +140,7 @@ public: | |||||
| size_t image_height(const TensorLayout& layout) const; | size_t image_height(const TensorLayout& layout) const; | ||||
| void serialize_append(std::string& result) const override; | void serialize_append(std::string& result) const override; | ||||
| protected: | protected: | ||||
| struct SerializePack { | struct SerializePack { | ||||
| uint8_t align_axis; | uint8_t align_axis; | ||||
| @@ -160,15 +156,14 @@ class Image2DPackedTensorFormatBase : public Image2DTensorFormatBase { | |||||
| * align COUNT, but mdl needs align size in byte, which equal to | * align COUNT, but mdl needs align size in byte, which equal to | ||||
| * (image_width algin count) * sizeof(data_type) * pixel_size | * (image_width algin count) * sizeof(data_type) * pixel_size | ||||
| */ | */ | ||||
| size_t image_pitch_alignment_in_bytes(size_t align_size_in_elements, | |||||
| const TensorLayout& layout) const; | |||||
| size_t image_pitch_alignment_in_bytes( | |||||
| size_t align_size_in_elements, const TensorLayout& layout) const; | |||||
| protected: | protected: | ||||
| Image2DPackedTensorFormatBase(Type type, size_t align_axis, | |||||
| size_t align_size_in_elements, | |||||
| Handle::HandleVendorType vendor_type) | |||||
| : detail::Image2DTensorFormatBase(type, align_axis, | |||||
| align_size_in_elements), | |||||
| Image2DPackedTensorFormatBase( | |||||
| Type type, size_t align_axis, size_t align_size_in_elements, | |||||
| Handle::HandleVendorType vendor_type) | |||||
| : detail::Image2DTensorFormatBase(type, align_axis, align_size_in_elements), | |||||
| m_vendor_type(vendor_type) {} | m_vendor_type(vendor_type) {} | ||||
| virtual ~Image2DPackedTensorFormatBase() = default; | virtual ~Image2DPackedTensorFormatBase() = default; | ||||
| @@ -197,13 +192,12 @@ public: | |||||
| bool is_contiguous_spec(const TensorLayout& layout) const override; | bool is_contiguous_spec(const TensorLayout& layout) const override; | ||||
| TensorLayout collapse_contiguous_spec( | |||||
| const TensorLayout& layout) const override; | |||||
| TensorLayout collapse_contiguous_spec(const TensorLayout& layout) const override; | |||||
| }; | }; | ||||
| using Image2DPack4TensorFormatBase = Image2DPackedTensorFormatBase<4>; | using Image2DPack4TensorFormatBase = Image2DPackedTensorFormatBase<4>; | ||||
| /*! | /*! | ||||
| * \brief used for tensors storing lowbit data | |||||
| * \brief used for tensors storing lowbit data | |||||
| * | * | ||||
| * \param m_size_nbits size in bits of elements in the tensor | * \param m_size_nbits size in bits of elements in the tensor | ||||
| * \param m_align_size_in_bits aligned size in bits | * \param m_align_size_in_bits aligned size in bits | ||||
| @@ -213,14 +207,14 @@ class LowbitsAlignedTensorFormatBase : public TensorFormat::ImplBase { | |||||
| size_t m_size_nbits, m_align_size_in_bits, m_align_size_in_elements; | size_t m_size_nbits, m_align_size_in_bits, m_align_size_in_elements; | ||||
| protected: //? | protected: //? | ||||
| LowbitsAlignedTensorFormatBase(Type type, size_t size_nbits, | |||||
| size_t align_size_in_bits); | |||||
| LowbitsAlignedTensorFormatBase( | |||||
| Type type, size_t size_nbits, size_t align_size_in_bits); | |||||
| virtual ~LowbitsAlignedTensorFormatBase() = default; | virtual ~LowbitsAlignedTensorFormatBase() = default; | ||||
| public: | public: | ||||
| size_t align_size_in_bits() const { return m_align_size_in_bits; } | size_t align_size_in_bits() const { return m_align_size_in_bits; } | ||||
| size_t size_nbits() const { return m_size_nbits; } | size_t size_nbits() const { return m_size_nbits; } | ||||
| std::string to_string() const override; | std::string to_string() const override; | ||||
| @@ -238,8 +232,8 @@ public: | |||||
| bool is_contiguous_spec(const TensorLayout& layout) const override; | bool is_contiguous_spec(const TensorLayout& layout) const override; | ||||
| TensorLayout collapse_contiguous_spec( | |||||
| const TensorLayout& layout) const override; | |||||
| TensorLayout collapse_contiguous_spec(const TensorLayout& layout) const override; | |||||
| protected: | protected: | ||||
| struct SerializePack { | struct SerializePack { | ||||
| uint8_t size_nbits; | uint8_t size_nbits; | ||||
| @@ -254,16 +248,14 @@ protected: | |||||
| * | * | ||||
| * This is used for OpenCL. | * This is used for OpenCL. | ||||
| */ | */ | ||||
| class Image2DPack4TensorFormat final | |||||
| : public detail::Image2DPack4TensorFormatBase { | |||||
| class Image2DPack4TensorFormat final : public detail::Image2DPack4TensorFormatBase { | |||||
| public: | public: | ||||
| static constexpr Type TYPE = Type::IMAGE2D_PACK4; | static constexpr Type TYPE = Type::IMAGE2D_PACK4; | ||||
| //! for internal usage or test purposes | //! for internal usage or test purposes | ||||
| static TensorFormat make_raw(size_t align_axis, | |||||
| size_t align_size_in_elements, | |||||
| Handle::HandleVendorType vendor_type = | |||||
| Handle::HandleVendorType::NOT_SPEC); | |||||
| static TensorFormat make_raw( | |||||
| size_t align_axis, size_t align_size_in_elements, | |||||
| Handle::HandleVendorType vendor_type = Handle::HandleVendorType::NOT_SPEC); | |||||
| static TensorFormat make(size_t align_axis, const Handle* handle); | static TensorFormat make(size_t align_axis, const Handle* handle); | ||||
| @@ -273,13 +265,11 @@ public: | |||||
| * Note that the alignment may be different if deserialized on another | * Note that the alignment may be different if deserialized on another | ||||
| * handle | * handle | ||||
| */ | */ | ||||
| static TensorFormat deserialize(const Handle* handle, const void* buf, | |||||
| size_t size); | |||||
| static TensorFormat deserialize(const Handle* handle, const void* buf, size_t size); | |||||
| static bool is_valid_image(const TensorLayout& layout) { | static bool is_valid_image(const TensorLayout& layout) { | ||||
| if (layout.format.type() == TYPE) { | if (layout.format.type() == TYPE) { | ||||
| layout.format.as_impl<Image2DPack4TensorFormat>().assert_valid( | |||||
| layout); | |||||
| layout.format.as_impl<Image2DPack4TensorFormat>().assert_valid(layout); | |||||
| return true; | return true; | ||||
| } | } | ||||
| return false; | return false; | ||||
| @@ -288,8 +278,9 @@ public: | |||||
| TensorFormat change_axis(size_t axis) const override; | TensorFormat change_axis(size_t axis) const override; | ||||
| private: | private: | ||||
| Image2DPack4TensorFormat(size_t align_axis, size_t align_size_in_elements, | |||||
| Handle::HandleVendorType vendor_type) | |||||
| Image2DPack4TensorFormat( | |||||
| size_t align_axis, size_t align_size_in_elements, | |||||
| Handle::HandleVendorType vendor_type) | |||||
| : detail::Image2DPack4TensorFormatBase( | : detail::Image2DPack4TensorFormatBase( | ||||
| TYPE, align_axis, align_size_in_elements, vendor_type) {} | TYPE, align_axis, align_size_in_elements, vendor_type) {} | ||||
| }; | }; | ||||
| @@ -306,13 +297,12 @@ public: | |||||
| static TensorFormat make(size_t size_nbits); | static TensorFormat make(size_t size_nbits); | ||||
| static TensorFormat deserialize(const Handle* handle, const void* buf, | |||||
| size_t size); | |||||
| static TensorFormat deserialize(const Handle* handle, const void* buf, size_t size); | |||||
| static bool is_valid_layout(const TensorLayout& layout) { | static bool is_valid_layout(const TensorLayout& layout) { | ||||
| if (layout.format.type() == TYPE) { | if (layout.format.type() == TYPE) { | ||||
| layout.format.as_impl<LowbitsAlignedToBytesTensorFormat>() | |||||
| .assert_valid(layout); | |||||
| layout.format.as_impl<LowbitsAlignedToBytesTensorFormat>().assert_valid( | |||||
| layout); | |||||
| return true; | return true; | ||||
| } | } | ||||
| return false; | return false; | ||||
| @@ -320,8 +310,7 @@ public: | |||||
| private: | private: | ||||
| LowbitsAlignedToBytesTensorFormat(size_t size_nbits) | LowbitsAlignedToBytesTensorFormat(size_t size_nbits) | ||||
| : detail::LowbitsAlignedTensorFormatBase(TYPE, size_nbits, | |||||
| BYTE_IN_BITS) {} | |||||
| : detail::LowbitsAlignedTensorFormatBase(TYPE, size_nbits, BYTE_IN_BITS) {} | |||||
| }; | }; | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -167,13 +167,11 @@ public: | |||||
| TensorIter(const TensorND& tensor) : m_tensor(tensor) {} | TensorIter(const TensorND& tensor) : m_tensor(tensor) {} | ||||
| Iter begin() const { | |||||
| return Iter::make(const_cast<TensorND&>(m_tensor), 0); | |||||
| } | |||||
| Iter begin() const { return Iter::make(const_cast<TensorND&>(m_tensor), 0); } | |||||
| Iter end() const { | Iter end() const { | ||||
| return Iter::make(const_cast<TensorND&>(m_tensor), | |||||
| m_tensor.layout.total_nr_elems()); | |||||
| return Iter::make( | |||||
| const_cast<TensorND&>(m_tensor), m_tensor.layout.total_nr_elems()); | |||||
| } | } | ||||
| }; | }; | ||||
| /*! | /*! | ||||
| @@ -11,19 +11,19 @@ | |||||
| #pragma once | #pragma once | ||||
| #include <type_traits> | |||||
| #include <cstdlib> | |||||
| #include <functional> | #include <functional> | ||||
| #include <utility> | |||||
| #include <memory> | #include <memory> | ||||
| #include <cstdlib> | |||||
| #include <type_traits> | |||||
| #include <utility> | |||||
| #include "megdnn/internal/visibility_prologue.h" | #include "megdnn/internal/visibility_prologue.h" | ||||
| namespace megdnn { | namespace megdnn { | ||||
| template<typename Signature> | |||||
| template <typename Signature> | |||||
| using thin_function = ::std::function<Signature>; | using thin_function = ::std::function<Signature>; | ||||
| } // namespace megdnn | |||||
| } // namespace megdnn | |||||
| #include "megdnn/internal/visibility_epilogue.h" | #include "megdnn/internal/visibility_epilogue.h" | ||||
| @@ -58,18 +58,16 @@ protected: | |||||
| m_end_ptr(first_elm), | m_end_ptr(first_elm), | ||||
| m_capacity_ptr(static_cast<char*>(first_elm) + size) {} | m_capacity_ptr(static_cast<char*>(first_elm) + size) {} | ||||
| void grow_pod(void* first_elm_ptr, size_t min_sz_in_bytes, | |||||
| size_t type_size); | |||||
| void grow_pod(void* first_elm_ptr, size_t min_sz_in_bytes, size_t type_size); | |||||
| public: | public: | ||||
| size_t size_in_bytes() const { | size_t size_in_bytes() const { | ||||
| return size_t(static_cast<char*>(m_end_ptr) - | |||||
| static_cast<char*>(m_begin_ptr)); | |||||
| return size_t(static_cast<char*>(m_end_ptr) - static_cast<char*>(m_begin_ptr)); | |||||
| } | } | ||||
| size_t capacity_in_bytes() const { | size_t capacity_in_bytes() const { | ||||
| return size_t(static_cast<char*>(m_capacity_ptr) - | |||||
| static_cast<char*>(m_begin_ptr)); | |||||
| return size_t( | |||||
| static_cast<char*>(m_capacity_ptr) - static_cast<char*>(m_begin_ptr)); | |||||
| } | } | ||||
| bool empty() const { return m_begin_ptr == m_end_ptr; } | bool empty() const { return m_begin_ptr == m_end_ptr; } | ||||
| @@ -85,20 +83,15 @@ private: | |||||
| U m_first_elm; | U m_first_elm; | ||||
| protected: | protected: | ||||
| SmallVectorTemplateCommon(size_t size) | |||||
| : SmallVectorBase(&m_first_elm, size) {} | |||||
| SmallVectorTemplateCommon(size_t size) : SmallVectorBase(&m_first_elm, size) {} | |||||
| void grow_pod(size_t min_sz_in_bytes, size_t type_size) { | void grow_pod(size_t min_sz_in_bytes, size_t type_size) { | ||||
| SmallVectorBase::grow_pod(&m_first_elm, min_sz_in_bytes, type_size); | SmallVectorBase::grow_pod(&m_first_elm, min_sz_in_bytes, type_size); | ||||
| } | } | ||||
| bool is_small() { | |||||
| return m_begin_ptr == static_cast<const void*>(&m_first_elm); | |||||
| } | |||||
| bool is_small() { return m_begin_ptr == static_cast<const void*>(&m_first_elm); } | |||||
| void reset_to_small() { | |||||
| m_begin_ptr = m_end_ptr = m_capacity_ptr = &m_first_elm; | |||||
| } | |||||
| void reset_to_small() { m_begin_ptr = m_end_ptr = m_capacity_ptr = &m_first_elm; } | |||||
| void set_end(T* p) { m_end_ptr = p; } | void set_end(T* p) { m_end_ptr = p; } | ||||
| @@ -128,20 +121,12 @@ protected: | |||||
| public: | public: | ||||
| // forwarding iterator creation | // forwarding iterator creation | ||||
| iterator begin() { return static_cast<iterator>(m_begin_ptr); } | iterator begin() { return static_cast<iterator>(m_begin_ptr); } | ||||
| const_iterator begin() const { | |||||
| return static_cast<const_iterator>(m_begin_ptr); | |||||
| } | |||||
| const_iterator cbegin() const { | |||||
| return static_cast<const_iterator>(m_begin_ptr); | |||||
| } | |||||
| const_iterator begin() const { return static_cast<const_iterator>(m_begin_ptr); } | |||||
| const_iterator cbegin() const { return static_cast<const_iterator>(m_begin_ptr); } | |||||
| iterator end() { return static_cast<iterator>(m_end_ptr); } | iterator end() { return static_cast<iterator>(m_end_ptr); } | ||||
| const_iterator end() const { | |||||
| return static_cast<const_iterator>(m_end_ptr); | |||||
| } | |||||
| const_iterator cend() const { | |||||
| return static_cast<const_iterator>(m_end_ptr); | |||||
| } | |||||
| const_iterator end() const { return static_cast<const_iterator>(m_end_ptr); } | |||||
| const_iterator cend() const { return static_cast<const_iterator>(m_end_ptr); } | |||||
| reference at(size_type idx) { | reference at(size_type idx) { | ||||
| if (idx >= size()) { | if (idx >= size()) { | ||||
| @@ -167,13 +152,9 @@ public: | |||||
| // reverse iterator creation method. | // reverse iterator creation method. | ||||
| reverse_iterator rbegin() { return reverse_iterator(end()); } | reverse_iterator rbegin() { return reverse_iterator(end()); } | ||||
| const_reverse_iterator rbegin() const { | |||||
| return const_reverse_iterator(end()); | |||||
| } | |||||
| const_reverse_iterator rbegin() const { return const_reverse_iterator(end()); } | |||||
| reverse_iterator rend() { return reverse_iterator(begin()); } | reverse_iterator rend() { return reverse_iterator(begin()); } | ||||
| const_reverse_iterator rend() const { | |||||
| return const_reverse_iterator(begin()); | |||||
| } | |||||
| const_reverse_iterator rend() const { return const_reverse_iterator(begin()); } | |||||
| pointer data() { return pointer(begin()); } | pointer data() { return pointer(begin()); } | ||||
| const_pointer data() const { return const_pointer(begin()); } | const_pointer data() const { return const_pointer(begin()); } | ||||
| @@ -207,8 +188,8 @@ protected: | |||||
| template <typename It1, typename It2> | template <typename It1, typename It2> | ||||
| static void uninitialized_move(It1 first, It1 last, It2 dest) { | static void uninitialized_move(It1 first, It1 last, It2 dest) { | ||||
| std::uninitialized_copy(std::make_move_iterator(first), | |||||
| std::make_move_iterator(last), dest); | |||||
| std::uninitialized_copy( | |||||
| std::make_move_iterator(first), std::make_move_iterator(last), dest); | |||||
| } | } | ||||
| template <typename It1, typename It2> | template <typename It1, typename It2> | ||||
| @@ -293,9 +274,7 @@ protected: | |||||
| memcpy(dest, first, (last - first) * sizeof(T)); | memcpy(dest, first, (last - first) * sizeof(T)); | ||||
| } | } | ||||
| void grow(size_t min_sz = 0) { | |||||
| this->grow_pod(min_sz * sizeof(T), sizeof(T)); | |||||
| } | |||||
| void grow(size_t min_sz = 0) { this->grow_pod(min_sz * sizeof(T), sizeof(T)); } | |||||
| public: | public: | ||||
| void push_back(const T& _elm) { | void push_back(const T& _elm) { | ||||
| @@ -318,8 +297,7 @@ public: | |||||
| * SmallVector<T, N> can be converted to SmallVectorImpl<T> to erase N | * SmallVector<T, N> can be converted to SmallVectorImpl<T> to erase N | ||||
| */ | */ | ||||
| template <typename T> | template <typename T> | ||||
| class SmallVectorImpl | |||||
| : public SmallVectorTemplateBase<T, std::is_pod<T>::value> { | |||||
| class SmallVectorImpl : public SmallVectorTemplateBase<T, std::is_pod<T>::value> { | |||||
| using SuperClass = SmallVectorTemplateBase<T, std::is_pod<T>::value>; | using SuperClass = SmallVectorTemplateBase<T, std::is_pod<T>::value>; | ||||
| public: | public: | ||||
| @@ -329,8 +307,7 @@ public: | |||||
| protected: | protected: | ||||
| explicit SmallVectorImpl(unsigned n) | explicit SmallVectorImpl(unsigned n) | ||||
| : SmallVectorTemplateBase<T, std::is_pod<T>::value>(n * sizeof(T)) { | |||||
| } | |||||
| : SmallVectorTemplateBase<T, std::is_pod<T>::value>(n * sizeof(T)) {} | |||||
| public: | public: | ||||
| SmallVectorImpl(const SmallVectorImpl&) = delete; | SmallVectorImpl(const SmallVectorImpl&) = delete; | ||||
| @@ -354,8 +331,7 @@ public: | |||||
| } else if (n > this->size()) { | } else if (n > this->size()) { | ||||
| if (this->capacity() < n) | if (this->capacity() < n) | ||||
| this->grow(n); | this->grow(n); | ||||
| for (auto it = this->end(), end = this->begin() + n; it != end; | |||||
| ++it) | |||||
| for (auto it = this->end(), end = this->begin() + n; it != end; ++it) | |||||
| new (&*it) T(); | new (&*it) T(); | ||||
| this->set_end(this->begin() + n); | this->set_end(this->begin() + n); | ||||
| } | } | ||||
| @@ -389,10 +365,11 @@ public: | |||||
| void swap(SmallVectorImpl<T>& rhs); | void swap(SmallVectorImpl<T>& rhs); | ||||
| /// Add the specified range to the end of the SmallVector. | /// Add the specified range to the end of the SmallVector. | ||||
| template <typename in_iter, | |||||
| typename = typename std::enable_if<std::is_convertible< | |||||
| typename std::iterator_traits<in_iter>::iterator_category, | |||||
| std::input_iterator_tag>::value>::type> | |||||
| template < | |||||
| typename in_iter, | |||||
| typename = typename std::enable_if<std::is_convertible< | |||||
| typename std::iterator_traits<in_iter>::iterator_category, | |||||
| std::input_iterator_tag>::value>::type> | |||||
| void append(in_iter in_start, in_iter in_end) { | void append(in_iter in_start, in_iter in_end) { | ||||
| size_type num_inputs = std::distance(in_start, in_end); | size_type num_inputs = std::distance(in_start, in_end); | ||||
| // Grow allocated space if needed. | // Grow allocated space if needed. | ||||
| @@ -432,10 +409,11 @@ public: | |||||
| std::uninitialized_fill(this->begin(), this->end(), elm); | std::uninitialized_fill(this->begin(), this->end(), elm); | ||||
| } | } | ||||
| template <typename in_iter, | |||||
| typename = typename std::enable_if<std::is_convertible< | |||||
| typename std::iterator_traits<in_iter>::iterator_category, | |||||
| std::input_iterator_tag>::value>::type> | |||||
| template < | |||||
| typename in_iter, | |||||
| typename = typename std::enable_if<std::is_convertible< | |||||
| typename std::iterator_traits<in_iter>::iterator_category, | |||||
| std::input_iterator_tag>::value>::type> | |||||
| void assign(in_iter in_start, in_iter in_end) { | void assign(in_iter in_start, in_iter in_end) { | ||||
| clear(); | clear(); | ||||
| append(in_start, in_end); | append(in_start, in_end); | ||||
| @@ -571,8 +549,7 @@ public: | |||||
| std::fill_n(it, num_overwritten, elm); | std::fill_n(it, num_overwritten, elm); | ||||
| // Insert the non-overwritten middle part. | // Insert the non-overwritten middle part. | ||||
| std::uninitialized_fill_n(old_end, num_to_insert - num_overwritten, | |||||
| elm); | |||||
| std::uninitialized_fill_n(old_end, num_to_insert - num_overwritten, elm); | |||||
| return it; | return it; | ||||
| } | } | ||||
| @@ -646,8 +623,7 @@ public: | |||||
| if (megdnn_unlikely(this->m_end_ptr >= this->m_capacity_ptr)) { | if (megdnn_unlikely(this->m_end_ptr >= this->m_capacity_ptr)) { | ||||
| this->grow(); | this->grow(); | ||||
| } | } | ||||
| new (static_cast<void*>(this->end())) | |||||
| T(std::forward<ArgTypes>(args)...); | |||||
| new (static_cast<void*>(this->end())) T(std::forward<ArgTypes>(args)...); | |||||
| this->set_end(this->end() + 1); | this->set_end(this->end() + 1); | ||||
| } | } | ||||
| @@ -661,13 +637,11 @@ public: | |||||
| return std::equal(this->begin(), this->end(), rhs.begin()); | return std::equal(this->begin(), this->end(), rhs.begin()); | ||||
| } | } | ||||
| bool operator!=(const SmallVectorImpl<T>& rhs) const { | |||||
| return !(*this == rhs); | |||||
| } | |||||
| bool operator!=(const SmallVectorImpl<T>& rhs) const { return !(*this == rhs); } | |||||
| bool operator<(const SmallVectorImpl<T>& rhs) const { | bool operator<(const SmallVectorImpl<T>& rhs) const { | ||||
| return std::lexicographical_compare(this->begin(), this->end(), | |||||
| rhs.begin(), rhs.end()); | |||||
| return std::lexicographical_compare( | |||||
| this->begin(), this->end(), rhs.begin(), rhs.end()); | |||||
| } | } | ||||
| }; | }; | ||||
| @@ -698,15 +672,13 @@ void SmallVectorImpl<T>::swap(SmallVectorImpl<T>& rhs) { | |||||
| // Copy over the extra elms. | // Copy over the extra elms. | ||||
| if (this->size() > rhs.size()) { | if (this->size() > rhs.size()) { | ||||
| size_t elm_diff = this->size() - rhs.size(); | size_t elm_diff = this->size() - rhs.size(); | ||||
| this->uninitialized_move(this->begin() + num_shared, this->end(), | |||||
| rhs.end()); | |||||
| this->uninitialized_move(this->begin() + num_shared, this->end(), rhs.end()); | |||||
| rhs.set_end(rhs.end() + elm_diff); | rhs.set_end(rhs.end() + elm_diff); | ||||
| this->destroy_range(this->begin() + num_shared, this->end()); | this->destroy_range(this->begin() + num_shared, this->end()); | ||||
| this->set_end(this->begin() + num_shared); | this->set_end(this->begin() + num_shared); | ||||
| } else if (rhs.size() > this->size()) { | } else if (rhs.size() > this->size()) { | ||||
| size_t elm_diff = rhs.size() - this->size(); | size_t elm_diff = rhs.size() - this->size(); | ||||
| this->uninitialized_move(rhs.begin() + num_shared, rhs.end(), | |||||
| this->end()); | |||||
| this->uninitialized_move(rhs.begin() + num_shared, rhs.end(), this->end()); | |||||
| this->set_end(this->end() + elm_diff); | this->set_end(this->end() + elm_diff); | ||||
| this->destroy_range(rhs.begin() + num_shared, rhs.end()); | this->destroy_range(rhs.begin() + num_shared, rhs.end()); | ||||
| rhs.set_end(rhs.begin() + num_shared); | rhs.set_end(rhs.begin() + num_shared); | ||||
| @@ -714,8 +686,7 @@ void SmallVectorImpl<T>::swap(SmallVectorImpl<T>& rhs) { | |||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| SmallVectorImpl<T>& SmallVectorImpl<T>::operator=( | |||||
| const SmallVectorImpl<T>& rhs) { | |||||
| SmallVectorImpl<T>& SmallVectorImpl<T>::operator=(const SmallVectorImpl<T>& rhs) { | |||||
| if (this == &rhs) | if (this == &rhs) | ||||
| return *this; | return *this; | ||||
| size_t rhs_sz = rhs.size(); | size_t rhs_sz = rhs.size(); | ||||
| @@ -740,8 +711,7 @@ SmallVectorImpl<T>& SmallVectorImpl<T>::operator=( | |||||
| } else if (cur_sz) { | } else if (cur_sz) { | ||||
| std::copy(rhs.begin(), rhs.begin() + cur_sz, this->begin()); | std::copy(rhs.begin(), rhs.begin() + cur_sz, this->begin()); | ||||
| } | } | ||||
| std::uninitialized_copy(rhs.begin() + cur_sz, rhs.end(), | |||||
| this->begin() + cur_sz); | |||||
| std::uninitialized_copy(rhs.begin() + cur_sz, rhs.end(), this->begin() + cur_sz); | |||||
| this->set_end(this->begin() + rhs_sz); | this->set_end(this->begin() + rhs_sz); | ||||
| return *this; | return *this; | ||||
| } | } | ||||
| @@ -785,8 +755,7 @@ SmallVectorImpl<T>& SmallVectorImpl<T>::operator=(SmallVectorImpl<T>&& rhs) { | |||||
| std::move(rhs.begin(), rhs.begin() + cur_sz, this->begin()); | std::move(rhs.begin(), rhs.begin() + cur_sz, this->begin()); | ||||
| } | } | ||||
| this->uninitialized_move(rhs.begin() + cur_sz, rhs.end(), | |||||
| this->begin() + cur_sz); | |||||
| this->uninitialized_move(rhs.begin() + cur_sz, rhs.end(), this->begin() + cur_sz); | |||||
| this->set_end(this->begin() + rhs_sz); | this->set_end(this->begin() + rhs_sz); | ||||
| @@ -826,8 +795,7 @@ class SmallVector : public SmallVectorImpl<T> { | |||||
| public: | public: | ||||
| SmallVector() : SmallVectorImpl<T>(N) {} | SmallVector() : SmallVectorImpl<T>(N) {} | ||||
| explicit SmallVector(size_t size, const T& value = T()) | |||||
| : SmallVectorImpl<T>(N) { | |||||
| explicit SmallVector(size_t size, const T& value = T()) : SmallVectorImpl<T>(N) { | |||||
| this->assign(size, value); | this->assign(size, value); | ||||
| } | } | ||||
| @@ -901,15 +869,13 @@ namespace std { | |||||
| /// Implement std::swap in terms of SmallVector swap. | /// Implement std::swap in terms of SmallVector swap. | ||||
| template <typename T> | template <typename T> | ||||
| inline void swap(megdnn::SmallVectorImpl<T>& lhs, | |||||
| megdnn::SmallVectorImpl<T>& rhs) { | |||||
| inline void swap(megdnn::SmallVectorImpl<T>& lhs, megdnn::SmallVectorImpl<T>& rhs) { | |||||
| lhs.swap(rhs); | lhs.swap(rhs); | ||||
| } | } | ||||
| /// Implement std::swap in terms of SmallVector swap. | /// Implement std::swap in terms of SmallVector swap. | ||||
| template <typename T, unsigned N> | template <typename T, unsigned N> | ||||
| inline void swap(megdnn::SmallVector<T, N>& lhs, | |||||
| megdnn::SmallVector<T, N>& rhs) { | |||||
| inline void swap(megdnn::SmallVector<T, N>& lhs, megdnn::SmallVector<T, N>& rhs) { | |||||
| lhs.swap(rhs); | lhs.swap(rhs); | ||||
| } | } | ||||
| } // end namespace std | } // end namespace std | ||||
| @@ -17,13 +17,13 @@ | |||||
| #include "megdnn/internal/visibility_prologue.h" | #include "megdnn/internal/visibility_prologue.h" | ||||
| namespace megdnn { | namespace megdnn { | ||||
| struct Version { | |||||
| int major, minor, patch; | |||||
| }; | |||||
| struct Version { | |||||
| int major, minor, patch; | |||||
| }; | |||||
| //! get megdnn version of the binary | |||||
| Version get_version(); | |||||
| } | |||||
| //! get megdnn version of the binary | |||||
| Version get_version(); | |||||
| } // namespace megdnn | |||||
| #include "megdnn/internal/visibility_epilogue.h" | #include "megdnn/internal/visibility_epilogue.h" | ||||
| @@ -22,18 +22,17 @@ using namespace aarch64; | |||||
| /* ===================== stride-2 algo ===================== */ | /* ===================== stride-2 algo ===================== */ | ||||
| MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp16) | MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp16) | ||||
| bool ConvBiasImpl::AlgoF16DirectStride2::usable(const NCBKernSizeParam& param, | |||||
| AlgoSelectionStrategy) const { | |||||
| bool ConvBiasImpl::AlgoF16DirectStride2::usable( | |||||
| const NCBKernSizeParam& param, AlgoSelectionStrategy) const { | |||||
| MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 0) { | MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 0) { | ||||
| auto&& fm = param.filter_meta; | auto&& fm = param.filter_meta; | ||||
| auto FH = fm.spatial[0]; | auto FH = fm.spatial[0]; | ||||
| return param.filter_meta.format == param::Convolution::Format::NCHW && | return param.filter_meta.format == param::Convolution::Format::NCHW && | ||||
| param.src_type.enumv() == DTypeEnum::Float16 && | param.src_type.enumv() == DTypeEnum::Float16 && | ||||
| param.filter_type.enumv() == DTypeEnum::Float16 && | param.filter_type.enumv() == DTypeEnum::Float16 && | ||||
| param.dst_type.enumv() == DTypeEnum::Float16 && | |||||
| !fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||||
| fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && | |||||
| FH == fm.spatial[1] && | |||||
| param.dst_type.enumv() == DTypeEnum::Float16 && !fm.should_flip && | |||||
| fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||||
| fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && | |||||
| (FH == 2 || FH == 3 || FH == 5 || FH == 7); | (FH == 2 || FH == 3 || FH == 5 || FH == 7); | ||||
| } | } | ||||
| MIDOUT_END(); | MIDOUT_END(); | ||||
| @@ -52,8 +51,7 @@ size_t ConvBiasImpl::AlgoF16DirectStride2::get_workspace( | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| SmallVector<ConvBiasImpl::NCBKern> | |||||
| ConvBiasImpl::AlgoF16DirectStride2::dispatch_kerns( | |||||
| SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16DirectStride2::dispatch_kerns( | |||||
| const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
| MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 2) { | MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 2) { | ||||
| return get_kimpls(param); | return get_kimpls(param); | ||||
| @@ -62,8 +60,7 @@ ConvBiasImpl::AlgoF16DirectStride2::dispatch_kerns( | |||||
| return {}; | return {}; | ||||
| } | } | ||||
| SmallVector<ConvBiasImpl::NCBKern> | |||||
| ConvBiasImpl::AlgoF16DirectStride2::get_kimpls( | |||||
| SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16DirectStride2::get_kimpls( | |||||
| const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
| auto fm = param.filter_meta; | auto fm = param.filter_meta; | ||||
| auto FH = fm.spatial[0]; | auto FH = fm.spatial[0]; | ||||
| @@ -72,8 +69,9 @@ ConvBiasImpl::AlgoF16DirectStride2::get_kimpls( | |||||
| size_t OC = param.filter_meta.ocpg; | size_t OC = param.filter_meta.ocpg; | ||||
| size_t group = fm.group; | size_t group = fm.group; | ||||
| bool large_group = group >= param.nr_threads; | bool large_group = group >= param.nr_threads; | ||||
| using Func = std::function<void(const __fp16*, const __fp16*, __fp16*, | |||||
| size_t, size_t, size_t, size_t, size_t)>; | |||||
| using Func = std::function<void( | |||||
| const __fp16*, const __fp16*, __fp16*, size_t, size_t, size_t, size_t, | |||||
| size_t)>; | |||||
| Func conv = nullptr; | Func conv = nullptr; | ||||
| if (FH == 2) { | if (FH == 2) { | ||||
| conv = fp16::conv_stride2::do_conv_2x2_stride2; | conv = fp16::conv_stride2::do_conv_2x2_stride2; | ||||
| @@ -101,31 +99,35 @@ ConvBiasImpl::AlgoF16DirectStride2::get_kimpls( | |||||
| bundle.set(kern_param.workspace_ptr); | bundle.set(kern_param.workspace_ptr); | ||||
| for (size_t ic = 0; ic < IC; ic++) { | for (size_t ic = 0; ic < IC; ic++) { | ||||
| arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: | arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: | ||||
| copy_padding_kern_stride(bundle, kern_param, ncb_index, | |||||
| {ncb_index.thread_id, 0, ic}); | |||||
| copy_padding_kern_stride( | |||||
| bundle, kern_param, ncb_index, | |||||
| {ncb_index.thread_id, 0, ic}); | |||||
| } | } | ||||
| for (size_t oc = 0; oc < OC; oc++) { | for (size_t oc = 0; oc < OC; oc++) { | ||||
| arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: | arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: | ||||
| do_conv_kern_stride(bundle, kern_param, ncb_index, conv, | |||||
| {ncb_index.thread_id, 0, oc}); | |||||
| do_conv_kern_stride( | |||||
| bundle, kern_param, ncb_index, conv, | |||||
| {ncb_index.thread_id, 0, oc}); | |||||
| } | } | ||||
| }; | }; | ||||
| ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); | ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); | ||||
| } else { | } else { | ||||
| auto copy_padding = [bundle](const NCBKernParam& kern_param, | |||||
| const NCBKernIndex& ncb_index) mutable { | |||||
| auto copy_padding = [bundle]( | |||||
| const NCBKernParam& kern_param, | |||||
| const NCBKernIndex& ncb_index) mutable { | |||||
| bundle.set(kern_param.workspace_ptr); | bundle.set(kern_param.workspace_ptr); | ||||
| arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: | arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: | ||||
| copy_padding_kern_stride(bundle, kern_param, ncb_index, | |||||
| ncb_index.ndrange_id); | |||||
| copy_padding_kern_stride( | |||||
| bundle, kern_param, ncb_index, ncb_index.ndrange_id); | |||||
| }; | }; | ||||
| ret_kerns.push_back({copy_padding, {group, N, IC}}); | ret_kerns.push_back({copy_padding, {group, N, IC}}); | ||||
| auto do_conv = [bundle, conv](const NCBKernParam& kern_param, | |||||
| const NCBKernIndex& ncb_index) mutable { | |||||
| auto do_conv = [bundle, conv]( | |||||
| const NCBKernParam& kern_param, | |||||
| const NCBKernIndex& ncb_index) mutable { | |||||
| bundle.set(kern_param.workspace_ptr); | bundle.set(kern_param.workspace_ptr); | ||||
| arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: | arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: | ||||
| do_conv_kern_stride(bundle, kern_param, ncb_index, conv, | |||||
| ncb_index.ndrange_id); | |||||
| do_conv_kern_stride( | |||||
| bundle, kern_param, ncb_index, conv, ncb_index.ndrange_id); | |||||
| }; | }; | ||||
| ret_kerns.push_back({do_conv, {group, N, OC}}); | ret_kerns.push_back({do_conv, {group, N, OC}}); | ||||
| } | } | ||||
| @@ -18,13 +18,13 @@ namespace aarch64 { | |||||
| /* ===================== stride-2 algo ===================== */ | /* ===================== stride-2 algo ===================== */ | ||||
| class ConvBiasImpl::AlgoF16DirectStride2 final : public AlgoBase { | class ConvBiasImpl::AlgoF16DirectStride2 final : public AlgoBase { | ||||
| SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | |||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| } | |||||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
| const char* name() const override { return "ARMV8F16STRD2"; } | const char* name() const override { return "ARMV8F16STRD2"; } | ||||
| bool usable(const NCBKernSizeParam& param, | |||||
| AlgoSelectionStrategy algo_selection_strategy) const override; | |||||
| bool usable( | |||||
| const NCBKernSizeParam& param, | |||||
| AlgoSelectionStrategy algo_selection_strategy) const override; | |||||
| size_t get_workspace(const NCBKernSizeParam& param) const override; | size_t get_workspace(const NCBKernSizeParam& param) const override; | ||||
| @@ -20,9 +20,9 @@ namespace aarch64 { | |||||
| namespace fp16 { | namespace fp16 { | ||||
| namespace conv_stride2 { | namespace conv_stride2 { | ||||
| static void do_conv_2x2_stride2(const __fp16* src, const __fp16* filter, | |||||
| __fp16* dst, size_t IH, size_t IW, size_t OH, | |||||
| size_t OW, size_t IC) { | |||||
| static void do_conv_2x2_stride2( | |||||
| const __fp16* src, const __fp16* filter, __fp16* dst, size_t IH, size_t IW, | |||||
| size_t OH, size_t OW, size_t IC) { | |||||
| const size_t tail_step = IW - 2 * OW + IW; | const size_t tail_step = IW - 2 * OW + IW; | ||||
| size_t width = OW >> 3; | size_t width = OW >> 3; | ||||
| size_t mod4_left = width & 3; | size_t mod4_left = width & 3; | ||||
| @@ -162,10 +162,9 @@ static void do_conv_2x2_stride2(const __fp16* src, const __fp16* filter, | |||||
| "5: \n" | "5: \n" | ||||
| : "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1) | : "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1) | ||||
| : "r"(mod4_left), "w"(_k0123) | : "r"(mod4_left), "w"(_k0123) | ||||
| : "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", | |||||
| "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||||
| "v15", "v16", "v17", "v18", "v19", "v28", "v29", "v30", | |||||
| "v31"); | |||||
| : "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", | |||||
| "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", | |||||
| "v17", "v18", "v19", "v28", "v29", "v30", "v31"); | |||||
| r0 += tail_step; | r0 += tail_step; | ||||
| r1 += tail_step; | r1 += tail_step; | ||||
| @@ -175,9 +174,9 @@ static void do_conv_2x2_stride2(const __fp16* src, const __fp16* filter, | |||||
| } | } | ||||
| } | } | ||||
| static void do_conv_3x3_stride2(const __fp16* src, const __fp16* filter, | |||||
| __fp16* dst, size_t IH, size_t IW, size_t OH, | |||||
| size_t OW, size_t IC) { | |||||
| static void do_conv_3x3_stride2( | |||||
| const __fp16* src, const __fp16* filter, __fp16* dst, size_t IH, size_t IW, | |||||
| size_t OH, size_t OW, size_t IC) { | |||||
| const size_t tail_step = IW - 2 * OW + IW; | const size_t tail_step = IW - 2 * OW + IW; | ||||
| size_t width = OW >> 3; | size_t width = OW >> 3; | ||||
| size_t mod3_left = width % 3; | size_t mod3_left = width % 3; | ||||
| @@ -352,10 +351,10 @@ static void do_conv_3x3_stride2(const __fp16* src, const __fp16* filter, | |||||
| "3: \n" | "3: \n" | ||||
| : "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2) | : "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2) | ||||
| : "r"(mod3_left), "w"(_k0123), "w"(_k3456), "w"(_k5678) | : "r"(mod3_left), "w"(_k0123), "w"(_k3456), "w"(_k5678) | ||||
| : "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", | |||||
| "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||||
| "v15", "v16", "v17", "v18", "v21", "v22", "v23", "v24", | |||||
| "v25", "v26", "v27", "v28", "v29"); | |||||
| : "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", | |||||
| "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", | |||||
| "v17", "v18", "v21", "v22", "v23", "v24", "v25", "v26", "v27", | |||||
| "v28", "v29"); | |||||
| r0 += tail_step; | r0 += tail_step; | ||||
| r1 += tail_step; | r1 += tail_step; | ||||
| @@ -366,9 +365,9 @@ static void do_conv_3x3_stride2(const __fp16* src, const __fp16* filter, | |||||
| } | } | ||||
| } | } | ||||
| static void do_conv_5x5_stride2(const __fp16* src, const __fp16* filter, | |||||
| __fp16* dst, size_t IH, size_t IW, size_t OH, | |||||
| size_t OW, size_t IC) { | |||||
| static void do_conv_5x5_stride2( | |||||
| const __fp16* src, const __fp16* filter, __fp16* dst, size_t IH, size_t IW, | |||||
| size_t OH, size_t OW, size_t IC) { | |||||
| const size_t tail_step = IW - 2 * OW + IW; | const size_t tail_step = IW - 2 * OW + IW; | ||||
| size_t width = OW >> 3; | size_t width = OW >> 3; | ||||
| size_t mod2_left = width & 1; | size_t mod2_left = width & 1; | ||||
| @@ -384,18 +383,12 @@ static void do_conv_5x5_stride2(const __fp16* src, const __fp16* filter, | |||||
| const __fp16* r4 = src_ptr + IW * 4; | const __fp16* r4 = src_ptr + IW * 4; | ||||
| register MEGDNN_SIMD_TYPE _k0123 asm("v0") = MEGDNN_SIMD_LOADU(filter); | register MEGDNN_SIMD_TYPE _k0123 asm("v0") = MEGDNN_SIMD_LOADU(filter); | ||||
| register MEGDNN_SIMD_TYPE _k4567 asm("v1") = | |||||
| MEGDNN_SIMD_LOADU(filter + 4); | |||||
| register MEGDNN_SIMD_TYPE _k891011 asm("v2") = | |||||
| MEGDNN_SIMD_LOADU(filter + 8); | |||||
| register MEGDNN_SIMD_TYPE _k12131415 asm("v3") = | |||||
| MEGDNN_SIMD_LOADU(filter + 12); | |||||
| register MEGDNN_SIMD_TYPE _k16171819 asm("v4") = | |||||
| MEGDNN_SIMD_LOADU(filter + 16); | |||||
| register MEGDNN_SIMD_TYPE _k20212223 asm("v5") = | |||||
| MEGDNN_SIMD_LOADU(filter + 20); | |||||
| register MEGDNN_SIMD_TYPE _k24242424 asm("v6") = | |||||
| MEGDNN_SIMD_SET1(filter[24]); | |||||
| register MEGDNN_SIMD_TYPE _k4567 asm("v1") = MEGDNN_SIMD_LOADU(filter + 4); | |||||
| register MEGDNN_SIMD_TYPE _k891011 asm("v2") = MEGDNN_SIMD_LOADU(filter + 8); | |||||
| register MEGDNN_SIMD_TYPE _k12131415 asm("v3") = MEGDNN_SIMD_LOADU(filter + 12); | |||||
| register MEGDNN_SIMD_TYPE _k16171819 asm("v4") = MEGDNN_SIMD_LOADU(filter + 16); | |||||
| register MEGDNN_SIMD_TYPE _k20212223 asm("v5") = MEGDNN_SIMD_LOADU(filter + 20); | |||||
| register MEGDNN_SIMD_TYPE _k24242424 asm("v6") = MEGDNN_SIMD_SET1(filter[24]); | |||||
| for (size_t i = 0; i < OH; i++) { | for (size_t i = 0; i < OH; i++) { | ||||
| asm volatile( | asm volatile( | ||||
| @@ -592,15 +585,14 @@ static void do_conv_5x5_stride2(const __fp16* src, const __fp16* filter, | |||||
| "bne 2b \n" | "bne 2b \n" | ||||
| "3: \n" | "3: \n" | ||||
| : "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), | |||||
| "+r"(r3), "+r"(r4) | |||||
| : "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3), | |||||
| "+r"(r4) | |||||
| : "w"(_k0123), "w"(_k4567), "w"(_k891011), "w"(_k12131415), | : "w"(_k0123), "w"(_k4567), "w"(_k891011), "w"(_k12131415), | ||||
| "w"(_k16171819), "w"(_k20212223), "w"(_k24242424), | |||||
| "r"(mod2_left) | |||||
| : "cc", "memory", "x1", "v7", "v8", "v9", "v10", "v11", | |||||
| "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||||
| "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", | |||||
| "v28", "v29", "v30", "v31"); | |||||
| "w"(_k16171819), "w"(_k20212223), "w"(_k24242424), "r"(mod2_left) | |||||
| : "cc", "memory", "x1", "v7", "v8", "v9", "v10", "v11", "v12", | |||||
| "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||||
| "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", | |||||
| "v31"); | |||||
| r0 += tail_step; | r0 += tail_step; | ||||
| r1 += tail_step; | r1 += tail_step; | ||||
| @@ -613,9 +605,9 @@ static void do_conv_5x5_stride2(const __fp16* src, const __fp16* filter, | |||||
| } | } | ||||
| } | } | ||||
| static void do_conv_7x7_stride2(const __fp16* src, const __fp16* filter, | |||||
| __fp16* dst, size_t IH, size_t IW, size_t OH, | |||||
| size_t OW, size_t IC) { | |||||
| static void do_conv_7x7_stride2( | |||||
| const __fp16* src, const __fp16* filter, __fp16* dst, size_t IH, size_t IW, | |||||
| size_t OH, size_t OW, size_t IC) { | |||||
| const size_t tail_step = IW - 2 * OW + IW; | const size_t tail_step = IW - 2 * OW + IW; | ||||
| size_t width = OW >> 3; | size_t width = OW >> 3; | ||||
| @@ -632,30 +624,20 @@ static void do_conv_7x7_stride2(const __fp16* src, const __fp16* filter, | |||||
| const __fp16* r6 = src_ptr + IW * 6; | const __fp16* r6 = src_ptr + IW * 6; | ||||
| register MEGDNN_SIMD_TYPE _k0123 asm("v0") = MEGDNN_SIMD_LOADU(filter); | register MEGDNN_SIMD_TYPE _k0123 asm("v0") = MEGDNN_SIMD_LOADU(filter); | ||||
| register MEGDNN_SIMD_TYPE _k4567 asm("v1") = | |||||
| MEGDNN_SIMD_LOADU(filter + 4); | |||||
| register MEGDNN_SIMD_TYPE _k891011 asm("v2") = | |||||
| MEGDNN_SIMD_LOADU(filter + 8); | |||||
| register MEGDNN_SIMD_TYPE _k12131415 asm("v3") = | |||||
| MEGDNN_SIMD_LOADU(filter + 12); | |||||
| register MEGDNN_SIMD_TYPE _k16171819 asm("v4") = | |||||
| MEGDNN_SIMD_LOADU(filter + 16); | |||||
| register MEGDNN_SIMD_TYPE _k20212223 asm("v5") = | |||||
| MEGDNN_SIMD_LOADU(filter + 20); | |||||
| register MEGDNN_SIMD_TYPE _k24252627 asm("v6") = | |||||
| MEGDNN_SIMD_LOADU(filter + 24); | |||||
| register MEGDNN_SIMD_TYPE _k28293031 asm("v7") = | |||||
| MEGDNN_SIMD_LOADU(filter + 28); | |||||
| register MEGDNN_SIMD_TYPE _k32333435 asm("v8") = | |||||
| MEGDNN_SIMD_LOADU(filter + 32); | |||||
| register MEGDNN_SIMD_TYPE _k36373839 asm("v9") = | |||||
| MEGDNN_SIMD_LOADU(filter + 36); | |||||
| register MEGDNN_SIMD_TYPE _k4567 asm("v1") = MEGDNN_SIMD_LOADU(filter + 4); | |||||
| register MEGDNN_SIMD_TYPE _k891011 asm("v2") = MEGDNN_SIMD_LOADU(filter + 8); | |||||
| register MEGDNN_SIMD_TYPE _k12131415 asm("v3") = MEGDNN_SIMD_LOADU(filter + 12); | |||||
| register MEGDNN_SIMD_TYPE _k16171819 asm("v4") = MEGDNN_SIMD_LOADU(filter + 16); | |||||
| register MEGDNN_SIMD_TYPE _k20212223 asm("v5") = MEGDNN_SIMD_LOADU(filter + 20); | |||||
| register MEGDNN_SIMD_TYPE _k24252627 asm("v6") = MEGDNN_SIMD_LOADU(filter + 24); | |||||
| register MEGDNN_SIMD_TYPE _k28293031 asm("v7") = MEGDNN_SIMD_LOADU(filter + 28); | |||||
| register MEGDNN_SIMD_TYPE _k32333435 asm("v8") = MEGDNN_SIMD_LOADU(filter + 32); | |||||
| register MEGDNN_SIMD_TYPE _k36373839 asm("v9") = MEGDNN_SIMD_LOADU(filter + 36); | |||||
| register MEGDNN_SIMD_TYPE _k40414243 asm("v10") = | register MEGDNN_SIMD_TYPE _k40414243 asm("v10") = | ||||
| MEGDNN_SIMD_LOADU(filter + 40); | MEGDNN_SIMD_LOADU(filter + 40); | ||||
| register MEGDNN_SIMD_TYPE _k44454647 asm("v11") = | register MEGDNN_SIMD_TYPE _k44454647 asm("v11") = | ||||
| MEGDNN_SIMD_LOADU(filter + 44); | MEGDNN_SIMD_LOADU(filter + 44); | ||||
| register MEGDNN_SIMD_TYPE _k48484848 asm("v12") = | |||||
| MEGDNN_SIMD_SET1(filter[48]); | |||||
| register MEGDNN_SIMD_TYPE _k48484848 asm("v12") = MEGDNN_SIMD_SET1(filter[48]); | |||||
| for (size_t i = 0; i < OH; i++) { | for (size_t i = 0; i < OH; i++) { | ||||
| asm volatile( | asm volatile( | ||||
| @@ -1005,16 +987,15 @@ static void do_conv_7x7_stride2(const __fp16* src, const __fp16* filter, | |||||
| "bne 2b \n" | "bne 2b \n" | ||||
| "3: \n" | "3: \n" | ||||
| : "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3), | |||||
| "+r"(r4), "+r"(r5), "+r"(r6) | |||||
| : "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3), "+r"(r4), | |||||
| "+r"(r5), "+r"(r6) | |||||
| : "r"(width), "w"(_k0123), "w"(_k4567), "w"(_k891011), | : "r"(width), "w"(_k0123), "w"(_k4567), "w"(_k891011), | ||||
| "w"(_k12131415), "w"(_k16171819), "w"(_k20212223), | "w"(_k12131415), "w"(_k16171819), "w"(_k20212223), | ||||
| "w"(_k24252627), "w"(_k28293031), "w"(_k32333435), | "w"(_k24252627), "w"(_k28293031), "w"(_k32333435), | ||||
| "w"(_k36373839), "w"(_k40414243), "w"(_k44454647), | |||||
| "w"(_k48484848) | |||||
| : "cc", "memory", "x1", "v13", "v14", "v15", "v16", "v17", | |||||
| "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", | |||||
| "v26", "v27", "v28", "v29", "v30", "v31"); | |||||
| "w"(_k36373839), "w"(_k40414243), "w"(_k44454647), "w"(_k48484848) | |||||
| : "cc", "memory", "x1", "v13", "v14", "v15", "v16", "v17", "v18", | |||||
| "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", | |||||
| "v28", "v29", "v30", "v31"); | |||||
| r0 += tail_step; | r0 += tail_step; | ||||
| r1 += tail_step; | r1 += tail_step; | ||||
| @@ -21,18 +21,17 @@ using namespace megdnn; | |||||
| using namespace aarch64; | using namespace aarch64; | ||||
| MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp32) | MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp32) | ||||
| bool ConvBiasImpl::AlgoF32DirectStride2::usable(const NCBKernSizeParam& param, | |||||
| AlgoSelectionStrategy) const { | |||||
| bool ConvBiasImpl::AlgoF32DirectStride2::usable( | |||||
| const NCBKernSizeParam& param, AlgoSelectionStrategy) const { | |||||
| MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 0) { | MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 0) { | ||||
| auto&& fm = param.filter_meta; | auto&& fm = param.filter_meta; | ||||
| auto FH = fm.spatial[0]; | auto FH = fm.spatial[0]; | ||||
| return param.filter_meta.format == param::ConvBias::Format::NCHW && | return param.filter_meta.format == param::ConvBias::Format::NCHW && | ||||
| param.src_type.enumv() == DTypeEnum::Float32 && | param.src_type.enumv() == DTypeEnum::Float32 && | ||||
| param.filter_type.enumv() == DTypeEnum::Float32 && | param.filter_type.enumv() == DTypeEnum::Float32 && | ||||
| param.dst_type.enumv() == DTypeEnum::Float32 && | |||||
| !fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||||
| fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && | |||||
| FH == fm.spatial[1] && | |||||
| param.dst_type.enumv() == DTypeEnum::Float32 && !fm.should_flip && | |||||
| fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||||
| fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && | |||||
| (FH == 2 || FH == 3 || FH == 5 || FH == 7); | (FH == 2 || FH == 3 || FH == 5 || FH == 7); | ||||
| } | } | ||||
| MIDOUT_END(); | MIDOUT_END(); | ||||
| @@ -50,8 +49,7 @@ size_t ConvBiasImpl::AlgoF32DirectStride2::get_workspace( | |||||
| MIDOUT_END(); | MIDOUT_END(); | ||||
| return 0; | return 0; | ||||
| } | } | ||||
| SmallVector<ConvBiasImpl::NCBKern> | |||||
| ConvBiasImpl::AlgoF32DirectStride2::dispatch_kerns( | |||||
| SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride2::dispatch_kerns( | |||||
| const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
| MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 2) { | MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 2) { | ||||
| return get_kimpls(param); | return get_kimpls(param); | ||||
| @@ -60,8 +58,7 @@ ConvBiasImpl::AlgoF32DirectStride2::dispatch_kerns( | |||||
| return {}; | return {}; | ||||
| } | } | ||||
| SmallVector<ConvBiasImpl::NCBKern> | |||||
| ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( | |||||
| SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( | |||||
| const NCBKernSizeParam& param) const { | const NCBKernSizeParam& param) const { | ||||
| auto fm = param.filter_meta; | auto fm = param.filter_meta; | ||||
| auto FH = fm.spatial[0]; | auto FH = fm.spatial[0]; | ||||
| @@ -70,8 +67,9 @@ ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( | |||||
| size_t OC = param.filter_meta.ocpg; | size_t OC = param.filter_meta.ocpg; | ||||
| size_t group = fm.group; | size_t group = fm.group; | ||||
| bool large_group = group >= param.nr_threads; | bool large_group = group >= param.nr_threads; | ||||
| using Func = std::function<void(const float*, const float*, float*, size_t, | |||||
| size_t, size_t, size_t, size_t)>; | |||||
| using Func = std::function<void( | |||||
| const float*, const float*, float*, size_t, size_t, size_t, size_t, | |||||
| size_t)>; | |||||
| Func conv = nullptr; | Func conv = nullptr; | ||||
| if (FH == 2) { | if (FH == 2) { | ||||
| conv = fp32::conv_stride2::do_conv_2x2_stride2; | conv = fp32::conv_stride2::do_conv_2x2_stride2; | ||||
| @@ -83,8 +81,9 @@ ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( | |||||
| conv = fp32::conv_stride2::do_conv_7x7_stride2; | conv = fp32::conv_stride2::do_conv_7x7_stride2; | ||||
| } | } | ||||
| WorkspaceBundle bundle = arm_common::MultithreadDirectConvCommon< | |||||
| float, float>::get_bundle_stride(param, large_group); | |||||
| WorkspaceBundle bundle = | |||||
| arm_common::MultithreadDirectConvCommon<float, float>::get_bundle_stride( | |||||
| param, large_group); | |||||
| SmallVector<NCBKern> ret_kerns; | SmallVector<NCBKern> ret_kerns; | ||||
| //! Dense conv and small group | //! Dense conv and small group | ||||
| @@ -99,34 +98,34 @@ ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( | |||||
| bundle.set(kern_param.workspace_ptr); | bundle.set(kern_param.workspace_ptr); | ||||
| for (size_t ic = 0; ic < IC; ic++) { | for (size_t ic = 0; ic < IC; ic++) { | ||||
| arm_common::MultithreadDirectConvCommon<float, float>:: | arm_common::MultithreadDirectConvCommon<float, float>:: | ||||
| copy_padding_kern_stride(bundle, kern_param, ncb_index, | |||||
| {ncb_index.thread_id, 0, ic}); | |||||
| copy_padding_kern_stride( | |||||
| bundle, kern_param, ncb_index, | |||||
| {ncb_index.thread_id, 0, ic}); | |||||
| } | } | ||||
| for (size_t oc = 0; oc < OC; oc++) { | for (size_t oc = 0; oc < OC; oc++) { | ||||
| arm_common::MultithreadDirectConvCommon< | |||||
| float, float>::do_conv_kern_stride(bundle, kern_param, | |||||
| ncb_index, conv, | |||||
| {ncb_index.thread_id, | |||||
| 0, oc}); | |||||
| arm_common::MultithreadDirectConvCommon<float, float>:: | |||||
| do_conv_kern_stride( | |||||
| bundle, kern_param, ncb_index, conv, | |||||
| {ncb_index.thread_id, 0, oc}); | |||||
| } | } | ||||
| }; | }; | ||||
| ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); | ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); | ||||
| } else { | } else { | ||||
| auto copy_padding = [bundle](const NCBKernParam& kern_param, | |||||
| const NCBKernIndex& ncb_index) mutable { | |||||
| auto copy_padding = [bundle]( | |||||
| const NCBKernParam& kern_param, | |||||
| const NCBKernIndex& ncb_index) mutable { | |||||
| bundle.set(kern_param.workspace_ptr); | bundle.set(kern_param.workspace_ptr); | ||||
| arm_common::MultithreadDirectConvCommon<float, float>:: | arm_common::MultithreadDirectConvCommon<float, float>:: | ||||
| copy_padding_kern_stride(bundle, kern_param, ncb_index, | |||||
| ncb_index.ndrange_id); | |||||
| copy_padding_kern_stride( | |||||
| bundle, kern_param, ncb_index, ncb_index.ndrange_id); | |||||
| }; | }; | ||||
| ret_kerns.push_back({copy_padding, {group, N, IC}}); | ret_kerns.push_back({copy_padding, {group, N, IC}}); | ||||
| auto do_conv = [bundle, conv](const NCBKernParam& kern_param, | |||||
| const NCBKernIndex& ncb_index) mutable { | |||||
| auto do_conv = [bundle, conv]( | |||||
| const NCBKernParam& kern_param, | |||||
| const NCBKernIndex& ncb_index) mutable { | |||||
| bundle.set(kern_param.workspace_ptr); | bundle.set(kern_param.workspace_ptr); | ||||
| arm_common::MultithreadDirectConvCommon< | |||||
| float, float>::do_conv_kern_stride(bundle, kern_param, | |||||
| ncb_index, conv, | |||||
| ncb_index.ndrange_id); | |||||
| arm_common::MultithreadDirectConvCommon<float, float>::do_conv_kern_stride( | |||||
| bundle, kern_param, ncb_index, conv, ncb_index.ndrange_id); | |||||
| }; | }; | ||||
| ret_kerns.push_back({do_conv, {group, N, OC}}); | ret_kerns.push_back({do_conv, {group, N, OC}}); | ||||
| } | } | ||||
| @@ -22,14 +22,14 @@ using FallbackConvBiasImpl = fallback::ConvBiasImpl; | |||||
| class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase { | class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase { | ||||
| SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | |||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| } | |||||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
| const char* name() const override { return "ARMV8F32STRD2"; } | const char* name() const override { return "ARMV8F32STRD2"; } | ||||
| bool usable(const NCBKernSizeParam& param, | |||||
| AlgoSelectionStrategy algo_selection_strategy) const override; | |||||
| bool usable( | |||||
| const NCBKernSizeParam& param, | |||||
| AlgoSelectionStrategy algo_selection_strategy) const override; | |||||
| size_t get_workspace(const NCBKernSizeParam& param) const override; | size_t get_workspace(const NCBKernSizeParam& param) const override; | ||||
| @@ -16,16 +16,15 @@ | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace aarch64 { | namespace aarch64 { | ||||
| namespace fp32{ | |||||
| namespace fp32 { | |||||
| namespace conv_stride2 { | namespace conv_stride2 { | ||||
| //! For the detail tune process, refer to `expr/conv_aarch64_stride2/main.cpp` | //! For the detail tune process, refer to `expr/conv_aarch64_stride2/main.cpp` | ||||
| // refer to function do_conv_2x2_stride2_asm_unroll4 | // refer to function do_conv_2x2_stride2_asm_unroll4 | ||||
| static void do_conv_2x2_stride2(const float* src, const float* filter, | |||||
| float* dst, size_t IH, size_t IW, size_t OH, | |||||
| size_t OW, size_t IC) { | |||||
| static void do_conv_2x2_stride2( | |||||
| const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||||
| size_t OH, size_t OW, size_t IC) { | |||||
| const size_t tail_step = IW - 2 * OW + IW; | const size_t tail_step = IW - 2 * OW + IW; | ||||
| size_t width = OW >> 2; | size_t width = OW >> 2; | ||||
| size_t mod4_left = width & 3; | size_t mod4_left = width & 3; | ||||
| @@ -165,10 +164,9 @@ static void do_conv_2x2_stride2(const float* src, const float* filter, | |||||
| "5: \n" | "5: \n" | ||||
| : "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1) | : "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1) | ||||
| : "r"(mod4_left), "w"(_k0123) | : "r"(mod4_left), "w"(_k0123) | ||||
| : "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", | |||||
| "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||||
| "v15", "v16", "v17", "v18", "v19", "v28", "v29", "v30", | |||||
| "v31"); | |||||
| : "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", | |||||
| "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", | |||||
| "v17", "v18", "v19", "v28", "v29", "v30", "v31"); | |||||
| r0 += tail_step; | r0 += tail_step; | ||||
| r1 += tail_step; | r1 += tail_step; | ||||
| @@ -179,9 +177,9 @@ static void do_conv_2x2_stride2(const float* src, const float* filter, | |||||
| } | } | ||||
| // refer to function do_conv_3x3_stride2_asm_unroll3 | // refer to function do_conv_3x3_stride2_asm_unroll3 | ||||
| static void do_conv_3x3_stride2(const float* src, const float* filter, | |||||
| float* dst, size_t IH, size_t IW, size_t OH, | |||||
| size_t OW, size_t IC) { | |||||
| static void do_conv_3x3_stride2( | |||||
| const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||||
| size_t OH, size_t OW, size_t IC) { | |||||
| const size_t tail_step = IW - 2 * OW + IW; | const size_t tail_step = IW - 2 * OW + IW; | ||||
| size_t width = OW >> 2; | size_t width = OW >> 2; | ||||
| size_t mod3_left = width % 3; | size_t mod3_left = width % 3; | ||||
| @@ -269,7 +267,7 @@ static void do_conv_3x3_stride2(const float* src, const float* filter, | |||||
| "ld2 {v1.4s, v2.4s}, [%2], #32 \n" // 0, 2, 4, 6 | "ld2 {v1.4s, v2.4s}, [%2], #32 \n" // 0, 2, 4, 6 | ||||
| "ld2 {v5.4s, v6.4s}, [%3], #32 \n" | "ld2 {v5.4s, v6.4s}, [%3], #32 \n" | ||||
| "ld1 {v3.4s}, [%2] \n" // load src 8 12 ... | |||||
| "ld1 {v3.4s}, [%2] \n" // load src 8 12 ... | |||||
| "fmla v0.4s, v1.4s, v21.4s \n" // src[i] * k[i] | "fmla v0.4s, v1.4s, v21.4s \n" // src[i] * k[i] | ||||
| "ext v7.16b, v1.16b, v3.16b, #4 \n" // 2, 4, 6, 8 | "ext v7.16b, v1.16b, v3.16b, #4 \n" // 2, 4, 6, 8 | ||||
| "fmla v0.4s, v2.4s, v22.4s \n" | "fmla v0.4s, v2.4s, v22.4s \n" | ||||
| @@ -356,10 +354,10 @@ static void do_conv_3x3_stride2(const float* src, const float* filter, | |||||
| "3: \n" | "3: \n" | ||||
| : "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2) | : "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2) | ||||
| : "r"(mod3_left), "w"(_k0123), "w"(_k3456), "w"(_k5678) | : "r"(mod3_left), "w"(_k0123), "w"(_k3456), "w"(_k5678) | ||||
| : "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", | |||||
| "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||||
| "v15", "v16", "v17", "v18", "v21", "v22", "v23", "v24", | |||||
| "v25", "v26", "v27", "v28", "v29"); | |||||
| : "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", | |||||
| "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", | |||||
| "v17", "v18", "v21", "v22", "v23", "v24", "v25", "v26", "v27", | |||||
| "v28", "v29"); | |||||
| r0 += tail_step; | r0 += tail_step; | ||||
| r1 += tail_step; | r1 += tail_step; | ||||
| @@ -371,9 +369,9 @@ static void do_conv_3x3_stride2(const float* src, const float* filter, | |||||
| } | } | ||||
| // refer to function do_conv_5x5_stride2_asm_unroll2 | // refer to function do_conv_5x5_stride2_asm_unroll2 | ||||
| static void do_conv_5x5_stride2(const float* src, const float* filter, | |||||
| float* dst, size_t IH, size_t IW, size_t OH, | |||||
| size_t OW, size_t IC) { | |||||
| static void do_conv_5x5_stride2( | |||||
| const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||||
| size_t OH, size_t OW, size_t IC) { | |||||
| const size_t tail_step = IW - 2 * OW + IW; | const size_t tail_step = IW - 2 * OW + IW; | ||||
| size_t width = OW >> 2; | size_t width = OW >> 2; | ||||
| size_t mod2_left = width & 1; | size_t mod2_left = width & 1; | ||||
| @@ -591,15 +589,13 @@ static void do_conv_5x5_stride2(const float* src, const float* filter, | |||||
| "bne 2b \n" | "bne 2b \n" | ||||
| "3: \n" | "3: \n" | ||||
| : "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), | |||||
| "+r"(r3), "+r"(r4) | |||||
| : "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3), | |||||
| "+r"(r4) | |||||
| : "w"(_k0123), "w"(_k4567), "w"(_k891011), "w"(_k12131415), | : "w"(_k0123), "w"(_k4567), "w"(_k891011), "w"(_k12131415), | ||||
| "w"(_k16171819), "w"(_k20212223), "w"(_k24242424), | |||||
| "r"(mod2_left) | |||||
| : "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", | |||||
| "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||||
| "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", | |||||
| "v23", "v24"); | |||||
| "w"(_k16171819), "w"(_k20212223), "w"(_k24242424), "r"(mod2_left) | |||||
| : "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", | |||||
| "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", | |||||
| "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24"); | |||||
| r0 += tail_step; | r0 += tail_step; | ||||
| r1 += tail_step; | r1 += tail_step; | ||||
| @@ -613,9 +609,9 @@ static void do_conv_5x5_stride2(const float* src, const float* filter, | |||||
| } | } | ||||
| // refer to function do_conv_7x7_stride2_asm_unroll2 | // refer to function do_conv_7x7_stride2_asm_unroll2 | ||||
| static void do_conv_7x7_stride2(const float* src, const float* filter, | |||||
| float* dst, size_t IH, size_t IW, size_t OH, | |||||
| size_t OW, size_t IC) { | |||||
| static void do_conv_7x7_stride2( | |||||
| const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||||
| size_t OH, size_t OW, size_t IC) { | |||||
| const size_t tail_step = IW - 2 * OW + IW; | const size_t tail_step = IW - 2 * OW + IW; | ||||
| size_t width = OW >> 2; | size_t width = OW >> 2; | ||||
| @@ -993,16 +989,15 @@ static void do_conv_7x7_stride2(const float* src, const float* filter, | |||||
| "bne 2b \n" | "bne 2b \n" | ||||
| "3: \n" | "3: \n" | ||||
| : "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3), | |||||
| "+r"(r4), "+r"(r5), "+r"(r6) | |||||
| : "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3), "+r"(r4), | |||||
| "+r"(r5), "+r"(r6) | |||||
| : "r"(width), "w"(_k0123), "w"(_k4567), "w"(_k891011), | : "r"(width), "w"(_k0123), "w"(_k4567), "w"(_k891011), | ||||
| "w"(_k12131415), "w"(_k16171819), "w"(_k20212223), | "w"(_k12131415), "w"(_k16171819), "w"(_k20212223), | ||||
| "w"(_k24252627), "w"(_k28293031), "w"(_k32333435), | "w"(_k24252627), "w"(_k28293031), "w"(_k32333435), | ||||
| "w"(_k36373839), "w"(_k40414243), "w"(_k44454647), | |||||
| "w"(_k48484848) | |||||
| : "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", | |||||
| "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||||
| "v15", "v16", "v17", "v18"); | |||||
| "w"(_k36373839), "w"(_k40414243), "w"(_k44454647), "w"(_k48484848) | |||||
| : "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", | |||||
| "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", | |||||
| "v17", "v18"); | |||||
| r0 += tail_step; | r0 += tail_step; | ||||
| r1 += tail_step; | r1 += tail_step; | ||||
| @@ -68,9 +68,9 @@ WorkspaceBundle ConvBiasImpl::AlgoS8MatrixMul::get_bundle( | |||||
| size_t N = OH * OW; | size_t N = OH * OW; | ||||
| #if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
| #define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||||
| _bias_midout_enum, _nonline, \ | |||||
| _nonline_midout_enum) \ | |||||
| #define DISPATCH_GEMM_STRATEGY( \ | |||||
| _gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ | |||||
| _nonline_midout_enum) \ | |||||
| matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | ||||
| M, N, K, param.filter_type, param.src_type, param.dst_type); \ | M, N, K, param.filter_type, param.src_type, param.dst_type); \ | ||||
| part2 = megdnn::matmul::GemmInterleaved< \ | part2 = megdnn::matmul::GemmInterleaved< \ | ||||
| @@ -84,11 +84,12 @@ WorkspaceBundle ConvBiasImpl::AlgoS8MatrixMul::get_bundle( | |||||
| DISPATCH_GEMM_BIAS(s8_4x4, 0) | DISPATCH_GEMM_BIAS(s8_4x4, 0) | ||||
| } | } | ||||
| #else | #else | ||||
| #define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||||
| _bias_midout_enum, _nonline, \ | |||||
| _nonline_midout_enum) \ | |||||
| MIDOUT_BEGIN(megdnn_aarch64_conv_bias_int8_gemm, 0, _gemm_midout_enum, \ | |||||
| _bias_midout_enum, _nonline_midout_enum) { \ | |||||
| #define DISPATCH_GEMM_STRATEGY( \ | |||||
| _gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ | |||||
| _nonline_midout_enum) \ | |||||
| MIDOUT_BEGIN( \ | |||||
| megdnn_aarch64_conv_bias_int8_gemm, 0, _gemm_midout_enum, \ | |||||
| _bias_midout_enum, _nonline_midout_enum) { \ | |||||
| matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | ||||
| M, N, K, param.filter_type, param.src_type, param.dst_type); \ | M, N, K, param.filter_type, param.src_type, param.dst_type); \ | ||||
| part2 = megdnn::matmul::GemmInterleaved< \ | part2 = megdnn::matmul::GemmInterleaved< \ | ||||
| @@ -104,8 +105,8 @@ WorkspaceBundle ConvBiasImpl::AlgoS8MatrixMul::get_bundle( | |||||
| return {nullptr, {part0, part1, part2}}; | return {nullptr, {part0, part1, part2}}; | ||||
| } | } | ||||
| void ConvBiasImpl::AlgoS8MatrixMul::kimpl(const NCBKernParam& param, | |||||
| const NCBKernIndex& ncb_index) { | |||||
| void ConvBiasImpl::AlgoS8MatrixMul::kimpl( | |||||
| const NCBKernParam& param, const NCBKernIndex& ncb_index) { | |||||
| auto is_xcorr = !param.filter_meta.should_flip; | auto is_xcorr = !param.filter_meta.should_flip; | ||||
| UNPACK_CONV_NCB_KERN_SIZES(param); | UNPACK_CONV_NCB_KERN_SIZES(param); | ||||
| auto bundle = get_bundle(param); | auto bundle = get_bundle(param); | ||||
| @@ -157,29 +158,28 @@ void ConvBiasImpl::AlgoS8MatrixMul::kimpl(const NCBKernParam& param, | |||||
| img2col<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW); | img2col<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW); | ||||
| } else { | } else { | ||||
| if (is_xcorr) | if (is_xcorr) | ||||
| img2col_stride<true>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, | |||||
| FW, SH, SW); | |||||
| img2col_stride<true>( | |||||
| src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW, SH, SW); | |||||
| else | else | ||||
| img2col_stride<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, | |||||
| FW, SH, SW); | |||||
| img2col_stride<false>( | |||||
| src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW, SH, SW); | |||||
| } | } | ||||
| } | } | ||||
| { | { | ||||
| Workspace workspace(static_cast<dt_byte*>(bundle.get(2)), | |||||
| bundle.get_size(2)); | |||||
| Workspace workspace( | |||||
| static_cast<dt_byte*>(bundle.get(2)), bundle.get_size(2)); | |||||
| size_t M = OC; | size_t M = OC; | ||||
| size_t K = IC * FH * FW; | size_t K = IC * FH * FW; | ||||
| size_t N = OH * OW; | size_t N = OH * OW; | ||||
| #if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
| #define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||||
| _bias_midout_enum, _nonline, \ | |||||
| _nonline_midout_enum) \ | |||||
| matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||||
| M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||||
| megdnn::matmul::GemmInterleaved< \ | |||||
| matmul::gemm_##_gemm##_##_bias##_##_nonline> \ | |||||
| gemm_interleaved(M, N, K, false, false, strategy); \ | |||||
| #define DISPATCH_GEMM_STRATEGY( \ | |||||
| _gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ | |||||
| _nonline_midout_enum) \ | |||||
| matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||||
| M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||||
| megdnn::matmul::GemmInterleaved<matmul::gemm_##_gemm##_##_bias##_##_nonline> \ | |||||
| gemm_interleaved(M, N, K, false, false, strategy); \ | |||||
| gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias); | gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias); | ||||
| if (cpuinfo_has_arm_neon_dot()) { | if (cpuinfo_has_arm_neon_dot()) { | ||||
| @@ -188,19 +188,18 @@ void ConvBiasImpl::AlgoS8MatrixMul::kimpl(const NCBKernParam& param, | |||||
| DISPATCH_GEMM_BIAS(s8_4x4, 0) | DISPATCH_GEMM_BIAS(s8_4x4, 0) | ||||
| } | } | ||||
| #else | #else | ||||
| #define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||||
| _bias_midout_enum, _nonline, \ | |||||
| _nonline_midout_enum) \ | |||||
| MIDOUT_BEGIN(megdnn_aarch64_conv_bias_int8_gemm, 1, _gemm_midout_enum, \ | |||||
| _bias_midout_enum, _nonline_midout_enum) { \ | |||||
| matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||||
| M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||||
| megdnn::matmul::GemmInterleaved< \ | |||||
| matmul::gemm_##_gemm##_##_bias##_##_nonline> \ | |||||
| gemm_interleaved(M, N, K, false, false, strategy); \ | |||||
| gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, \ | |||||
| bias); \ | |||||
| } \ | |||||
| #define DISPATCH_GEMM_STRATEGY( \ | |||||
| _gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ | |||||
| _nonline_midout_enum) \ | |||||
| MIDOUT_BEGIN( \ | |||||
| megdnn_aarch64_conv_bias_int8_gemm, 1, _gemm_midout_enum, \ | |||||
| _bias_midout_enum, _nonline_midout_enum) { \ | |||||
| matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||||
| M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||||
| megdnn::matmul::GemmInterleaved<matmul::gemm_##_gemm##_##_bias##_##_nonline> \ | |||||
| gemm_interleaved(M, N, K, false, false, strategy); \ | |||||
| gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias); \ | |||||
| } \ | |||||
| MIDOUT_END() | MIDOUT_END() | ||||
| DISPATCH_GEMM_BIAS(s8_4x4, 0) | DISPATCH_GEMM_BIAS(s8_4x4, 0) | ||||
| #endif | #endif | ||||
| @@ -12,8 +12,8 @@ | |||||
| #pragma once | #pragma once | ||||
| #include "src/aarch64/conv_bias/opr_impl.h" | #include "src/aarch64/conv_bias/opr_impl.h" | ||||
| #include "src/fallback/conv_bias/opr_impl.h" | |||||
| #include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
| #include "src/fallback/conv_bias/opr_impl.h" | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace aarch64 { | namespace aarch64 { | ||||
| @@ -25,18 +25,16 @@ class ConvBiasImpl::AlgoS8MatrixMul final : public AlgoBase { | |||||
| static void kimpl(const NCBKernParam& param, const NCBKernIndex& ncb_index); | static void kimpl(const NCBKernParam& param, const NCBKernIndex& ncb_index); | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | |||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| } | |||||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
| const char* name() const override { return "S8MATMUL"; } | const char* name() const override { return "S8MATMUL"; } | ||||
| bool usable(const NCBKernSizeParam& param, | |||||
| AlgoSelectionStrategy algo_selection_strategy) const override; | |||||
| bool usable( | |||||
| const NCBKernSizeParam& param, | |||||
| AlgoSelectionStrategy algo_selection_strategy) const override; | |||||
| size_t get_workspace(const NCBKernSizeParam& param) const override { | size_t get_workspace(const NCBKernSizeParam& param) const override { | ||||
| return get_bundle(param).total_size_in_bytes(); | return get_bundle(param).total_size_in_bytes(); | ||||
| } | } | ||||
| SmallVector<NCBKern> dispatch_kerns( | |||||
| const NCBKernSizeParam& param) const override { | |||||
| SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam& param) const override { | |||||
| size_t group = param.filter_meta.group; | size_t group = param.filter_meta.group; | ||||
| return {{kimpl, {group, 1_z, 1_z}}}; | return {{kimpl, {group, 1_z, 1_z}}}; | ||||
| } | } | ||||
| @@ -29,9 +29,10 @@ struct KernCaller; | |||||
| #if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
| template <BiasMode bmode, typename Op> | template <BiasMode bmode, typename Op> | ||||
| struct KernCaller<bmode, Op, 8, 12> { | struct KernCaller<bmode, Op, 8, 12> { | ||||
| static void run(const dt_int8* packA, const dt_int8* packB, size_t M, | |||||
| size_t N, size_t K, dt_int8* C, size_t LDC, bool is_first_k, | |||||
| Op op, const dt_int32* bias, dt_int32* workspace) { | |||||
| static void run( | |||||
| const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||||
| dt_int8* C, size_t LDC, bool is_first_k, Op op, const dt_int32* bias, | |||||
| dt_int32* workspace) { | |||||
| megdnn_assert(is_first_k); | megdnn_assert(is_first_k); | ||||
| constexpr size_t A_INTERLEAVE = 8; | constexpr size_t A_INTERLEAVE = 8; | ||||
| @@ -49,19 +50,19 @@ struct KernCaller<bmode, Op, 8, 12> { | |||||
| size_t n = 0; | size_t n = 0; | ||||
| const dt_int8* cur_packB = packB; | const dt_int8* cur_packB = packB; | ||||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
| matmul_8x12x4::kern_8x12(packA, cur_packB, K, workspace, 12, | |||||
| is_first_k); | |||||
| matmul_8x12x4::kern_8x12( | |||||
| packA, cur_packB, K, workspace, 12, is_first_k); | |||||
| arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 8, 12, 8, | |||||
| 12>::postprocess(bias, workspace, | |||||
| output, LDC, op); | |||||
| arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 8, 12, 8, 12>:: | |||||
| postprocess(bias, workspace, output, LDC, op); | |||||
| output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
| cur_packB += K12; | cur_packB += K12; | ||||
| } | } | ||||
| for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
| matmul_8x12x4::kern_8x4(packA, cur_packB, K, workspace, 4, | |||||
| is_first_k, std::min<size_t>(N - n, 4)); | |||||
| matmul_8x12x4::kern_8x4( | |||||
| packA, cur_packB, K, workspace, 4, is_first_k, | |||||
| std::min<size_t>(N - n, 4)); | |||||
| #define cb(m, n) \ | #define cb(m, n) \ | ||||
| arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 8, 4, 8, n>::postprocess( \ | arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 8, 4, 8, n>::postprocess( \ | ||||
| @@ -83,9 +84,9 @@ struct KernCaller<bmode, Op, 8, 12> { | |||||
| const dt_int8* cur_packB = packB; | const dt_int8* cur_packB = packB; | ||||
| size_t n = 0; | size_t n = 0; | ||||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
| matmul_8x12x4::kern_4x12(packA, cur_packB, K, workspace, 12, | |||||
| is_first_k, | |||||
| std::min<size_t>(M - m, 4)); | |||||
| matmul_8x12x4::kern_4x12( | |||||
| packA, cur_packB, K, workspace, 12, is_first_k, | |||||
| std::min<size_t>(M - m, 4)); | |||||
| #define cb(m, n) \ | #define cb(m, n) \ | ||||
| arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 12, m, n>::postprocess( \ | arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 12, m, n>::postprocess( \ | ||||
| bias, workspace, output, LDC, op); | bias, workspace, output, LDC, op); | ||||
| @@ -97,14 +98,13 @@ struct KernCaller<bmode, Op, 8, 12> { | |||||
| } | } | ||||
| for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
| matmul_8x12x4::kern_4x4(packA, cur_packB, K, workspace, 4, | |||||
| is_first_k, std::min<size_t>(M - m, 4), | |||||
| std::min<size_t>(N - n, 4)); | |||||
| matmul_8x12x4::kern_4x4( | |||||
| packA, cur_packB, K, workspace, 4, is_first_k, | |||||
| std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||||
| #define cb(m, n) \ | #define cb(m, n) \ | ||||
| arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, m, n>::postprocess( \ | arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, m, n>::postprocess( \ | ||||
| bias, workspace, output, LDC, op); | bias, workspace, output, LDC, op); | ||||
| DISPATCH_M(cb, std::min<size_t>(M - m, 4), | |||||
| std::min<size_t>(N - n, 4)); | |||||
| DISPATCH_M(cb, std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||||
| #undef cb | #undef cb | ||||
| output += 4; | output += 4; | ||||
| @@ -122,9 +122,10 @@ struct KernCaller<bmode, Op, 8, 12> { | |||||
| template <BiasMode bmode, typename Op> | template <BiasMode bmode, typename Op> | ||||
| struct KernCaller<bmode, Op, 4, 4> { | struct KernCaller<bmode, Op, 4, 4> { | ||||
| static void run(const dt_int8* packA, const dt_int8* packB, size_t M, | |||||
| size_t N, size_t K, dt_int8* C, size_t LDC, bool is_first_k, | |||||
| Op op, const dt_int32* bias, dt_int32* workspace) { | |||||
| static void run( | |||||
| const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||||
| dt_int8* C, size_t LDC, bool is_first_k, Op op, const dt_int32* bias, | |||||
| dt_int32* workspace) { | |||||
| megdnn_assert(is_first_k); | megdnn_assert(is_first_k); | ||||
| constexpr size_t A_INTERLEAVE = 4; | constexpr size_t A_INTERLEAVE = 4; | ||||
| @@ -140,20 +141,18 @@ struct KernCaller<bmode, Op, 4, 4> { | |||||
| size_t n = 0; | size_t n = 0; | ||||
| const dt_int8* cur_packB = packB; | const dt_int8* cur_packB = packB; | ||||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
| matmul_4x4x16::kern_4x4(packA, cur_packB, K, workspace, 4, | |||||
| is_first_k); | |||||
| arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, 4, | |||||
| 4>::postprocess(bias, workspace, | |||||
| output, LDC, op); | |||||
| matmul_4x4x16::kern_4x4(packA, cur_packB, K, workspace, 4, is_first_k); | |||||
| arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, 4, 4>::postprocess( | |||||
| bias, workspace, output, LDC, op); | |||||
| output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
| cur_packB += K4; | cur_packB += K4; | ||||
| } | } | ||||
| for (; n < N; n += B_INTERLEAVE) { | for (; n < N; n += B_INTERLEAVE) { | ||||
| matmul_4x4x16::kern_4x4_remain(packA, cur_packB, K, workspace, | |||||
| 4, is_first_k, 4, | |||||
| std::min<size_t>(N - n, 4)); | |||||
| matmul_4x4x16::kern_4x4_remain( | |||||
| packA, cur_packB, K, workspace, 4, is_first_k, 4, | |||||
| std::min<size_t>(N - n, 4)); | |||||
| #define cb(m, n) \ | #define cb(m, n) \ | ||||
| arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, 4, n>::postprocess( \ | arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, 4, n>::postprocess( \ | ||||
| bias, workspace, output, LDC, op); | bias, workspace, output, LDC, op); | ||||
| @@ -182,8 +181,7 @@ struct KernCaller<bmode, Op, 4, 4> { | |||||
| #define cb(m, n) \ | #define cb(m, n) \ | ||||
| arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, m, n>::postprocess( \ | arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, m, n>::postprocess( \ | ||||
| bias, workspace, output, LDC, op); | bias, workspace, output, LDC, op); | ||||
| DISPATCH_M(cb, std::min<size_t>(M - m, 4), | |||||
| std::min<size_t>(N - n, 4)); | |||||
| DISPATCH_M(cb, std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||||
| #undef cb | #undef cb | ||||
| output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
| cur_packB += K4; | cur_packB += K4; | ||||
| @@ -200,21 +198,19 @@ struct KernCaller<bmode, Op, 4, 4> { | |||||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_4x4_nobias_identity) | MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_4x4_nobias_identity) | ||||
| void gemm_s8_4x4_nobias_identity::pack_A(dt_int8* outptr, const dt_int8* inptr, | |||||
| int ldin, int y0, int ymax, int k0, | |||||
| int kmax, bool transpose) const { | |||||
| void gemm_s8_4x4_nobias_identity::pack_A( | |||||
| dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||||
| int kmax, bool transpose) const { | |||||
| if (transpose) { | if (transpose) { | ||||
| matmul_4x4x16::gemm_s8_4x4_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, | |||||
| kmax); | |||||
| matmul_4x4x16::gemm_s8_4x4_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||||
| } else { | } else { | ||||
| matmul_4x4x16::gemm_s8_4x4_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, | |||||
| kmax); | |||||
| matmul_4x4x16::gemm_s8_4x4_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||||
| } | } | ||||
| } | } | ||||
| void gemm_s8_4x4_nobias_identity::pack_B(dt_int8* out, const dt_int8* in, | |||||
| int ldin, int x0, int xmax, int k0, | |||||
| int kmax, bool transpose) const { | |||||
| void gemm_s8_4x4_nobias_identity::pack_B( | |||||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||||
| bool transpose) const { | |||||
| if (transpose) { | if (transpose) { | ||||
| matmul_4x4x16::gemm_s8_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, kmax); | matmul_4x4x16::gemm_s8_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, kmax); | ||||
| } else { | } else { | ||||
| @@ -229,23 +225,21 @@ size_t gemm_s8_4x4_nobias_identity::get_workspace_size() const { | |||||
| #if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x12_nobias_identity) | MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x12_nobias_identity) | ||||
| void gemm_s8_8x12_nobias_identity::pack_A(dt_int8* outptr, const dt_int8* inptr, | |||||
| int ldin, int y0, int ymax, int k0, | |||||
| int kmax, bool transpose) const { | |||||
| void gemm_s8_8x12_nobias_identity::pack_A( | |||||
| dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||||
| int kmax, bool transpose) const { | |||||
| MEGDNN_MARK_USED_VAR(matmul_8x12x4::gemm_s8_8x12_pack_A_t); | MEGDNN_MARK_USED_VAR(matmul_8x12x4::gemm_s8_8x12_pack_A_t); | ||||
| MEGDNN_MARK_USED_VAR(matmul_8x12x4::gemm_s8_8x12_pack_B_t); | MEGDNN_MARK_USED_VAR(matmul_8x12x4::gemm_s8_8x12_pack_B_t); | ||||
| if (transpose) { | if (transpose) { | ||||
| matmul_8x12x4::gemm_s8_8x12_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, | |||||
| kmax); | |||||
| matmul_8x12x4::gemm_s8_8x12_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||||
| } else { | } else { | ||||
| matmul_8x12x4::gemm_s8_8x12_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, | |||||
| kmax); | |||||
| matmul_8x12x4::gemm_s8_8x12_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||||
| } | } | ||||
| } | } | ||||
| void gemm_s8_8x12_nobias_identity::pack_B(dt_int8* out, const dt_int8* in, | |||||
| int ldin, int x0, int xmax, int k0, | |||||
| int kmax, bool transpose) const { | |||||
| void gemm_s8_8x12_nobias_identity::pack_B( | |||||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||||
| bool transpose) const { | |||||
| if (transpose) { | if (transpose) { | ||||
| matmul_8x12x4::gemm_s8_8x12_pack_A_n(out, in, ldin, x0, xmax, k0, kmax); | matmul_8x12x4::gemm_s8_8x12_pack_A_n(out, in, ldin, x0, xmax, k0, kmax); | ||||
| } else { | } else { | ||||
| @@ -259,18 +253,17 @@ size_t gemm_s8_8x12_nobias_identity::get_workspace_size() const { | |||||
| #endif | #endif | ||||
| #define KERN(_block_m, _block_n, _bias, _BIAS, _nonline, _OP) \ | |||||
| void gemm_s8_##_block_m##x##_block_n##_##_bias##_##_nonline::kern( \ | |||||
| const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, \ | |||||
| size_t K, dt_int8* C, size_t LDC, bool is_first_k, \ | |||||
| const dt_int32* bias, dt_int32* workspace) const { \ | |||||
| float scale_A = A_dtype.param<dtype::QuantizedS8>().scale; \ | |||||
| float scale_B = B_dtype.param<dtype::QuantizedS8>().scale; \ | |||||
| float scale_C = C_dtype.param<dtype::QuantizedS8>().scale; \ | |||||
| DEFINE_OP(_OP); \ | |||||
| impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n>::run( \ | |||||
| packA, packB, M, N, K, C, LDC, is_first_k, op, bias, \ | |||||
| workspace); \ | |||||
| #define KERN(_block_m, _block_n, _bias, _BIAS, _nonline, _OP) \ | |||||
| void gemm_s8_##_block_m##x##_block_n##_##_bias##_##_nonline::kern( \ | |||||
| const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, \ | |||||
| dt_int8* C, size_t LDC, bool is_first_k, const dt_int32* bias, \ | |||||
| dt_int32* workspace) const { \ | |||||
| float scale_A = A_dtype.param<dtype::QuantizedS8>().scale; \ | |||||
| float scale_B = B_dtype.param<dtype::QuantizedS8>().scale; \ | |||||
| float scale_C = C_dtype.param<dtype::QuantizedS8>().scale; \ | |||||
| DEFINE_OP(_OP); \ | |||||
| impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n>::run( \ | |||||
| packA, packB, M, N, K, C, LDC, is_first_k, op, bias, workspace); \ | |||||
| } | } | ||||
| #define DEFINE_OP(_Op) \ | #define DEFINE_OP(_Op) \ | ||||
| @@ -286,18 +279,16 @@ KERN(8, 12, nobias, BiasMode::NO_BIAS, hswish, HSwishOp) | |||||
| #endif | #endif | ||||
| #undef DEFINE_OP | #undef DEFINE_OP | ||||
| #define DEFINE_OP(_Op) \ | |||||
| arm_common::_Op<dt_qint32, dt_qint8> op(scale_A* scale_B, \ | |||||
| scale_A* scale_B, scale_C); | |||||
| #define DEFINE_OP(_Op) \ | |||||
| arm_common::_Op<dt_qint32, dt_qint8> op( \ | |||||
| scale_A* scale_B, scale_A* scale_B, scale_C); | |||||
| KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) | KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) | ||||
| KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) | KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) | ||||
| KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, | |||||
| FuseAddHSwishOp) | |||||
| KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, FuseAddHSwishOp) | |||||
| #if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
| KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) | KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) | ||||
| KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) | KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) | ||||
| KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, | |||||
| FuseAddHSwishOp) | |||||
| KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, FuseAddHSwishOp) | |||||
| #endif | #endif | ||||
| #undef DEFINE_OP | #undef DEFINE_OP | ||||
| @@ -20,43 +20,42 @@ namespace matmul { | |||||
| * | * | ||||
| * \name gemm_<type>_<block>_biasmode_nolinemode | * \name gemm_<type>_<block>_biasmode_nolinemode | ||||
| */ | */ | ||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_int8, dt_int8, dt_int32, 4, 4, 16, | |||||
| false, true, | |||||
| gemm_s8_4x4_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK( | |||||
| dt_int8, dt_int8, dt_int32, 4, 4, 16, false, true, gemm_s8_4x4_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_nobias_relu, | |||||
| gemm_s8_4x4_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||||
| gemm_s8_4x4_nobias_relu, gemm_s8_4x4_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_nobias_hswish, | |||||
| gemm_s8_4x4_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||||
| gemm_s8_4x4_nobias_hswish, gemm_s8_4x4_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_identity, | |||||
| gemm_s8_4x4_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||||
| gemm_s8_4x4_bias_channel_identity, gemm_s8_4x4_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_relu, | |||||
| gemm_s8_4x4_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||||
| gemm_s8_4x4_bias_channel_relu, gemm_s8_4x4_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_hswish, | |||||
| gemm_s8_4x4_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||||
| gemm_s8_4x4_bias_channel_hswish, gemm_s8_4x4_nobias_identity); | |||||
| #if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_int8, dt_int8, dt_int32, 8, 12, 4, | |||||
| false, true, | |||||
| gemm_s8_8x12_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK( | |||||
| dt_int8, dt_int8, dt_int32, 8, 12, 4, false, true, | |||||
| gemm_s8_8x12_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_nobias_relu, | |||||
| gemm_s8_8x12_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||||
| gemm_s8_8x12_nobias_relu, gemm_s8_8x12_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_nobias_hswish, | |||||
| gemm_s8_8x12_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||||
| gemm_s8_8x12_nobias_hswish, gemm_s8_8x12_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_identity, | |||||
| gemm_s8_8x12_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||||
| gemm_s8_8x12_bias_channel_identity, gemm_s8_8x12_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_relu, | |||||
| gemm_s8_8x12_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||||
| gemm_s8_8x12_bias_channel_relu, gemm_s8_8x12_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_hswish, | |||||
| gemm_s8_8x12_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||||
| gemm_s8_8x12_bias_channel_hswish, gemm_s8_8x12_nobias_identity); | |||||
| #endif | #endif | ||||
| } // namespace matmul | } // namespace matmul | ||||
| @@ -13,13 +13,13 @@ | |||||
| #include "src/aarch64/conv_bias/int8/algos.h" | #include "src/aarch64/conv_bias/int8/algos.h" | ||||
| #include "src/aarch64/conv_bias/quint8/algos.h" | #include "src/aarch64/conv_bias/quint8/algos.h" | ||||
| #include "src/naive/handle.h" | |||||
| #include "src/common/utils.h" | |||||
| #include "src/common/metahelper.h" | #include "src/common/metahelper.h" | ||||
| #include "src/common/utils.h" | |||||
| #include "src/naive/handle.h" | |||||
| #include "src/fallback/convolution/opr_impl.h" | |||||
| #include "src/aarch64/conv_bias/fp32/algos.h" | |||||
| #include "src/aarch64/conv_bias/fp16/algos.h" | #include "src/aarch64/conv_bias/fp16/algos.h" | ||||
| #include "src/aarch64/conv_bias/fp32/algos.h" | |||||
| #include "src/fallback/convolution/opr_impl.h" | |||||
| using namespace megdnn; | using namespace megdnn; | ||||
| using namespace aarch64; | using namespace aarch64; | ||||
| @@ -56,12 +56,10 @@ public: | |||||
| const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& direct_algos() const { | const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& direct_algos() const { | ||||
| return m_direct_algos; | return m_direct_algos; | ||||
| } | } | ||||
| const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& matmul_algos() | |||||
| const { | |||||
| const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& matmul_algos() const { | |||||
| return m_matmul_algos; | return m_matmul_algos; | ||||
| } | } | ||||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | ||||
| }; | }; | ||||
| const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() { | const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() { | ||||
| @@ -71,15 +69,16 @@ const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() { | |||||
| MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(ConvBiasImpl) | MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(ConvBiasImpl) | ||||
| SmallVector<fallback::ConvBiasImpl::AlgoBase*> | |||||
| ConvBiasImpl::get_all_packed_algo() { | |||||
| SmallVector<fallback::ConvBiasImpl::AlgoBase*> ConvBiasImpl::get_all_packed_algo() { | |||||
| auto&& algos = arm_common::ConvBiasImpl::get_all_packed_algo(); | auto&& algos = arm_common::ConvBiasImpl::get_all_packed_algo(); | ||||
| algos.insert(algos.begin(), algo_pack().direct_algos().begin(), | |||||
| algo_pack().direct_algos().end()); | |||||
| algos.insert( | |||||
| algos.begin(), algo_pack().direct_algos().begin(), | |||||
| algo_pack().direct_algos().end()); | |||||
| //! We put matmul algos at the begin. Because matmul will get privilege when | //! We put matmul algos at the begin. Because matmul will get privilege when | ||||
| //! prefer return true. See | //! prefer return true. See | ||||
| algos.insert(algos.begin(), algo_pack().matmul_algos().begin(), | |||||
| algo_pack().matmul_algos().end()); | |||||
| algos.insert( | |||||
| algos.begin(), algo_pack().matmul_algos().begin(), | |||||
| algo_pack().matmul_algos().end()); | |||||
| return std::move(algos); | return std::move(algos); | ||||
| } | } | ||||
| @@ -9,8 +9,8 @@ | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| #include "src/common/utils.h" | |||||
| #include "src/arm_common/conv_bias/opr_impl.h" | #include "src/arm_common/conv_bias/opr_impl.h" | ||||
| #include "src/common/utils.h" | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace aarch64 { | namespace aarch64 { | ||||
| @@ -70,9 +70,9 @@ WorkspaceBundle ConvBiasImpl::AlgoQU8MatrixMul::get_bundle( | |||||
| size_t N = OH * OW; | size_t N = OH * OW; | ||||
| #if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
| #define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||||
| _bias_midout_enum, _nonline, \ | |||||
| _nonline_midout_enum) \ | |||||
| #define DISPATCH_GEMM_STRATEGY( \ | |||||
| _gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ | |||||
| _nonline_midout_enum) \ | |||||
| matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | ||||
| M, N, K, param.filter_type, param.src_type, param.dst_type); \ | M, N, K, param.filter_type, param.src_type, param.dst_type); \ | ||||
| part2 = megdnn::matmul::GemmInterleaved< \ | part2 = megdnn::matmul::GemmInterleaved< \ | ||||
| @@ -86,11 +86,12 @@ WorkspaceBundle ConvBiasImpl::AlgoQU8MatrixMul::get_bundle( | |||||
| DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0); | DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0); | ||||
| } | } | ||||
| #else | #else | ||||
| #define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||||
| _bias_midout_enum, _nonline, \ | |||||
| _nonline_midout_enum) \ | |||||
| MIDOUT_BEGIN(megdnn_aarch64_conv_bias_quint8_gemm, 0, _gemm_midout_enum, \ | |||||
| _bias_midout_enum, _nonline_midout_enum) { \ | |||||
| #define DISPATCH_GEMM_STRATEGY( \ | |||||
| _gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ | |||||
| _nonline_midout_enum) \ | |||||
| MIDOUT_BEGIN( \ | |||||
| megdnn_aarch64_conv_bias_quint8_gemm, 0, _gemm_midout_enum, \ | |||||
| _bias_midout_enum, _nonline_midout_enum) { \ | |||||
| matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | ||||
| M, N, K, param.filter_type, param.src_type, param.dst_type); \ | M, N, K, param.filter_type, param.src_type, param.dst_type); \ | ||||
| part2 = megdnn::matmul::GemmInterleaved< \ | part2 = megdnn::matmul::GemmInterleaved< \ | ||||
| @@ -106,8 +107,8 @@ WorkspaceBundle ConvBiasImpl::AlgoQU8MatrixMul::get_bundle( | |||||
| return {nullptr, {part0, part1, part2}}; | return {nullptr, {part0, part1, part2}}; | ||||
| } | } | ||||
| void ConvBiasImpl::AlgoQU8MatrixMul::kimpl(const NCBKernParam& param, | |||||
| const NCBKernIndex& ncb_index) { | |||||
| void ConvBiasImpl::AlgoQU8MatrixMul::kimpl( | |||||
| const NCBKernParam& param, const NCBKernIndex& ncb_index) { | |||||
| auto is_xcorr = !param.filter_meta.should_flip; | auto is_xcorr = !param.filter_meta.should_flip; | ||||
| UNPACK_CONV_NCB_KERN_SIZES(param); | UNPACK_CONV_NCB_KERN_SIZES(param); | ||||
| auto bundle = get_bundle(param); | auto bundle = get_bundle(param); | ||||
| @@ -160,29 +161,28 @@ void ConvBiasImpl::AlgoQU8MatrixMul::kimpl(const NCBKernParam& param, | |||||
| img2col<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW); | img2col<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW); | ||||
| } else { | } else { | ||||
| if (is_xcorr) | if (is_xcorr) | ||||
| img2col_stride<true>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, | |||||
| FW, SH, SW); | |||||
| img2col_stride<true>( | |||||
| src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW, SH, SW); | |||||
| else | else | ||||
| img2col_stride<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, | |||||
| FW, SH, SW); | |||||
| img2col_stride<false>( | |||||
| src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW, SH, SW); | |||||
| } | } | ||||
| } | } | ||||
| { | { | ||||
| Workspace workspace(static_cast<dt_byte*>(bundle.get(2)), | |||||
| bundle.get_size(2)); | |||||
| Workspace workspace( | |||||
| static_cast<dt_byte*>(bundle.get(2)), bundle.get_size(2)); | |||||
| size_t M = OC; | size_t M = OC; | ||||
| size_t K = IC * FH * FW; | size_t K = IC * FH * FW; | ||||
| size_t N = OH * OW; | size_t N = OH * OW; | ||||
| #if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
| #define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||||
| _bias_midout_enum, _nonline, \ | |||||
| _nonline_midout_enum) \ | |||||
| matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||||
| M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||||
| megdnn::matmul::GemmInterleaved< \ | |||||
| matmul::gemm_##_gemm##_##_bias##_##_nonline> \ | |||||
| gemm_interleaved(M, N, K, false, false, strategy); \ | |||||
| #define DISPATCH_GEMM_STRATEGY( \ | |||||
| _gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ | |||||
| _nonline_midout_enum) \ | |||||
| matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||||
| M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||||
| megdnn::matmul::GemmInterleaved<matmul::gemm_##_gemm##_##_bias##_##_nonline> \ | |||||
| gemm_interleaved(M, N, K, false, false, strategy); \ | |||||
| gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias); | gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias); | ||||
| if (cpuinfo_has_arm_neon_dot()) { | if (cpuinfo_has_arm_neon_dot()) { | ||||
| @@ -191,19 +191,18 @@ void ConvBiasImpl::AlgoQU8MatrixMul::kimpl(const NCBKernParam& param, | |||||
| DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0) | DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0) | ||||
| } | } | ||||
| #else | #else | ||||
| #define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||||
| _bias_midout_enum, _nonline, \ | |||||
| _nonline_midout_enum) \ | |||||
| MIDOUT_BEGIN(megdnn_aarch64_conv_bias_quint8_gemm, 1, _gemm_midout_enum, \ | |||||
| _bias_midout_enum, _nonline_midout_enum) { \ | |||||
| matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||||
| M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||||
| megdnn::matmul::GemmInterleaved< \ | |||||
| matmul::gemm_##_gemm##_##_bias##_##_nonline> \ | |||||
| gemm_interleaved(M, N, K, false, false, strategy); \ | |||||
| gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, \ | |||||
| bias); \ | |||||
| } \ | |||||
| #define DISPATCH_GEMM_STRATEGY( \ | |||||
| _gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ | |||||
| _nonline_midout_enum) \ | |||||
| MIDOUT_BEGIN( \ | |||||
| megdnn_aarch64_conv_bias_quint8_gemm, 1, _gemm_midout_enum, \ | |||||
| _bias_midout_enum, _nonline_midout_enum) { \ | |||||
| matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||||
| M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||||
| megdnn::matmul::GemmInterleaved<matmul::gemm_##_gemm##_##_bias##_##_nonline> \ | |||||
| gemm_interleaved(M, N, K, false, false, strategy); \ | |||||
| gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias); \ | |||||
| } \ | |||||
| MIDOUT_END() | MIDOUT_END() | ||||
| DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0) | DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0) | ||||
| @@ -12,8 +12,8 @@ | |||||
| #pragma once | #pragma once | ||||
| #include "src/aarch64/conv_bias/opr_impl.h" | #include "src/aarch64/conv_bias/opr_impl.h" | ||||
| #include "src/fallback/conv_bias/opr_impl.h" | |||||
| #include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
| #include "src/fallback/conv_bias/opr_impl.h" | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace aarch64 { | namespace aarch64 { | ||||
| @@ -25,18 +25,16 @@ class ConvBiasImpl::AlgoQU8MatrixMul final : public AlgoBase { | |||||
| static void kimpl(const NCBKernParam& param, const NCBKernIndex&); | static void kimpl(const NCBKernParam& param, const NCBKernIndex&); | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | |||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| } | |||||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
| const char* name() const override { return "QU8MATMUL"; } | const char* name() const override { return "QU8MATMUL"; } | ||||
| bool usable(const NCBKernSizeParam& param, | |||||
| AlgoSelectionStrategy algo_selection_strategy) const override; | |||||
| bool usable( | |||||
| const NCBKernSizeParam& param, | |||||
| AlgoSelectionStrategy algo_selection_strategy) const override; | |||||
| size_t get_workspace(const NCBKernSizeParam& param) const override { | size_t get_workspace(const NCBKernSizeParam& param) const override { | ||||
| return get_bundle(param).total_size_in_bytes(); | return get_bundle(param).total_size_in_bytes(); | ||||
| } | } | ||||
| SmallVector<NCBKern> dispatch_kerns( | |||||
| const NCBKernSizeParam& param) const override { | |||||
| SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam& param) const override { | |||||
| size_t group = param.filter_meta.group; | size_t group = param.filter_meta.group; | ||||
| return {{kimpl, {group, 1_z, 1_z}}}; | return {{kimpl, {group, 1_z, 1_z}}}; | ||||
| } | } | ||||
| @@ -14,8 +14,8 @@ | |||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/fallback/conv_bias/common.h" | #include "src/fallback/conv_bias/common.h" | ||||
| #include "src/aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h" | |||||
| #include "src/aarch64/matrix_mul/quint8/kernel_8x8x8.h" | #include "src/aarch64/matrix_mul/quint8/kernel_8x8x8.h" | ||||
| #include "src/aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h" | |||||
| #include "src/arm_common/conv_bias/matmul_postprocess.h" | #include "src/arm_common/conv_bias/matmul_postprocess.h" | ||||
| using namespace megdnn; | using namespace megdnn; | ||||
| @@ -29,10 +29,10 @@ struct KernCaller; | |||||
| #if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
| template <BiasMode bmode, typename Op> | template <BiasMode bmode, typename Op> | ||||
| struct KernCaller<bmode, Op, 8, 8, true> { | struct KernCaller<bmode, Op, 8, 8, true> { | ||||
| static void run(const dt_uint8* packA, const dt_uint8* packB, size_t M, | |||||
| size_t N, size_t K, dt_uint8* C, size_t LDC, | |||||
| bool is_first_k, Op op, const dt_int32* bias, | |||||
| dt_int32* workspace, uint8_t zp_A, uint8_t zp_B) { | |||||
| static void run( | |||||
| const dt_uint8* packA, const dt_uint8* packB, size_t M, size_t N, size_t K, | |||||
| dt_uint8* C, size_t LDC, bool is_first_k, Op op, const dt_int32* bias, | |||||
| dt_int32* workspace, uint8_t zp_A, uint8_t zp_B) { | |||||
| megdnn_assert(is_first_k); | megdnn_assert(is_first_k); | ||||
| constexpr size_t A_INTERLEAVE = 8; | constexpr size_t A_INTERLEAVE = 8; | ||||
| constexpr size_t B_INTERLEAVE = 8; | constexpr size_t B_INTERLEAVE = 8; | ||||
| @@ -50,20 +50,19 @@ struct KernCaller<bmode, Op, 8, 8, true> { | |||||
| size_t n = 0; | size_t n = 0; | ||||
| const dt_uint8* cur_packB = packB; | const dt_uint8* cur_packB = packB; | ||||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
| matmul_8x8x4::kern_8x8(packA, cur_packB, K, workspace, 8, | |||||
| is_first_k, zp_A, zp_B, zAB); | |||||
| matmul_8x8x4::kern_8x8( | |||||
| packA, cur_packB, K, workspace, 8, is_first_k, zp_A, zp_B, zAB); | |||||
| arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 8, 8, | |||||
| 8>::postprocess(bias, workspace, | |||||
| output, LDC, op); | |||||
| arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 8, 8, 8>:: | |||||
| postprocess(bias, workspace, output, LDC, op); | |||||
| output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
| cur_packB += K8; | cur_packB += K8; | ||||
| } | } | ||||
| for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
| matmul_8x8x4::kern_8x4(packA, cur_packB, K, workspace, 4, | |||||
| is_first_k, std::min<size_t>(N - n, 4), | |||||
| zp_A, zp_B, zAB); | |||||
| matmul_8x8x4::kern_8x4( | |||||
| packA, cur_packB, K, workspace, 4, is_first_k, | |||||
| std::min<size_t>(N - n, 4), zp_A, zp_B, zAB); | |||||
| #define cb(m, n) \ | #define cb(m, n) \ | ||||
| arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 4, 8, n>::postprocess( \ | arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 4, 8, n>::postprocess( \ | ||||
| bias, workspace, output, LDC, op); | bias, workspace, output, LDC, op); | ||||
| @@ -84,9 +83,9 @@ struct KernCaller<bmode, Op, 8, 8, true> { | |||||
| const dt_uint8* cur_packB = packB; | const dt_uint8* cur_packB = packB; | ||||
| size_t n = 0; | size_t n = 0; | ||||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
| matmul_8x8x4::kern_4x8(packA, cur_packB, K, workspace, 8, | |||||
| is_first_k, std::min<size_t>(M - m, 4), | |||||
| zp_A, zp_B, zAB); | |||||
| matmul_8x8x4::kern_4x8( | |||||
| packA, cur_packB, K, workspace, 8, is_first_k, | |||||
| std::min<size_t>(M - m, 4), zp_A, zp_B, zAB); | |||||
| #define cb(m, n) \ | #define cb(m, n) \ | ||||
| arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 8, m, n>::postprocess( \ | arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 8, m, n>::postprocess( \ | ||||
| bias, workspace, output, LDC, op); | bias, workspace, output, LDC, op); | ||||
| @@ -98,15 +97,14 @@ struct KernCaller<bmode, Op, 8, 8, true> { | |||||
| } | } | ||||
| for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
| matmul_8x8x4::kern_4x4(packA, cur_packB, K, workspace, 4, | |||||
| is_first_k, std::min<size_t>(M - m, 4), | |||||
| std::min<size_t>(N - n, 4), zp_A, zp_B, | |||||
| zAB); | |||||
| matmul_8x8x4::kern_4x4( | |||||
| packA, cur_packB, K, workspace, 4, is_first_k, | |||||
| std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4), zp_A, | |||||
| zp_B, zAB); | |||||
| #define cb(m, n) \ | #define cb(m, n) \ | ||||
| arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 4, m, n>::postprocess( \ | arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 4, m, n>::postprocess( \ | ||||
| bias, workspace, output, LDC, op); | bias, workspace, output, LDC, op); | ||||
| DISPATCH_M(cb, std::min<size_t>(M - m, 4), | |||||
| std::min<size_t>(N - n, 4)); | |||||
| DISPATCH_M(cb, std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||||
| #undef cb | #undef cb | ||||
| output += 4; | output += 4; | ||||
| @@ -124,10 +122,10 @@ struct KernCaller<bmode, Op, 8, 8, true> { | |||||
| template <BiasMode bmode, typename Op> | template <BiasMode bmode, typename Op> | ||||
| struct KernCaller<bmode, Op, 8, 8, false> { | struct KernCaller<bmode, Op, 8, 8, false> { | ||||
| static void run(const dt_uint8* packA, const dt_uint8* packB, size_t M, | |||||
| size_t N, size_t K, dt_uint8* C, size_t LDC, | |||||
| bool is_first_k, Op op, const dt_int32* bias, | |||||
| dt_int32* workspace, uint8_t zp_A, uint8_t zp_B) { | |||||
| static void run( | |||||
| const dt_uint8* packA, const dt_uint8* packB, size_t M, size_t N, size_t K, | |||||
| dt_uint8* C, size_t LDC, bool is_first_k, Op op, const dt_int32* bias, | |||||
| dt_int32* workspace, uint8_t zp_A, uint8_t zp_B) { | |||||
| megdnn_assert(is_first_k); | megdnn_assert(is_first_k); | ||||
| constexpr size_t A_INTERLEAVE = 8; | constexpr size_t A_INTERLEAVE = 8; | ||||
| @@ -144,27 +142,25 @@ struct KernCaller<bmode, Op, 8, 8, false> { | |||||
| size_t n = 0; | size_t n = 0; | ||||
| const dt_uint8* cur_packB = packB; | const dt_uint8* cur_packB = packB; | ||||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
| matmul_8x8x8::kern_8x8(packA, cur_packB, K, workspace, 8, | |||||
| is_first_k, zp_A, zp_B); | |||||
| matmul_8x8x8::kern_8x8( | |||||
| packA, cur_packB, K, workspace, 8, is_first_k, zp_A, zp_B); | |||||
| arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 8, 8, | |||||
| 8>::postprocess(bias, workspace, | |||||
| output, LDC, op); | |||||
| arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 8, 8, 8>:: | |||||
| postprocess(bias, workspace, output, LDC, op); | |||||
| output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
| cur_packB += K8; | cur_packB += K8; | ||||
| } | } | ||||
| for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
| matmul_8x8x8::kern_8x4(packA, cur_packB, K, workspace, 4, | |||||
| is_first_k, std::min<size_t>(N - n, 4), | |||||
| zp_A, zp_B); | |||||
| matmul_8x8x8::kern_8x4( | |||||
| packA, cur_packB, K, workspace, 4, is_first_k, | |||||
| std::min<size_t>(N - n, 4), zp_A, zp_B); | |||||
| #define cb(m, n) \ | #define cb(m, n) \ | ||||
| arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 4, 8, n>::postprocess( \ | arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 4, 8, n>::postprocess( \ | ||||
| bias, workspace, output, LDC, op); | bias, workspace, output, LDC, op); | ||||
| DISPATCH_N(cb, 8, std::min<size_t>(N - n, 4)); | DISPATCH_N(cb, 8, std::min<size_t>(N - n, 4)); | ||||
| #undef cb | #undef cb | ||||
| output += 4; | output += 4; | ||||
| cur_packB += K4; | cur_packB += K4; | ||||
| } | } | ||||
| @@ -179,9 +175,9 @@ struct KernCaller<bmode, Op, 8, 8, false> { | |||||
| const dt_uint8* cur_packB = packB; | const dt_uint8* cur_packB = packB; | ||||
| size_t n = 0; | size_t n = 0; | ||||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
| matmul_8x8x8::kern_4x8(packA, cur_packB, K, workspace, 8, | |||||
| is_first_k, std::min<size_t>(M - m, 4), | |||||
| zp_A, zp_B); | |||||
| matmul_8x8x8::kern_4x8( | |||||
| packA, cur_packB, K, workspace, 8, is_first_k, | |||||
| std::min<size_t>(M - m, 4), zp_A, zp_B); | |||||
| #define cb(m, n) \ | #define cb(m, n) \ | ||||
| arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 8, m, n>::postprocess( \ | arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 8, m, n>::postprocess( \ | ||||
| bias, workspace, output, LDC, op); | bias, workspace, output, LDC, op); | ||||
| @@ -193,17 +189,16 @@ struct KernCaller<bmode, Op, 8, 8, false> { | |||||
| } | } | ||||
| for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
| matmul_8x8x8::kern_4x4(packA, cur_packB, K, workspace, 4, | |||||
| is_first_k, std::min<size_t>(M - m, 4), | |||||
| std::min<size_t>(N - n, 4), zp_A, zp_B); | |||||
| matmul_8x8x8::kern_4x4( | |||||
| packA, cur_packB, K, workspace, 4, is_first_k, | |||||
| std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4), zp_A, | |||||
| zp_B); | |||||
| #define cb(m, n) \ | #define cb(m, n) \ | ||||
| arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 4, m, n>::postprocess( \ | arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 4, m, n>::postprocess( \ | ||||
| bias, workspace, output, LDC, op); | bias, workspace, output, LDC, op); | ||||
| DISPATCH_M(cb, std::min<size_t>(M - m, 4), | |||||
| std::min<size_t>(N - n, 4)); | |||||
| DISPATCH_M(cb, std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||||
| #undef cb | #undef cb | ||||
| output += 4; | output += 4; | ||||
| cur_packB += K4; | cur_packB += K4; | ||||
| } | } | ||||
| @@ -219,27 +214,27 @@ struct KernCaller<bmode, Op, 8, 8, false> { | |||||
| #if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_dot_nobias_identity) | MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_dot_nobias_identity) | ||||
| void gemm_u8_8x8_dot_nobias_identity::pack_A(uint8_t* outptr, const uint8_t* inptr, | |||||
| int ldin, int y0, int ymax, int k0, | |||||
| int kmax, bool transpose) const { | |||||
| void gemm_u8_8x8_dot_nobias_identity::pack_A( | |||||
| uint8_t* outptr, const uint8_t* inptr, int ldin, int y0, int ymax, int k0, | |||||
| int kmax, bool transpose) const { | |||||
| if (transpose) { | if (transpose) { | ||||
| matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper(outptr, inptr, ldin, y0, | |||||
| ymax, k0, kmax); | |||||
| matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper( | |||||
| outptr, inptr, ldin, y0, ymax, k0, kmax); | |||||
| } else { | } else { | ||||
| matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(outptr, inptr, ldin, | |||||
| y0, ymax, k0, kmax); | |||||
| matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper( | |||||
| outptr, inptr, ldin, y0, ymax, k0, kmax); | |||||
| } | } | ||||
| } | } | ||||
| void gemm_u8_8x8_dot_nobias_identity::pack_B(uint8_t* out, const uint8_t* in, | |||||
| int ldin, int x0, int xmax, int k0, | |||||
| int kmax, bool transpose) const { | |||||
| void gemm_u8_8x8_dot_nobias_identity::pack_B( | |||||
| uint8_t* out, const uint8_t* in, int ldin, int x0, int xmax, int k0, int kmax, | |||||
| bool transpose) const { | |||||
| if (transpose) { | if (transpose) { | ||||
| matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(out, in, ldin, x0, | |||||
| xmax, k0, kmax); | |||||
| matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper( | |||||
| out, in, ldin, x0, xmax, k0, kmax); | |||||
| } else { | } else { | ||||
| matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper(out, in, ldin, x0, xmax, | |||||
| k0, kmax); | |||||
| matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper( | |||||
| out, in, ldin, x0, xmax, k0, kmax); | |||||
| } | } | ||||
| } | } | ||||
| @@ -249,30 +244,27 @@ size_t gemm_u8_8x8_dot_nobias_identity::get_workspace_size() const { | |||||
| #endif | #endif | ||||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_nodot_nobias_identity) | MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_nodot_nobias_identity) | ||||
| void gemm_u8_8x8_nodot_nobias_identity::pack_A(dt_uint8* outptr, | |||||
| const dt_uint8* inptr, int ldin, | |||||
| int y0, int ymax, int k0, int kmax, | |||||
| bool transpose) const { | |||||
| void gemm_u8_8x8_nodot_nobias_identity::pack_A( | |||||
| dt_uint8* outptr, const dt_uint8* inptr, int ldin, int y0, int ymax, int k0, | |||||
| int kmax, bool transpose) const { | |||||
| uint8_t zA = A_dtype.param<dtype::Quantized8Asymm>().zero_point; | uint8_t zA = A_dtype.param<dtype::Quantized8Asymm>().zero_point; | ||||
| if (transpose) { | if (transpose) { | ||||
| matmul_8x8x8::gemm_u8_8x8_transpose_pack_A_n(outptr, inptr, ldin, y0, | |||||
| ymax, k0, kmax, zA); | |||||
| matmul_8x8x8::gemm_u8_8x8_transpose_pack_A_n( | |||||
| outptr, inptr, ldin, y0, ymax, k0, kmax, zA); | |||||
| } else { | } else { | ||||
| matmul_8x8x8::gemm_u8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, | |||||
| kmax, zA); | |||||
| matmul_8x8x8::gemm_u8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax, zA); | |||||
| } | } | ||||
| } | } | ||||
| void gemm_u8_8x8_nodot_nobias_identity::pack_B(dt_uint8* out, const dt_uint8* in, | |||||
| int ldin, int x0, int xmax, int k0, | |||||
| int kmax, bool transpose) const { | |||||
| void gemm_u8_8x8_nodot_nobias_identity::pack_B( | |||||
| dt_uint8* out, const dt_uint8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||||
| bool transpose) const { | |||||
| uint8_t zB = B_dtype.param<dtype::Quantized8Asymm>().zero_point; | uint8_t zB = B_dtype.param<dtype::Quantized8Asymm>().zero_point; | ||||
| if (transpose) { | if (transpose) { | ||||
| matmul_8x8x8::gemm_u8_8x8_transpose_pack_B_n(out, in, ldin, x0, xmax, | |||||
| k0, kmax, zB); | |||||
| matmul_8x8x8::gemm_u8_8x8_transpose_pack_B_n( | |||||
| out, in, ldin, x0, xmax, k0, kmax, zB); | |||||
| } else { | } else { | ||||
| matmul_8x8x8::gemm_u8_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax, | |||||
| zB); | |||||
| matmul_8x8x8::gemm_u8_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax, zB); | |||||
| } | } | ||||
| } | } | ||||
| @@ -280,22 +272,21 @@ size_t gemm_u8_8x8_nodot_nobias_identity::get_workspace_size() const { | |||||
| return 8 * 8 * sizeof(dt_int32); | return 8 * 8 * sizeof(dt_int32); | ||||
| } | } | ||||
| #define KERN(_block_m, _block_n, _dot, _suffix, _bias, _BIAS, _nonline, \ | |||||
| _OP) \ | |||||
| void gemm_u8_##_block_m##x##_block_n##_suffix##_##_bias##_##_nonline:: \ | |||||
| kern(const dt_uint8* packA, const dt_uint8* packB, size_t M, \ | |||||
| size_t N, size_t K, dt_uint8* C, size_t LDC, bool is_first_k, \ | |||||
| const dt_int32* bias, dt_int32* workspace) const { \ | |||||
| float scale_A = A_dtype.param<dtype::Quantized8Asymm>().scale; \ | |||||
| uint8_t zp_A = A_dtype.param<dtype::Quantized8Asymm>().zero_point; \ | |||||
| float scale_B = B_dtype.param<dtype::Quantized8Asymm>().scale; \ | |||||
| uint8_t zp_B = B_dtype.param<dtype::Quantized8Asymm>().zero_point; \ | |||||
| float scale_C = C_dtype.param<dtype::Quantized8Asymm>().scale; \ | |||||
| uint8_t zp_C = C_dtype.param<dtype::Quantized8Asymm>().zero_point; \ | |||||
| DEFINE_OP(_OP); \ | |||||
| impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n, _dot>::run( \ | |||||
| packA, packB, M, N, K, C, LDC, is_first_k, op, bias, \ | |||||
| workspace, zp_A, zp_B); \ | |||||
| #define KERN(_block_m, _block_n, _dot, _suffix, _bias, _BIAS, _nonline, _OP) \ | |||||
| void gemm_u8_##_block_m##x##_block_n##_suffix##_##_bias##_##_nonline::kern( \ | |||||
| const dt_uint8* packA, const dt_uint8* packB, size_t M, size_t N, \ | |||||
| size_t K, dt_uint8* C, size_t LDC, bool is_first_k, const dt_int32* bias, \ | |||||
| dt_int32* workspace) const { \ | |||||
| float scale_A = A_dtype.param<dtype::Quantized8Asymm>().scale; \ | |||||
| uint8_t zp_A = A_dtype.param<dtype::Quantized8Asymm>().zero_point; \ | |||||
| float scale_B = B_dtype.param<dtype::Quantized8Asymm>().scale; \ | |||||
| uint8_t zp_B = B_dtype.param<dtype::Quantized8Asymm>().zero_point; \ | |||||
| float scale_C = C_dtype.param<dtype::Quantized8Asymm>().scale; \ | |||||
| uint8_t zp_C = C_dtype.param<dtype::Quantized8Asymm>().zero_point; \ | |||||
| DEFINE_OP(_OP); \ | |||||
| impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n, _dot>::run( \ | |||||
| packA, packB, M, N, K, C, LDC, is_first_k, op, bias, workspace, zp_A, \ | |||||
| zp_B); \ | |||||
| } | } | ||||
| #define DEFINE_OP(_Op) \ | #define DEFINE_OP(_Op) \ | ||||
| @@ -311,17 +302,22 @@ KERN(8, 8, false, _nodot, nobias, BiasMode::NO_BIAS, relu, ReluOp) | |||||
| KERN(8, 8, false, _nodot, nobias, BiasMode::NO_BIAS, hswish, HSwishOp) | KERN(8, 8, false, _nodot, nobias, BiasMode::NO_BIAS, hswish, HSwishOp) | ||||
| #undef DEFINE_OP | #undef DEFINE_OP | ||||
| #define DEFINE_OP(_Op) \ | |||||
| arm_common::_Op<dt_qint32, dt_quint8> op(scale_A* scale_B, \ | |||||
| scale_A* scale_B, scale_C, zp_C); | |||||
| #define DEFINE_OP(_Op) \ | |||||
| arm_common::_Op<dt_qint32, dt_quint8> op( \ | |||||
| scale_A* scale_B, scale_A* scale_B, scale_C, zp_C); | |||||
| #if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
| KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) | KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) | ||||
| KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) | |||||
| KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, FuseAddHSwishOp) | |||||
| KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, | |||||
| FuseAddReluOp) | |||||
| KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, | |||||
| FuseAddHSwishOp) | |||||
| #endif | #endif | ||||
| KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) | |||||
| KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) | |||||
| KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, FuseAddHSwishOp) | |||||
| KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, | |||||
| AddOp) | |||||
| KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, | |||||
| FuseAddReluOp) | |||||
| KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, | |||||
| FuseAddHSwishOp) | |||||
| #undef DEFINE_OP | #undef DEFINE_OP | ||||
| #undef KERN | #undef KERN | ||||
| @@ -16,46 +16,44 @@ namespace aarch64 { | |||||
| namespace matmul { | namespace matmul { | ||||
| #if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_uint8, dt_uint8, dt_int32, 8, 8, 4, | |||||
| false, true, | |||||
| gemm_u8_8x8_dot_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK( | |||||
| dt_uint8, dt_uint8, dt_int32, 8, 8, 4, false, true, | |||||
| gemm_u8_8x8_dot_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_nobias_relu, | |||||
| gemm_u8_8x8_dot_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||||
| gemm_u8_8x8_dot_nobias_relu, gemm_u8_8x8_dot_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_nobias_hswish, | |||||
| gemm_u8_8x8_dot_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||||
| gemm_u8_8x8_dot_nobias_hswish, gemm_u8_8x8_dot_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_bias_channel_identity, | |||||
| gemm_u8_8x8_dot_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||||
| gemm_u8_8x8_dot_bias_channel_identity, gemm_u8_8x8_dot_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_bias_channel_relu, | |||||
| gemm_u8_8x8_dot_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_bias_channel_hswish, | |||||
| gemm_u8_8x8_dot_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||||
| gemm_u8_8x8_dot_bias_channel_relu, gemm_u8_8x8_dot_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||||
| gemm_u8_8x8_dot_bias_channel_hswish, gemm_u8_8x8_dot_nobias_identity); | |||||
| #endif | #endif | ||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_uint8, dt_uint8, dt_int32, 8, 8, 8, | |||||
| false, true, | |||||
| gemm_u8_8x8_nodot_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_nobias_relu, | |||||
| gemm_u8_8x8_nodot_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK( | |||||
| dt_uint8, dt_uint8, dt_int32, 8, 8, 8, false, true, | |||||
| gemm_u8_8x8_nodot_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_nobias_hswish, | |||||
| gemm_u8_8x8_nodot_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||||
| gemm_u8_8x8_nodot_nobias_relu, gemm_u8_8x8_nodot_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_bias_channel_identity, | |||||
| gemm_u8_8x8_nodot_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||||
| gemm_u8_8x8_nodot_nobias_hswish, gemm_u8_8x8_nodot_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_bias_channel_relu, | |||||
| gemm_u8_8x8_nodot_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||||
| gemm_u8_8x8_nodot_bias_channel_identity, gemm_u8_8x8_nodot_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_bias_channel_hswish, | |||||
| gemm_u8_8x8_nodot_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||||
| gemm_u8_8x8_nodot_bias_channel_relu, gemm_u8_8x8_nodot_nobias_identity); | |||||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||||
| gemm_u8_8x8_nodot_bias_channel_hswish, gemm_u8_8x8_nodot_nobias_identity); | |||||
| } // namespace matmul | } // namespace matmul | ||||
| } // namespace aarch64 | } // namespace aarch64 | ||||
| @@ -11,11 +11,11 @@ | |||||
| #include "src/common/handle_impl.h" | #include "src/common/handle_impl.h" | ||||
| #include "src/aarch64/conv_bias/opr_impl.h" | |||||
| #include "src/aarch64/handle.h" | #include "src/aarch64/handle.h" | ||||
| #include "src/aarch64/matrix_mul/opr_impl.h" | #include "src/aarch64/matrix_mul/opr_impl.h" | ||||
| #include "src/aarch64/rotate/opr_impl.h" | |||||
| #include "src/aarch64/relayout/opr_impl.h" | #include "src/aarch64/relayout/opr_impl.h" | ||||
| #include "src/aarch64/conv_bias/opr_impl.h" | |||||
| #include "src/aarch64/rotate/opr_impl.h" | |||||
| #include "src/aarch64/warp_perspective/opr_impl.h" | #include "src/aarch64/warp_perspective/opr_impl.h" | ||||
| namespace megdnn { | namespace megdnn { | ||||
| @@ -38,7 +38,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(WarpPerspective) | |||||
| MEGDNN_FOREACH_OPR_CLASS(MEGDNN_INST_CREATE_OPERATOR) | MEGDNN_FOREACH_OPR_CLASS(MEGDNN_INST_CREATE_OPERATOR) | ||||
| #pragma GCC diagnostic pop | #pragma GCC diagnostic pop | ||||
| } // namespace aarch64 | |||||
| } // namespace megdnn | |||||
| } // namespace aarch64 | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -14,20 +14,18 @@ | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace aarch64 { | namespace aarch64 { | ||||
| class HandleImpl: public arm_common::HandleImpl { | |||||
| public: | |||||
| HandleImpl(megcoreComputingHandle_t computing_handle, | |||||
| HandleType type = HandleType::AARCH64): | |||||
| arm_common::HandleImpl::HandleImpl(computing_handle, type) | |||||
| {} | |||||
| class HandleImpl : public arm_common::HandleImpl { | |||||
| public: | |||||
| HandleImpl( | |||||
| megcoreComputingHandle_t computing_handle, | |||||
| HandleType type = HandleType::AARCH64) | |||||
| : arm_common::HandleImpl::HandleImpl(computing_handle, type) {} | |||||
| template <typename Opr> | |||||
| std::unique_ptr<Opr> create_operator(); | |||||
| template <typename Opr> | |||||
| std::unique_ptr<Opr> create_operator(); | |||||
| }; | }; | ||||
| } // namespace aarch64 | |||||
| } // namespace megdnn | |||||
| } // namespace aarch64 | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -21,9 +21,7 @@ namespace aarch64 { | |||||
| class MatrixMulImpl::AlgoF32K8x12x1 final : public AlgoBase { | class MatrixMulImpl::AlgoF32K8x12x1 final : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | |||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| } | |||||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
| const char* name() const override { return "AARCH64_F32K8X12X1"; } | const char* name() const override { return "AARCH64_F32K8X12X1"; } | ||||
| bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
| size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
| @@ -35,8 +33,7 @@ public: | |||||
| class MatrixMulImpl::AlgoF32MK4_8x12x1 final : public AlgoBase { | class MatrixMulImpl::AlgoF32MK4_8x12x1 final : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| return AlgoAttribute::REPRODUCIBLE | | |||||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
| return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
| } | } | ||||
| const char* name() const override { return "AARCH64_F32_MK4_K8X12X1"; } | const char* name() const override { return "AARCH64_F32_MK4_K8X12X1"; } | ||||
| bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
| @@ -48,9 +45,7 @@ public: | |||||
| class MatrixMulImpl::AlgoF32K4x16x1 final : public AlgoBase { | class MatrixMulImpl::AlgoF32K4x16x1 final : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | |||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| } | |||||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
| const char* name() const override { return "AARCH64_F32K4X16X1"; } | const char* name() const override { return "AARCH64_F32K4X16X1"; } | ||||
| bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
| size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
| @@ -61,9 +56,7 @@ public: | |||||
| class MatrixMulImpl::AlgoF32MK4_4x16 final : public AlgoBase { | class MatrixMulImpl::AlgoF32MK4_4x16 final : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | |||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| } | |||||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
| const char* name() const override { return "AARCH64_F32_MK4_4x16"; } | const char* name() const override { return "AARCH64_F32_MK4_4x16"; } | ||||
| bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
| size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
| @@ -73,8 +66,7 @@ public: | |||||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_F32_MK4_4x16) | MEGDNN_DECL_ALGO_TYPE(AARCH64_F32_MK4_4x16) | ||||
| }; | }; | ||||
| class MatrixMulImpl::AlgoF32Gemv final | |||||
| : public arm_common::MatrixMulImpl::AlgoF32Gemv { | |||||
| class MatrixMulImpl::AlgoF32Gemv final : public arm_common::MatrixMulImpl::AlgoF32Gemv { | |||||
| public: | public: | ||||
| AlgoF32Gemv() : arm_common::MatrixMulImpl::AlgoF32Gemv() { | AlgoF32Gemv() : arm_common::MatrixMulImpl::AlgoF32Gemv() { | ||||
| m_handle_type = Handle::HandleType::AARCH64; | m_handle_type = Handle::HandleType::AARCH64; | ||||
| @@ -85,9 +77,7 @@ public: | |||||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
| class MatrixMulImpl::AlgoF16K8x24x1 final : public AlgoBase { | class MatrixMulImpl::AlgoF16K8x24x1 final : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | |||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| } | |||||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
| const char* name() const override { return "AARCH64_F16_K8X24X1"; } | const char* name() const override { return "AARCH64_F16_K8X24X1"; } | ||||
| bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
| size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
| @@ -98,9 +88,7 @@ public: | |||||
| class MatrixMulImpl::AlgoF16MK8_8x8 final : public AlgoBase { | class MatrixMulImpl::AlgoF16MK8_8x8 final : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | |||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| } | |||||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
| const char* name() const override { return "AARCH64_F16_MK8_8X8"; } | const char* name() const override { return "AARCH64_F16_MK8_8X8"; } | ||||
| bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
| size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
| @@ -115,12 +103,8 @@ public: | |||||
| #if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
| class MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd final : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | |||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| } | |||||
| const char* name() const override { | |||||
| return "AARCH64_INT8X8X32_K8X12X4_DOTPROD"; | |||||
| } | |||||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
| const char* name() const override { return "AARCH64_INT8X8X32_K8X12X4_DOTPROD"; } | |||||
| bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
| size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
| kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
| @@ -130,12 +114,8 @@ public: | |||||
| class MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd final : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | |||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| } | |||||
| const char* name() const override { | |||||
| return "AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD"; | |||||
| } | |||||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
| const char* name() const override { return "AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD"; } | |||||
| bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
| size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
| kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
| @@ -147,8 +127,7 @@ public: | |||||
| class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| return AlgoAttribute::REPRODUCIBLE | | |||||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
| return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
| } | } | ||||
| const char* name() const override { return "AARCH64_INT8X8X32_MK4_4X4X16"; } | const char* name() const override { return "AARCH64_INT8X8X32_MK4_4X4X16"; } | ||||
| bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
| @@ -163,9 +142,7 @@ public: | |||||
| class MatrixMulImpl::AlgoInt8x8x32K4x4x16 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32K4x4x16 final : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | |||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| } | |||||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
| const char* name() const override { return "AARCH64_INT8X8X32_K4X4X16"; } | const char* name() const override { return "AARCH64_INT8X8X32_K4X4X16"; } | ||||
| bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
| bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
| @@ -178,9 +155,7 @@ public: | |||||
| class MatrixMulImpl::AlgoInt8x8x32K8x8x8 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32K8x8x8 final : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | |||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| } | |||||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
| const char* name() const override { return "AARCH64_INT8X8X32_K8X8X8"; } | const char* name() const override { return "AARCH64_INT8X8X32_K8X8X8"; } | ||||
| bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
| bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
| @@ -192,9 +167,7 @@ public: | |||||
| class MatrixMulImpl::AlgoInt8x8x16K8x8x8 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x16K8x8x8 final : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | |||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| } | |||||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
| const char* name() const override { return "AARCH64_INT8X8X16_K8X8X8"; } | const char* name() const override { return "AARCH64_INT8X8X16_K8X8X8"; } | ||||
| bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
| bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
| @@ -207,9 +180,7 @@ public: | |||||
| class MatrixMulImpl::AlgoInt8x8x16K4x4x16 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x16K4x4x16 final : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | |||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| } | |||||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
| const char* name() const override { return "AARCH64_INT8X8X16_K4X4X16"; } | const char* name() const override { return "AARCH64_INT8X8X16_K4X4X16"; } | ||||
| bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
| bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
| @@ -222,8 +193,7 @@ public: | |||||
| class MatrixMulImpl::AlgoInt4x4x16K8x8x8 final : public AlgoBase { | class MatrixMulImpl::AlgoInt4x4x16K8x8x8 final : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| return AlgoAttribute::REPRODUCIBLE | | |||||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
| return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
| } | } | ||||
| const char* name() const override { return "AARCH64_INT4X4X16_K8X8X8"; } | const char* name() const override { return "AARCH64_INT4X4X16_K8X8X8"; } | ||||
| bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
| @@ -238,12 +208,9 @@ public: | |||||
| class MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4 final : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| return AlgoAttribute::REPRODUCIBLE | | |||||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
| } | |||||
| const char* name() const override { | |||||
| return "AARCH64_INT8X8X16_MK4_16X12X4"; | |||||
| return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
| } | } | ||||
| const char* name() const override { return "AARCH64_INT8X8X16_MK4_16X12X4"; } | |||||
| bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
| bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
| size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
| @@ -257,12 +224,9 @@ public: | |||||
| class MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8 final : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| return AlgoAttribute::REPRODUCIBLE | | |||||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
| } | |||||
| const char* name() const override { | |||||
| return "AARCH64_INT8X8X16_MK4_K8X8X8"; | |||||
| return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
| } | } | ||||
| const char* name() const override { return "AARCH64_INT8X8X16_MK4_K8X8X8"; } | |||||
| bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
| bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
| size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
| @@ -276,8 +240,7 @@ public: | |||||
| class MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8 final : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| return AlgoAttribute::REPRODUCIBLE | | |||||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
| return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
| } | } | ||||
| const char* name() const override { return "AARCH64_INT8X8X16_MK4_4X4X8"; } | const char* name() const override { return "AARCH64_INT8X8X16_MK4_4X4X8"; } | ||||
| bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
| @@ -292,9 +255,7 @@ public: | |||||
| class MatrixMulImpl::AlgoInt16x16x32K12x8x1 final : public AlgoBase { | class MatrixMulImpl::AlgoInt16x16x32K12x8x1 final : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | |||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| } | |||||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
| const char* name() const override { return "AARCH64_INT16X16X32_K12X8X1"; } | const char* name() const override { return "AARCH64_INT16X16X32_K12X8X1"; } | ||||
| bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
| bool preferred(const KernSizeParam&) const override; | bool preferred(const KernSizeParam&) const override; | ||||
| @@ -306,9 +267,7 @@ public: | |||||
| class MatrixMulImpl::AlgoInt16x16x32MK8_8x8 final : public AlgoBase { | class MatrixMulImpl::AlgoInt16x16x32MK8_8x8 final : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | |||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| } | |||||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
| const char* name() const override { return "AARCH64_INT16X16X32_MK8_8X8"; } | const char* name() const override { return "AARCH64_INT16X16X32_MK8_8X8"; } | ||||
| bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
| size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
| @@ -321,12 +280,8 @@ public: | |||||
| #if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
| class MatrixMulImpl::AlgoQuint8K8x8x4DotProd final : public AlgoBase { | class MatrixMulImpl::AlgoQuint8K8x8x4DotProd final : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | |||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| } | |||||
| const char* name() const override { | |||||
| return "AARCH64_QUINT8_K8X8X4_DOTPROD"; | |||||
| } | |||||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
| const char* name() const override { return "AARCH64_QUINT8_K8X8X4_DOTPROD"; } | |||||
| bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
| size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
| kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
| @@ -336,8 +291,7 @@ public: | |||||
| class MatrixMulImpl::AlgoQuint8GemvDotProd final : public AlgoBase { | class MatrixMulImpl::AlgoQuint8GemvDotProd final : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
| return AlgoAttribute::REPRODUCIBLE | | |||||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
| return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
| } | } | ||||
| const char* name() const override { return "AARCH64_QUINT8_GEMV_DOTPROD"; } | const char* name() const override { return "AARCH64_QUINT8_GEMV_DOTPROD"; } | ||||
| bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
| @@ -352,9 +306,7 @@ public: | |||||
| #endif | #endif | ||||
| class MatrixMulImpl::AlgoQuint8K8x8x8 final : public AlgoBase { | class MatrixMulImpl::AlgoQuint8K8x8x8 final : public AlgoBase { | ||||
| public: | public: | ||||
| AlgoAttribute attribute() const override { | |||||
| return AlgoAttribute::REPRODUCIBLE; | |||||
| } | |||||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||||
| const char* name() const override { return "AARCH64_QUINT8_K8X8X8"; } | const char* name() const override { return "AARCH64_QUINT8_K8X8X8"; } | ||||
| bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
| size_t get_workspace(const KernSizeParam&) const override; | size_t get_workspace(const KernSizeParam&) const override; | ||||
| @@ -16,11 +16,11 @@ namespace megdnn { | |||||
| namespace aarch64 { | namespace aarch64 { | ||||
| namespace matmul { | namespace matmul { | ||||
| MEGDNN_REG_GEMM_STRATEGY(dt_float16, dt_float16, dt_float16, 8, 24, 1, false, | |||||
| true, hgemm_8x24); | |||||
| MEGDNN_REG_GEMM_STRATEGY( | |||||
| dt_float16, dt_float16, dt_float16, 8, 24, 1, false, true, hgemm_8x24); | |||||
| MEGDNN_REG_GEMM_STRATEGY_NOPACK(dt_float16, dt_float16, dt_float16, 8, 8, 1, | |||||
| false, true, gemm_nopack_f16_8x8); | |||||
| MEGDNN_REG_GEMM_STRATEGY_NOPACK( | |||||
| dt_float16, dt_float16, dt_float16, 8, 8, 1, false, true, gemm_nopack_f16_8x8); | |||||
| } // namespace matmul | } // namespace matmul | ||||
| } // namespace aarch64 | } // namespace aarch64 | ||||
| @@ -9,8 +9,8 @@ | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| #include "src/aarch64/matrix_mul/fp16/strategy.h" | |||||
| #include "src/aarch64/matrix_mul/asm/common.h" | #include "src/aarch64/matrix_mul/asm/common.h" | ||||
| #include "src/aarch64/matrix_mul/fp16/strategy.h" | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| @@ -21,8 +21,9 @@ using namespace aarch64::matmul; | |||||
| namespace { | namespace { | ||||
| void kern_8x1(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||||
| dt_float16* output) { | |||||
| void kern_8x1( | |||||
| const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||||
| dt_float16* output) { | |||||
| LDB *= sizeof(dt_float16); | LDB *= sizeof(dt_float16); | ||||
| asm volatile( | asm volatile( | ||||
| ".arch armv8.2-a+fp16\n" | ".arch armv8.2-a+fp16\n" | ||||
| @@ -86,9 +87,8 @@ void kern_8x1(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
| [output] "+r"(output), [LDB] "+r"(LDB) | [output] "+r"(output), [LDB] "+r"(LDB) | ||||
| : | : | ||||
| : "v0", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", | |||||
| "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", | |||||
| "memory"); | |||||
| : "v0", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", | |||||
| "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory"); | |||||
| } | } | ||||
| // Overview of register layout: | // Overview of register layout: | ||||
| @@ -115,8 +115,9 @@ void kern_8x1(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||||
| // |v23[0-7]| |v27[0-7]| | // |v23[0-7]| |v27[0-7]| | ||||
| // +--------+ +--------+ | // +--------+ +--------+ | ||||
| // Accumulator | // Accumulator | ||||
| void kern_8x4(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||||
| dt_float16* output) { | |||||
| void kern_8x4( | |||||
| const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||||
| dt_float16* output) { | |||||
| //! LDB means number of elements in one block in B. we will read 24 numbers | //! LDB means number of elements in one block in B. we will read 24 numbers | ||||
| //! first. so minus 24 * 2 bytes here. | //! first. so minus 24 * 2 bytes here. | ||||
| LDB = (LDB - 24) * sizeof(dt_float16); | LDB = (LDB - 24) * sizeof(dt_float16); | ||||
| @@ -263,8 +264,8 @@ void kern_8x4(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
| [output] "+r"(output), [LDB] "+r"(LDB) | [output] "+r"(output), [LDB] "+r"(LDB) | ||||
| : | : | ||||
| : "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19", "v20", "v21", | |||||
| "v22", "v23", "v24", "v25", "v26", "v27", "cc", "memory"); | |||||
| : "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19", "v20", "v21", "v22", | |||||
| "v23", "v24", "v25", "v26", "v27", "cc", "memory"); | |||||
| } | } | ||||
| // Overview of register layout: | // Overview of register layout: | ||||
| @@ -295,8 +296,9 @@ void kern_8x4(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||||
| // | v7[0-7]| |v31[0-7]| | // | v7[0-7]| |v31[0-7]| | ||||
| // +--------+ +--------+ | // +--------+ +--------+ | ||||
| // Accumulator | // Accumulator | ||||
| void kern_8x8(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||||
| dt_float16* output) { | |||||
| void kern_8x8( | |||||
| const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||||
| dt_float16* output) { | |||||
| //! As each load 128 number from B, but the pos add 112 * 2, so we minus 112 | //! As each load 128 number from B, but the pos add 112 * 2, so we minus 112 | ||||
| //! here. | //! here. | ||||
| LDB = (LDB - 32) * sizeof(dt_float16); | LDB = (LDB - 32) * sizeof(dt_float16); | ||||
| @@ -467,20 +469,19 @@ void kern_8x8(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
| [output] "+r"(output), [LDB] "+r"(LDB) | [output] "+r"(output), [LDB] "+r"(LDB) | ||||
| : | : | ||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
| "v11", "v12", "v13", "v14", "v15", "v24", "v25", "v26", "v27", | |||||
| "v28", "v29", "v30", "v31", "cc", "memory"); | |||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||||
| "v12", "v13", "v14", "v15", "v24", "v25", "v26", "v27", "v28", "v29", | |||||
| "v30", "v31", "cc", "memory"); | |||||
| } | } | ||||
| } // anonymous namespace | } // anonymous namespace | ||||
| MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gemm_nopack_f16_8x8); | MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gemm_nopack_f16_8x8); | ||||
| void gemm_nopack_f16_8x8::kern(const dt_float16* A, size_t LDA, | |||||
| const dt_float16* B, size_t LDB, dt_float16* C, | |||||
| size_t LDC, size_t M, size_t K, size_t N, | |||||
| const dt_float16*, void*, bool trA, | |||||
| bool trB) const { | |||||
| void gemm_nopack_f16_8x8::kern( | |||||
| const dt_float16* A, size_t LDA, const dt_float16* B, size_t LDB, dt_float16* C, | |||||
| size_t LDC, size_t M, size_t K, size_t N, const dt_float16*, void*, bool trA, | |||||
| bool trB) const { | |||||
| constexpr static size_t MB = 8; | constexpr static size_t MB = 8; | ||||
| constexpr static size_t KB = 8; | constexpr static size_t KB = 8; | ||||
| constexpr static size_t NB = 8; | constexpr static size_t NB = 8; | ||||
| @@ -17,21 +17,23 @@ | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace aarch64 { | namespace aarch64 { | ||||
| MEGDNN_NOINLINE void sgemm_packA_n(const float* A, float* Apacked, size_t M, | |||||
| size_t K, size_t LDA, const float* alpha); | |||||
| MEGDNN_NOINLINE void sgemm_packA_n( | |||||
| const float* A, float* Apacked, size_t M, size_t K, size_t LDA, | |||||
| const float* alpha); | |||||
| MEGDNN_NOINLINE void sgemm_packA_t(const float* A, float* Apacked, size_t M, | |||||
| size_t K, size_t LDA, const float* alpha); | |||||
| MEGDNN_NOINLINE void sgemm_packA_t( | |||||
| const float* A, float* Apacked, size_t M, size_t K, size_t LDA, | |||||
| const float* alpha); | |||||
| MEGDNN_NOINLINE void sgemm_packB_n(const float* B, float* Bpacked, size_t K, | |||||
| size_t N, size_t LDB); | |||||
| MEGDNN_NOINLINE void sgemm_packB_n( | |||||
| const float* B, float* Bpacked, size_t K, size_t N, size_t LDB); | |||||
| MEGDNN_NOINLINE void sgemm_packB_t(const float* B, float* Bpacked, size_t K, | |||||
| size_t N, size_t LDB); | |||||
| MEGDNN_NOINLINE void sgemm_packB_t( | |||||
| const float* B, float* Bpacked, size_t K, size_t N, size_t LDB); | |||||
| MEGDNN_NOINLINE void sgemm_kernel12x8(const float* A, const float* B, float* C, | |||||
| size_t LDC, size_t M, size_t N, size_t K, | |||||
| int type, const float* beta); | |||||
| MEGDNN_NOINLINE void sgemm_kernel12x8( | |||||
| const float* A, const float* B, float* C, size_t LDC, size_t M, size_t N, | |||||
| size_t K, int type, const float* beta); | |||||
| } // namespace aarch64 | } // namespace aarch64 | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -12,7 +12,6 @@ | |||||
| #include "src/aarch64/matrix_mul/asm/common.h" | #include "src/aarch64/matrix_mul/asm/common.h" | ||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace aarch64 { | namespace aarch64 { | ||||
| namespace matmul_general_4x16 { | namespace matmul_general_4x16 { | ||||
| @@ -39,8 +38,9 @@ namespace matmul_general_4x16 { | |||||
| // +--+ - - - - +--------+--------+--------+--------+ | // +--+ - - - - +--------+--------+--------+--------+ | ||||
| // | // | ||||
| // Accumulator | // Accumulator | ||||
| void kern_4x16(const float* packA, const float* packB, int K, | |||||
| float* output, int LDC, bool is_first_k, int m_remain) { | |||||
| void kern_4x16( | |||||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||||
| bool is_first_k, int m_remain) { | |||||
| const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
| const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
| int oddk = (K & 1); | int oddk = (K & 1); | ||||
| @@ -224,14 +224,14 @@ void kern_4x16(const float* packA, const float* packB, int K, | |||||
| "6:\n" STORE_C | "6:\n" STORE_C | ||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC), | |||||
| [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||||
| [m_remain] "+r"(m_remain), [outptr] "+r"(outptr) | [m_remain] "+r"(m_remain), [outptr] "+r"(outptr) | ||||
| : | : | ||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||||
| "v20", "v21", "v22", "v23", "v24", "v25", "x1", "x2", "x3", "x9", | |||||
| "x10", "cc", "memory"); | |||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||||
| "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||||
| "v22", "v23", "v24", "v25", "x1", "x2", "x3", "x9", "x10", "cc", | |||||
| "memory"); | |||||
| #undef LOAD_LINE | #undef LOAD_LINE | ||||
| #undef LOAD_C | #undef LOAD_C | ||||
| @@ -263,8 +263,9 @@ void kern_4x16(const float* packA, const float* packB, int K, | |||||
| // +--+--+ - - - - +--------+ | // +--+--+ - - - - +--------+ | ||||
| // | // | ||||
| // Accumulator | // Accumulator | ||||
| void kern_4x4(const float* packA, const float* packB, int K, float* output, | |||||
| int LDC, bool is_first_k, int m_remain, int n_remain) { | |||||
| void kern_4x4( | |||||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||||
| bool is_first_k, int m_remain, int n_remain) { | |||||
| const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
| const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
| int oddk = (K & 1); | int oddk = (K & 1); | ||||
| @@ -330,99 +331,100 @@ void kern_4x4(const float* packA, const float* packB, int K, float* output, | |||||
| STORE_LINE("6", "2") \ | STORE_LINE("6", "2") \ | ||||
| STORE_LINE("7", "3") \ | STORE_LINE("7", "3") \ | ||||
| "105:\n" | "105:\n" | ||||
| // clang-format on | |||||
| asm volatile( | |||||
| // load accumulator C | |||||
| "add x1, x0, %x[LDC]\n" | |||||
| "add x2, x1, %x[LDC]\n" | |||||
| "add x3, x2, %x[LDC]\n" | |||||
| "cmp %w[is_first_k], #1\n" | |||||
| "beq 1f\n" LOAD_C | |||||
| "b 2f\n" | |||||
| "1:\n" | |||||
| "eor v4.16b, v4.16b, v4.16b\n" | |||||
| "eor v5.16b, v5.16b, v5.16b\n" | |||||
| "eor v6.16b, v6.16b, v6.16b\n" | |||||
| "eor v7.16b, v7.16b, v7.16b\n" | |||||
| "2: \n" | |||||
| "ld1 {v0.4s}, [%[a_ptr]], 16\n" | |||||
| "ld1 {v2.4s}, [%[b_ptr]], 16\n" | |||||
| "cmp %w[K], #0\n" | |||||
| "beq 4f\n" | |||||
| "3:\n" | |||||
| "ld1 {v1.4s}, [%[a_ptr]], 16\n" | |||||
| "ld1 {v3.4s}, [%[b_ptr]], 16\n" | |||||
| "fmla v4.4s, v2.4s, v0.s[0]\n" | |||||
| "fmla v5.4s, v2.4s, v0.s[1]\n" | |||||
| "fmla v6.4s, v2.4s, v0.s[2]\n" | |||||
| "fmla v7.4s, v2.4s, v0.s[3]\n" | |||||
| "ld1 {v0.4s}, [%[a_ptr]], 16\n" | |||||
| "ld1 {v2.4s}, [%[b_ptr]], 16\n" | |||||
| "fmla v4.4s, v3.4s, v1.s[0]\n" | |||||
| "fmla v5.4s, v3.4s, v1.s[1]\n" | |||||
| "fmla v6.4s, v3.4s, v1.s[2]\n" | |||||
| "fmla v7.4s, v3.4s, v1.s[3]\n" | |||||
| "subs %w[K], %w[K], #1\n" | |||||
| "bne 3b\n" | |||||
| "4:\n" | |||||
| "cmp %w[oddk], #1\n" | |||||
| "beq 5f\n" | |||||
| // Even tail | |||||
| "ld1 {v1.4s}, [%[a_ptr]], 16\n" | |||||
| "ld1 {v3.4s}, [%[b_ptr]], 16\n" | |||||
| "fmla v4.4s, v2.4s, v0.s[0]\n" | |||||
| "fmla v5.4s, v2.4s, v0.s[1]\n" | |||||
| "fmla v6.4s, v2.4s, v0.s[2]\n" | |||||
| "fmla v7.4s, v2.4s, v0.s[3]\n" | |||||
| "fmla v4.4s, v3.4s, v1.s[0]\n" | |||||
| "fmla v5.4s, v3.4s, v1.s[1]\n" | |||||
| "fmla v6.4s, v3.4s, v1.s[2]\n" | |||||
| "fmla v7.4s, v3.4s, v1.s[3]\n" | |||||
| "b 6f\n" | |||||
| // odd tail | |||||
| "5:\n" | |||||
| "fmla v4.4s, v2.4s, v0.s[0]\n" | |||||
| "fmla v5.4s, v2.4s, v0.s[1]\n" | |||||
| "fmla v6.4s, v2.4s, v0.s[2]\n" | |||||
| "fmla v7.4s, v2.4s, v0.s[3]\n" | |||||
| "6:\n" STORE_C | |||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||||
| [oddk] "+r"(oddk), [m_remain] "+r"(m_remain), | |||||
| [n_remain] "+r"(n_remain), [outptr] "+r"(outptr) | |||||
| : | |||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "x1", | |||||
| "x2", "x3", "x10", "cc", "memory"); | |||||
| // clang-format on | |||||
| asm volatile( | |||||
| // load accumulator C | |||||
| "add x1, x0, %x[LDC]\n" | |||||
| "add x2, x1, %x[LDC]\n" | |||||
| "add x3, x2, %x[LDC]\n" | |||||
| "cmp %w[is_first_k], #1\n" | |||||
| "beq 1f\n" LOAD_C | |||||
| "b 2f\n" | |||||
| "1:\n" | |||||
| "eor v4.16b, v4.16b, v4.16b\n" | |||||
| "eor v5.16b, v5.16b, v5.16b\n" | |||||
| "eor v6.16b, v6.16b, v6.16b\n" | |||||
| "eor v7.16b, v7.16b, v7.16b\n" | |||||
| "2: \n" | |||||
| "ld1 {v0.4s}, [%[a_ptr]], 16\n" | |||||
| "ld1 {v2.4s}, [%[b_ptr]], 16\n" | |||||
| "cmp %w[K], #0\n" | |||||
| "beq 4f\n" | |||||
| "3:\n" | |||||
| "ld1 {v1.4s}, [%[a_ptr]], 16\n" | |||||
| "ld1 {v3.4s}, [%[b_ptr]], 16\n" | |||||
| "fmla v4.4s, v2.4s, v0.s[0]\n" | |||||
| "fmla v5.4s, v2.4s, v0.s[1]\n" | |||||
| "fmla v6.4s, v2.4s, v0.s[2]\n" | |||||
| "fmla v7.4s, v2.4s, v0.s[3]\n" | |||||
| "ld1 {v0.4s}, [%[a_ptr]], 16\n" | |||||
| "ld1 {v2.4s}, [%[b_ptr]], 16\n" | |||||
| "fmla v4.4s, v3.4s, v1.s[0]\n" | |||||
| "fmla v5.4s, v3.4s, v1.s[1]\n" | |||||
| "fmla v6.4s, v3.4s, v1.s[2]\n" | |||||
| "fmla v7.4s, v3.4s, v1.s[3]\n" | |||||
| "subs %w[K], %w[K], #1\n" | |||||
| "bne 3b\n" | |||||
| "4:\n" | |||||
| "cmp %w[oddk], #1\n" | |||||
| "beq 5f\n" | |||||
| // Even tail | |||||
| "ld1 {v1.4s}, [%[a_ptr]], 16\n" | |||||
| "ld1 {v3.4s}, [%[b_ptr]], 16\n" | |||||
| "fmla v4.4s, v2.4s, v0.s[0]\n" | |||||
| "fmla v5.4s, v2.4s, v0.s[1]\n" | |||||
| "fmla v6.4s, v2.4s, v0.s[2]\n" | |||||
| "fmla v7.4s, v2.4s, v0.s[3]\n" | |||||
| "fmla v4.4s, v3.4s, v1.s[0]\n" | |||||
| "fmla v5.4s, v3.4s, v1.s[1]\n" | |||||
| "fmla v6.4s, v3.4s, v1.s[2]\n" | |||||
| "fmla v7.4s, v3.4s, v1.s[3]\n" | |||||
| "b 6f\n" | |||||
| // odd tail | |||||
| "5:\n" | |||||
| "fmla v4.4s, v2.4s, v0.s[0]\n" | |||||
| "fmla v5.4s, v2.4s, v0.s[1]\n" | |||||
| "fmla v6.4s, v2.4s, v0.s[2]\n" | |||||
| "fmla v7.4s, v2.4s, v0.s[3]\n" | |||||
| "6:\n" STORE_C | |||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC), | |||||
| [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||||
| [m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain), | |||||
| [outptr] "+r"(outptr) | |||||
| : | |||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "x1", "x2", "x3", "x10", | |||||
| "cc", "memory"); | |||||
| #undef LOAD_LINE | #undef LOAD_LINE | ||||
| #undef LOAD_C | #undef LOAD_C | ||||
| #undef STORE_LINE | #undef STORE_LINE | ||||
| #undef STORE_C | #undef STORE_C | ||||
| } | } | ||||
| void sgemm_4x16_pack_A_n(float * outptr, const float * inptr, int ldin, int y0, | |||||
| int ymax, int k0, int kmax) { | |||||
| void sgemm_4x16_pack_A_n( | |||||
| float* outptr, const float* inptr, int ldin, int y0, int ymax, int k0, | |||||
| int kmax) { | |||||
| float zerobuff[4]; | float zerobuff[4]; | ||||
| std::memset(zerobuff, 0, sizeof(float) * 4); | std::memset(zerobuff, 0, sizeof(float) * 4); | ||||
| constexpr int PACK_SIZE = 4*4; | |||||
| constexpr int PACK_SIZE = 4 * 4; | |||||
| int y = y0; | int y = y0; | ||||
| for (; y + 3 < ymax; y += 4) { | for (; y + 3 < ymax; y += 4) { | ||||
| // printf("main loop pack_a_n %p \n",outptr); | |||||
| // printf("main loop pack_a_n %p \n",outptr); | |||||
| const float* inptr0 = inptr + y * ldin + k0; | const float* inptr0 = inptr + y * ldin + k0; | ||||
| const float* inptr1 = inptr0 + ldin; | const float* inptr1 = inptr0 + ldin; | ||||
| const float* inptr2 = inptr1 + ldin; | const float* inptr2 = inptr1 + ldin; | ||||
| @@ -459,9 +461,11 @@ void sgemm_4x16_pack_A_n(float * outptr, const float * inptr, int ldin, int y0, | |||||
| switch ((y + 3) - ymax) { | switch ((y + 3) - ymax) { | ||||
| /* Everything falls through in here */ | /* Everything falls through in here */ | ||||
| case 2: | case 2: | ||||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr1 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 1: | case 1: | ||||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr2 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 0: | case 0: | ||||
| inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
| break; | break; | ||||
| @@ -478,9 +482,11 @@ void sgemm_4x16_pack_A_n(float * outptr, const float * inptr, int ldin, int y0, | |||||
| if (y + 3 >= ymax) { | if (y + 3 >= ymax) { | ||||
| switch (y + 3 - ymax) { | switch (y + 3 - ymax) { | ||||
| case 2: | case 2: | ||||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr1 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 1: | case 1: | ||||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr2 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 0: | case 0: | ||||
| inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
| break; | break; | ||||
| @@ -493,8 +499,8 @@ void sgemm_4x16_pack_A_n(float * outptr, const float * inptr, int ldin, int y0, | |||||
| } | } | ||||
| } | } | ||||
| void sgemm_4x16_pack_A_t(float* out, const float* in, int ldin, int x0, | |||||
| int xmax, int k0, int kmax) { | |||||
| void sgemm_4x16_pack_A_t( | |||||
| float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||||
| int ksize = kmax - k0; | int ksize = kmax - k0; | ||||
| int ksize4 = (ksize << 2); | int ksize4 = (ksize << 2); | ||||
| float* outptr_base = out; | float* outptr_base = out; | ||||
| @@ -515,8 +521,7 @@ void sgemm_4x16_pack_A_t(float* out, const float* in, int ldin, int x0, | |||||
| auto outptr = outptr_base; | auto outptr = outptr_base; | ||||
| for (; x + 4 <= xmax; x += 4) { | for (; x + 4 <= xmax; x += 4) { | ||||
| auto outptr_interleave = outptr; | auto outptr_interleave = outptr; | ||||
| interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, | |||||
| outptr_interleave); | |||||
| interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); | |||||
| outptr += ksize4; | outptr += ksize4; | ||||
| } | } | ||||
| @@ -546,8 +551,8 @@ void sgemm_4x16_pack_A_t(float* out, const float* in, int ldin, int x0, | |||||
| } | } | ||||
| } | } | ||||
| void sgemm_4x16_pack_B_n(float* out, const float* in, int ldin, | |||||
| int x0, int xmax, int k0, int kmax) { | |||||
| void sgemm_4x16_pack_B_n( | |||||
| float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||||
| int ksize = kmax - k0; | int ksize = kmax - k0; | ||||
| int ksize16 = ksize * 16; | int ksize16 = ksize * 16; | ||||
| int ksize4 = (ksize << 2); | int ksize4 = (ksize << 2); | ||||
| @@ -570,15 +575,13 @@ void sgemm_4x16_pack_B_n(float* out, const float* in, int ldin, | |||||
| auto outptr = outptr_base; | auto outptr = outptr_base; | ||||
| for (; x + 16 <= xmax; x += 16) { | for (; x + 16 <= xmax; x += 16) { | ||||
| auto outptr_interleave = outptr; | auto outptr_interleave = outptr; | ||||
| interleave_4x16_1_s(inptr, inptr1, inptr2, inptr3, | |||||
| outptr_interleave); | |||||
| interleave_4x16_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); | |||||
| outptr += ksize16; | outptr += ksize16; | ||||
| } | } | ||||
| outptr = outptr_base4; | outptr = outptr_base4; | ||||
| for (; x + 4 <= xmax; x += 4) { | for (; x + 4 <= xmax; x += 4) { | ||||
| auto outptr_interleave = outptr; | auto outptr_interleave = outptr; | ||||
| interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, | |||||
| outptr_interleave); | |||||
| interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); | |||||
| outptr += ksize4; | outptr += ksize4; | ||||
| } | } | ||||
| @@ -616,8 +619,8 @@ void sgemm_4x16_pack_B_n(float* out, const float* in, int ldin, | |||||
| } | } | ||||
| } | } | ||||
| void sgemm_4x16_pack_B_t(float* out, const float* in, int ldin, | |||||
| int y0, int ymax, int k0, int kmax) { | |||||
| void sgemm_4x16_pack_B_t( | |||||
| float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax) { | |||||
| float* outptr = out; | float* outptr = out; | ||||
| const float* inptr = in; | const float* inptr = in; | ||||
| float zerobuff[4]; | float zerobuff[4]; | ||||
| @@ -642,8 +645,7 @@ void sgemm_4x16_pack_B_t(float* out, const float* in, int ldin, | |||||
| int x = (kmax - k0); | int x = (kmax - k0); | ||||
| for (; x > 3; x -= 4) { | for (; x > 3; x -= 4) { | ||||
| transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr_inner, | |||||
| 64); | |||||
| transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr_inner, 64); | |||||
| outptr_inner += 64; | outptr_inner += 64; | ||||
| } | } | ||||
| for (; x > 0; x--) { | for (; x > 0; x--) { | ||||
| @@ -676,9 +678,11 @@ void sgemm_4x16_pack_B_t(float* out, const float* in, int ldin, | |||||
| switch ((y + 3) - ymax) { | switch ((y + 3) - ymax) { | ||||
| /* Everything falls through in here */ | /* Everything falls through in here */ | ||||
| case 2: | case 2: | ||||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr1 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 1: | case 1: | ||||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr2 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 0: | case 0: | ||||
| inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
| break; | break; | ||||
| @@ -696,9 +700,11 @@ void sgemm_4x16_pack_B_t(float* out, const float* in, int ldin, | |||||
| switch ((y + 3) - ymax) { | switch ((y + 3) - ymax) { | ||||
| /* Everything falls through in here */ | /* Everything falls through in here */ | ||||
| case 2: | case 2: | ||||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr1 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 1: | case 1: | ||||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr2 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 0: | case 0: | ||||
| inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
| break; | break; | ||||
| @@ -711,8 +717,8 @@ void sgemm_4x16_pack_B_t(float* out, const float* in, int ldin, | |||||
| } | } | ||||
| } | } | ||||
| } // matmul_general_4x16 | |||||
| } // aarch64 | |||||
| } // megdnn | |||||
| } // namespace matmul_general_4x16 | |||||
| } // namespace aarch64 | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -43,8 +43,9 @@ struct matmul_general_8x12 { | |||||
| // +--+ --- - +--------+--------+--------+ | // +--+ --- - +--------+--------+--------+ | ||||
| // | // | ||||
| // Accumulator | // Accumulator | ||||
| static void kern_8x12(const float* packA, const float* packB, int K, | |||||
| float* output, int LDC, bool is_first_k) { | |||||
| static void kern_8x12( | |||||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||||
| bool is_first_k) { | |||||
| const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
| const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
| int oddk = (K & 1); | int oddk = (K & 1); | ||||
| @@ -306,14 +307,13 @@ struct matmul_general_8x12 { | |||||
| "6:\n" | "6:\n" | ||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||||
| [oddk] "+r"(oddk), [outptr] "+r"(outptr) | |||||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||||
| [outptr] "+r"(outptr) | |||||
| : | : | ||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||||
| "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||||
| "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", | |||||
| "v28", "v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5", | |||||
| "x6", "x7", "cc", "memory"); | |||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", | |||||
| "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", | |||||
| "v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "cc", "memory"); | |||||
| #undef LOAD_LINE | #undef LOAD_LINE | ||||
| #undef LOAD_C | #undef LOAD_C | ||||
| @@ -348,9 +348,9 @@ struct matmul_general_8x12 { | |||||
| // +--+ --- - +--------+ | // +--+ --- - +--------+ | ||||
| // | // | ||||
| // Accumulator | // Accumulator | ||||
| static void kern_8x4(const float* packA, const float* packB, int K, | |||||
| float* output, int LDC, bool is_first_k, | |||||
| int n_remain) { | |||||
| static void kern_8x4( | |||||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||||
| bool is_first_k, int n_remain) { | |||||
| const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
| const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
| int oddk = (K & 1); | int oddk = (K & 1); | ||||
| @@ -520,13 +520,12 @@ struct matmul_general_8x12 { | |||||
| "6:\n" STORE_C | "6:\n" STORE_C | ||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||||
| [oddk] "+r"(oddk), [outptr] "+r"(outptr), | |||||
| [n_remain] "+r"(n_remain) | |||||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||||
| [outptr] "+r"(outptr), [n_remain] "+r"(n_remain) | |||||
| : | : | ||||
| : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20", | |||||
| "v23", "v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7", | |||||
| "cc", "memory"); | |||||
| : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20", "v23", | |||||
| "v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "cc", | |||||
| "memory"); | |||||
| #undef LOAD_LINE | #undef LOAD_LINE | ||||
| #undef LOAD_C | #undef LOAD_C | ||||
| @@ -557,9 +556,9 @@ struct matmul_general_8x12 { | |||||
| // +--+ --- - +--------+--------+--------+ | // +--+ --- - +--------+--------+--------+ | ||||
| // | // | ||||
| // Accumulator | // Accumulator | ||||
| static void kern_4x12(const float* packA, const float* packB, int K, | |||||
| float* output, int LDC, bool is_first_k, | |||||
| int m_remain) { | |||||
| static void kern_4x12( | |||||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||||
| bool is_first_k, int m_remain) { | |||||
| const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
| const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
| int oddk = (K & 1); | int oddk = (K & 1); | ||||
| @@ -717,13 +716,12 @@ struct matmul_general_8x12 { | |||||
| "6:\n" STORE_C | "6:\n" STORE_C | ||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||||
| [oddk] "+r"(oddk), [outptr] "+r"(outptr), | |||||
| [m_remain] "+r"(m_remain) | |||||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||||
| [outptr] "+r"(outptr), [m_remain] "+r"(m_remain) | |||||
| : | : | ||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||||
| "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||||
| "v19", "x1", "x2", "x3", "x10", "cc", "memory"); | |||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", | |||||
| "x2", "x3", "x10", "cc", "memory"); | |||||
| #undef LOAD_LINE | #undef LOAD_LINE | ||||
| #undef LOAD_C | #undef LOAD_C | ||||
| @@ -754,9 +752,9 @@ struct matmul_general_8x12 { | |||||
| // +--+ --- - +--------+ | // +--+ --- - +--------+ | ||||
| // | // | ||||
| // Accumulator | // Accumulator | ||||
| static void kern_4x4(const float* packA, const float* packB, int K, | |||||
| float* output, int LDC, bool is_first_k, int m_remain, | |||||
| int n_remain) { | |||||
| static void kern_4x4( | |||||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||||
| bool is_first_k, int m_remain, int n_remain) { | |||||
| const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
| const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
| int oddk = (K & 1); | int oddk = (K & 1); | ||||
| @@ -895,20 +893,21 @@ struct matmul_general_8x12 { | |||||
| "6:\n" STORE_C | "6:\n" STORE_C | ||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||||
| [oddk] "+r"(oddk), [outptr] "+r"(outptr), | |||||
| [n_remain] "+r"(n_remain), [m_remain] "+r"(m_remain) | |||||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||||
| [outptr] "+r"(outptr), [n_remain] "+r"(n_remain), | |||||
| [m_remain] "+r"(m_remain) | |||||
| : | : | ||||
| : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", | |||||
| "x3", "x10", "cc", "memory"); | |||||
| : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", "x3", | |||||
| "x10", "cc", "memory"); | |||||
| #undef LOAD_LINE | #undef LOAD_LINE | ||||
| #undef LOAD_C | #undef LOAD_C | ||||
| #undef STORE_LINE | #undef STORE_LINE | ||||
| #undef STORE_C | #undef STORE_C | ||||
| } | } | ||||
| static void sgemm_8x12_pack_A_n(float* outptr, const float* inptr, int ldin, | |||||
| int y0, int ymax, int k0, int kmax) { | |||||
| static void sgemm_8x12_pack_A_n( | |||||
| float* outptr, const float* inptr, int ldin, int y0, int ymax, int k0, | |||||
| int kmax) { | |||||
| float zerobuff[8]; | float zerobuff[8]; | ||||
| std::memset(zerobuff, 0, sizeof(float) * 8); | std::memset(zerobuff, 0, sizeof(float) * 8); | ||||
| constexpr int PACK_SIZE_32 = 4 * 8; | constexpr int PACK_SIZE_32 = 4 * 8; | ||||
| @@ -933,8 +932,9 @@ struct matmul_general_8x12 { | |||||
| prefetch_2x(inptr7); | prefetch_2x(inptr7); | ||||
| int x = (kmax - k0); | int x = (kmax - k0); | ||||
| for (; x > 3; x -= 4) { | for (; x > 3; x -= 4) { | ||||
| transpose_8x4_1_s(inptr0, inptr1, inptr2, inptr3, inptr4, | |||||
| inptr5, inptr6, inptr7, outptr); | |||||
| transpose_8x4_1_s( | |||||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
| outptr); | |||||
| outptr += PACK_SIZE_32; | outptr += PACK_SIZE_32; | ||||
| } | } | ||||
| for (; x > 0; x--) { | for (; x > 0; x--) { | ||||
| @@ -1004,8 +1004,8 @@ struct matmul_general_8x12 { | |||||
| } | } | ||||
| } | } | ||||
| static void sgemm_8x12_pack_A_t(float* out, const float* in, int ldin, | |||||
| int x0, int xmax, int k0, int kmax) { | |||||
| static void sgemm_8x12_pack_A_t( | |||||
| float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||||
| int ksize = kmax - k0; | int ksize = kmax - k0; | ||||
| int ksize8 = (ksize << 3); | int ksize8 = (ksize << 3); | ||||
| int ksize4 = (ksize << 2); | int ksize4 = (ksize << 2); | ||||
| @@ -1028,20 +1028,17 @@ struct matmul_general_8x12 { | |||||
| auto outptr = outptr_base; | auto outptr = outptr_base; | ||||
| for (; x + 8 <= xmax; x += 8) { | for (; x + 8 <= xmax; x += 8) { | ||||
| auto outptr_interleave = outptr; | auto outptr_interleave = outptr; | ||||
| interleave_4x8_1_s(inptr, inptr1, inptr2, inptr3, | |||||
| outptr_interleave); | |||||
| interleave_4x8_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); | |||||
| outptr += ksize8; | outptr += ksize8; | ||||
| } | } | ||||
| outptr = outptr_base4; | outptr = outptr_base4; | ||||
| for (; x + 4 <= xmax; x += 4) { | for (; x + 4 <= xmax; x += 4) { | ||||
| auto outptr_interleave = outptr; | auto outptr_interleave = outptr; | ||||
| interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, | |||||
| outptr_interleave); | |||||
| interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); | |||||
| outptr += ksize4; | outptr += ksize4; | ||||
| } | } | ||||
| if (x < xmax) { | if (x < xmax) { | ||||
| interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, | |||||
| xmax - x); | |||||
| interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, xmax - x); | |||||
| } | } | ||||
| outptr_base += 4 * 8; | outptr_base += 4 * 8; | ||||
| outptr_base4 += 4 * 4; | outptr_base4 += 4 * 4; | ||||
| @@ -1071,8 +1068,8 @@ struct matmul_general_8x12 { | |||||
| } | } | ||||
| } | } | ||||
| static void sgemm_8x12_pack_B_n(float* out, const float* in, int ldin, | |||||
| int x0, int xmax, int k0, int kmax) { | |||||
| static void sgemm_8x12_pack_B_n( | |||||
| float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||||
| int ksize = kmax - k0; | int ksize = kmax - k0; | ||||
| int ksize12 = ksize * 12; | int ksize12 = ksize * 12; | ||||
| int ksize4 = (ksize << 2); | int ksize4 = (ksize << 2); | ||||
| @@ -1095,20 +1092,17 @@ struct matmul_general_8x12 { | |||||
| auto outptr = outptr_base; | auto outptr = outptr_base; | ||||
| for (; x + 12 <= xmax; x += 12) { | for (; x + 12 <= xmax; x += 12) { | ||||
| auto outptr_interleave = outptr; | auto outptr_interleave = outptr; | ||||
| interleave_4x12_1_s(inptr, inptr1, inptr2, inptr3, | |||||
| outptr_interleave); | |||||
| interleave_4x12_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); | |||||
| outptr += ksize12; | outptr += ksize12; | ||||
| } | } | ||||
| outptr = outptr_base4; | outptr = outptr_base4; | ||||
| for (; x + 4 <= xmax; x += 4) { | for (; x + 4 <= xmax; x += 4) { | ||||
| auto outptr_interleave = outptr; | auto outptr_interleave = outptr; | ||||
| interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, | |||||
| outptr_interleave); | |||||
| interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); | |||||
| outptr += ksize4; | outptr += ksize4; | ||||
| } | } | ||||
| if (x < xmax) { | if (x < xmax) { | ||||
| interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, | |||||
| xmax - x); | |||||
| interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, xmax - x); | |||||
| } | } | ||||
| outptr_base += 12 * 4; | outptr_base += 12 * 4; | ||||
| outptr_base4 += 4 * 4; | outptr_base4 += 4 * 4; | ||||
| @@ -1138,8 +1132,8 @@ struct matmul_general_8x12 { | |||||
| } | } | ||||
| } | } | ||||
| static void sgemm_8x12_pack_B_t(float* out, const float* in, int ldin, | |||||
| int y0, int ymax, int k0, int kmax) { | |||||
| static void sgemm_8x12_pack_B_t( | |||||
| float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax) { | |||||
| float* outptr = out; | float* outptr = out; | ||||
| const float* inptr = in; | const float* inptr = in; | ||||
| float zerobuff[12]; | float zerobuff[12]; | ||||
| @@ -1172,9 +1166,9 @@ struct matmul_general_8x12 { | |||||
| prefetch_2x(inptr11); | prefetch_2x(inptr11); | ||||
| int x = (kmax - k0); | int x = (kmax - k0); | ||||
| for (; x > 3; x -= 4) { | for (; x > 3; x -= 4) { | ||||
| transpose_12x4_1_s(inptr0, inptr1, inptr2, inptr3, inptr4, | |||||
| inptr5, inptr6, inptr7, inptr8, inptr9, | |||||
| inptr10, inptr11, outptr); | |||||
| transpose_12x4_1_s( | |||||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
| inptr8, inptr9, inptr10, inptr11, outptr); | |||||
| outptr += 48; | outptr += 48; | ||||
| } | } | ||||
| for (; x > 0; x--) { | for (; x > 0; x--) { | ||||
| @@ -43,8 +43,9 @@ struct matmul_general_8x12_a53 { | |||||
| // +--+ --- - +--------+--------+--------+ | // +--+ --- - +--------+--------+--------+ | ||||
| // | // | ||||
| // Accumulator | // Accumulator | ||||
| static void kern_8x12(const float* packA, const float* packB, int K, | |||||
| float* output, int LDC, bool is_first_k) { | |||||
| static void kern_8x12( | |||||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||||
| bool is_first_k) { | |||||
| const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
| const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
| int oddk = (K & 1); | int oddk = (K & 1); | ||||
| @@ -575,15 +576,14 @@ struct matmul_general_8x12_a53 { | |||||
| "6:\n" | "6:\n" | ||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||||
| [oddk] "+r"(oddk), [outptr] "+r"(outptr) | |||||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||||
| [outptr] "+r"(outptr) | |||||
| : | : | ||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||||
| "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||||
| "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", | |||||
| "v28", "v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5", | |||||
| "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", | |||||
| "memory"); | |||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", | |||||
| "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", | |||||
| "v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", | |||||
| "x11", "x12", "x13", "cc", "memory"); | |||||
| #undef LOAD_LINE | #undef LOAD_LINE | ||||
| #undef LOAD_C | #undef LOAD_C | ||||
| } | } | ||||
| @@ -615,9 +615,9 @@ struct matmul_general_8x12_a53 { | |||||
| // +--+ --- - +--------+ | // +--+ --- - +--------+ | ||||
| // | // | ||||
| // Accumulator | // Accumulator | ||||
| static void kern_8x4(const float* packA, const float* packB, int K, | |||||
| float* output, int LDC, bool is_first_k, | |||||
| int n_remain) { | |||||
| static void kern_8x4( | |||||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||||
| bool is_first_k, int n_remain) { | |||||
| const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
| const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
| int oddk = (K & 1); | int oddk = (K & 1); | ||||
| @@ -856,13 +856,12 @@ struct matmul_general_8x12_a53 { | |||||
| "6:\n" STORE_C | "6:\n" STORE_C | ||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||||
| [oddk] "+r"(oddk), [outptr] "+r"(outptr), | |||||
| [n_remain] "+r"(n_remain) | |||||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||||
| [outptr] "+r"(outptr), [n_remain] "+r"(n_remain) | |||||
| : | : | ||||
| : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20", | |||||
| "v23", "v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7", | |||||
| "x8", "x9", "x10", "cc", "memory"); | |||||
| : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20", "v23", | |||||
| "v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", | |||||
| "x10", "cc", "memory"); | |||||
| #undef LOAD_LINE | #undef LOAD_LINE | ||||
| #undef LOAD_C | #undef LOAD_C | ||||
| @@ -893,9 +892,9 @@ struct matmul_general_8x12_a53 { | |||||
| // +--+ --- - +--------+--------+--------+ | // +--+ --- - +--------+--------+--------+ | ||||
| // | // | ||||
| // Accumulator | // Accumulator | ||||
| static void kern_4x12(const float* packA, const float* packB, int K, | |||||
| float* output, int LDC, bool is_first_k, | |||||
| int m_remain) { | |||||
| static void kern_4x12( | |||||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||||
| bool is_first_k, int m_remain) { | |||||
| const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
| const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
| int oddk = (K & 1); | int oddk = (K & 1); | ||||
| @@ -1133,14 +1132,12 @@ struct matmul_general_8x12_a53 { | |||||
| "6:\n" STORE_C | "6:\n" STORE_C | ||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||||
| [oddk] "+r"(oddk), [outptr] "+r"(outptr), | |||||
| [m_remain] "+r"(m_remain) | |||||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||||
| [outptr] "+r"(outptr), [m_remain] "+r"(m_remain) | |||||
| : | : | ||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||||
| "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||||
| "v19", "x1", "x2", "x3", "x8", "x9", "x10", "x20", "x21", | |||||
| "x22", "cc", "memory"); | |||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", | |||||
| "x2", "x3", "x8", "x9", "x10", "x20", "x21", "x22", "cc", "memory"); | |||||
| #undef LOAD_LINE | #undef LOAD_LINE | ||||
| #undef LOAD_C | #undef LOAD_C | ||||
| @@ -1171,9 +1168,9 @@ struct matmul_general_8x12_a53 { | |||||
| // +--+ --- - +--------+ | // +--+ --- - +--------+ | ||||
| // | // | ||||
| // Accumulator | // Accumulator | ||||
| static void kern_4x4(const float* packA, const float* packB, int K, | |||||
| float* output, int LDC, bool is_first_k, int m_remain, | |||||
| int n_remain) { | |||||
| static void kern_4x4( | |||||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||||
| bool is_first_k, int m_remain, int n_remain) { | |||||
| const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
| const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
| int oddk = (K & 1); | int oddk = (K & 1); | ||||
| @@ -1312,12 +1309,12 @@ struct matmul_general_8x12_a53 { | |||||
| "6:\n" STORE_C | "6:\n" STORE_C | ||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||||
| [oddk] "+r"(oddk), [outptr] "+r"(outptr), | |||||
| [n_remain] "+r"(n_remain), [m_remain] "+r"(m_remain) | |||||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||||
| [outptr] "+r"(outptr), [n_remain] "+r"(n_remain), | |||||
| [m_remain] "+r"(m_remain) | |||||
| : | : | ||||
| : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", | |||||
| "x3", "x10", "cc", "memory"); | |||||
| : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", "x3", | |||||
| "x10", "cc", "memory"); | |||||
| #undef LOAD_LINE | #undef LOAD_LINE | ||||
| #undef LOAD_C | #undef LOAD_C | ||||
| #undef STORE_LINE | #undef STORE_LINE | ||||
| @@ -43,8 +43,9 @@ struct matmul_general_8x12_a55 { | |||||
| // +--+ --- - +--------+--------+--------+ | // +--+ --- - +--------+--------+--------+ | ||||
| // | // | ||||
| // Accumulator | // Accumulator | ||||
| static void kern_8x12(const float* packA, const float* packB, int K, | |||||
| float* output, int LDC, bool is_first_k) { | |||||
| static void kern_8x12( | |||||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||||
| bool is_first_k) { | |||||
| const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
| const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
| int oddk = (K & 1); | int oddk = (K & 1); | ||||
| @@ -525,15 +526,14 @@ struct matmul_general_8x12_a55 { | |||||
| "6:\n" | "6:\n" | ||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||||
| [oddk] "+r"(oddk), [outptr] "+r"(outptr) | |||||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||||
| [outptr] "+r"(outptr) | |||||
| : | : | ||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||||
| "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||||
| "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", | |||||
| "v28", "v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5", | |||||
| "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", | |||||
| "memory"); | |||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", | |||||
| "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", | |||||
| "v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", | |||||
| "x11", "x12", "x13", "cc", "memory"); | |||||
| #undef LOAD_LINE | #undef LOAD_LINE | ||||
| #undef LOAD_C | #undef LOAD_C | ||||
| } | } | ||||
| @@ -565,9 +565,9 @@ struct matmul_general_8x12_a55 { | |||||
| // +--+ --- - +--------+ | // +--+ --- - +--------+ | ||||
| // | // | ||||
| // Accumulator | // Accumulator | ||||
| static void kern_8x4(const float* packA, const float* packB, int K, | |||||
| float* output, int LDC, bool is_first_k, | |||||
| int n_remain) { | |||||
| static void kern_8x4( | |||||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||||
| bool is_first_k, int n_remain) { | |||||
| const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
| const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
| int oddk = (K & 1); | int oddk = (K & 1); | ||||
| @@ -742,13 +742,12 @@ struct matmul_general_8x12_a55 { | |||||
| "6:\n" STORE_C | "6:\n" STORE_C | ||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||||
| [oddk] "+r"(oddk), [outptr] "+r"(outptr), | |||||
| [n_remain] "+r"(n_remain) | |||||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||||
| [outptr] "+r"(outptr), [n_remain] "+r"(n_remain) | |||||
| : | : | ||||
| : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20", | |||||
| "v23", "v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7", | |||||
| "x10", "cc", "memory"); | |||||
| : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20", "v23", | |||||
| "v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x10", "cc", | |||||
| "memory"); | |||||
| #undef LOAD_LINE | #undef LOAD_LINE | ||||
| #undef LOAD_C | #undef LOAD_C | ||||
| @@ -779,9 +778,9 @@ struct matmul_general_8x12_a55 { | |||||
| // +--+ --- - +--------+--------+--------+ | // +--+ --- - +--------+--------+--------+ | ||||
| // | // | ||||
| // Accumulator | // Accumulator | ||||
| static void kern_4x12(const float* packA, const float* packB, int K, | |||||
| float* output, int LDC, bool is_first_k, | |||||
| int m_remain) { | |||||
| static void kern_4x12( | |||||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||||
| bool is_first_k, int m_remain) { | |||||
| const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
| const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
| int oddk = (K & 1); | int oddk = (K & 1); | ||||
| @@ -972,14 +971,12 @@ struct matmul_general_8x12_a55 { | |||||
| "6:\n" STORE_C | "6:\n" STORE_C | ||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||||
| [oddk] "+r"(oddk), [outptr] "+r"(outptr), | |||||
| [m_remain] "+r"(m_remain) | |||||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||||
| [outptr] "+r"(outptr), [m_remain] "+r"(m_remain) | |||||
| : | : | ||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||||
| "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||||
| "v19", "x1", "x2", "x3", "x10", "x20", "x21", "x22", "cc", | |||||
| "memory"); | |||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", | |||||
| "x2", "x3", "x10", "x20", "x21", "x22", "cc", "memory"); | |||||
| #undef LOAD_LINE | #undef LOAD_LINE | ||||
| #undef LOAD_C | #undef LOAD_C | ||||
| @@ -1010,9 +1007,9 @@ struct matmul_general_8x12_a55 { | |||||
| // +--+ --- - +--------+ | // +--+ --- - +--------+ | ||||
| // | // | ||||
| // Accumulator | // Accumulator | ||||
| static void kern_4x4(const float* packA, const float* packB, int K, | |||||
| float* output, int LDC, bool is_first_k, int m_remain, | |||||
| int n_remain) { | |||||
| static void kern_4x4( | |||||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||||
| bool is_first_k, int m_remain, int n_remain) { | |||||
| const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
| const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
| int oddk = (K & 1); | int oddk = (K & 1); | ||||
| @@ -1151,12 +1148,12 @@ struct matmul_general_8x12_a55 { | |||||
| "6:\n" STORE_C | "6:\n" STORE_C | ||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||||
| [oddk] "+r"(oddk), [outptr] "+r"(outptr), | |||||
| [n_remain] "+r"(n_remain), [m_remain] "+r"(m_remain) | |||||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||||
| [outptr] "+r"(outptr), [n_remain] "+r"(n_remain), | |||||
| [m_remain] "+r"(m_remain) | |||||
| : | : | ||||
| : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", | |||||
| "x3", "x10", "cc", "memory"); | |||||
| : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", "x3", | |||||
| "x10", "cc", "memory"); | |||||
| #undef LOAD_LINE | #undef LOAD_LINE | ||||
| #undef LOAD_C | #undef LOAD_C | ||||
| #undef STORE_LINE | #undef STORE_LINE | ||||
| @@ -44,8 +44,9 @@ struct matmul_mk4_8x12 { | |||||
| // +--+ --- - +--------+--------+--------+ | // +--+ --- - +--------+--------+--------+ | ||||
| // | // | ||||
| // Accumulator | // Accumulator | ||||
| static void kern_8x12(const float* packA, const float* packB, int K, | |||||
| float* output, int LDC, bool is_first_k) { | |||||
| static void kern_8x12( | |||||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||||
| bool is_first_k) { | |||||
| const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
| const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
| float* output0 = output; | float* output0 = output; | ||||
| @@ -307,10 +308,10 @@ struct matmul_mk4_8x12 { | |||||
| [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | ||||
| [output0] "+r"(output0), [output1] "+r"(output1) | [output0] "+r"(output0), [output1] "+r"(output1) | ||||
| : | : | ||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||||
| "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||||
| "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", | |||||
| "v28", "v29", "v30", "v31", "x1", "x2", "cc", "memory"); | |||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", | |||||
| "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", | |||||
| "v31", "x1", "x2", "cc", "memory"); | |||||
| } | } | ||||
| // Overview of register layout: | // Overview of register layout: | ||||
| @@ -340,9 +341,9 @@ struct matmul_mk4_8x12 { | |||||
| // +--+ --- - +--------+ | // +--+ --- - +--------+ | ||||
| // | // | ||||
| // Accumulator | // Accumulator | ||||
| static void kern_8x4(const float* packA, const float* packB, int K, | |||||
| float* output, int LDC, bool is_first_k, | |||||
| int n_remain) { | |||||
| static void kern_8x4( | |||||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||||
| bool is_first_k, int n_remain) { | |||||
| const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
| const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
| float* output0 = output; | float* output0 = output; | ||||
| @@ -500,8 +501,8 @@ struct matmul_mk4_8x12 { | |||||
| [output0] "+r"(output0), [output1] "+r"(output1), | [output0] "+r"(output0), [output1] "+r"(output1), | ||||
| [n_remain] "+r"(n_remain) | [n_remain] "+r"(n_remain) | ||||
| : | : | ||||
| : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", | |||||
| "v13", "v14", "v15", "cc", "memory"); | |||||
| : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||||
| "v15", "cc", "memory"); | |||||
| #undef LOAD_C | #undef LOAD_C | ||||
| #undef STORE_C | #undef STORE_C | ||||
| @@ -531,8 +532,9 @@ struct matmul_mk4_8x12 { | |||||
| // | // | ||||
| // Accumulator | // Accumulator | ||||
| static void kern_4x12(const float* packA, const float* packB, int K, | |||||
| float* output, int LDC, bool is_first_k) { | |||||
| static void kern_4x12( | |||||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||||
| bool is_first_k) { | |||||
| MEGDNN_MARK_USED_VAR(LDC); | MEGDNN_MARK_USED_VAR(LDC); | ||||
| const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
| const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
| @@ -669,9 +671,9 @@ struct matmul_mk4_8x12 { | |||||
| [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | ||||
| [output0] "+r"(output0) | [output0] "+r"(output0) | ||||
| : | : | ||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||||
| "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||||
| "v19", "x1", "cc", "memory"); | |||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", | |||||
| "cc", "memory"); | |||||
| } | } | ||||
| // Overview of register layout: | // Overview of register layout: | ||||
| @@ -697,9 +699,9 @@ struct matmul_mk4_8x12 { | |||||
| // +--+ --- - +--------+ | // +--+ --- - +--------+ | ||||
| // | // | ||||
| // Accumulator | // Accumulator | ||||
| static void kern_4x4(const float* packA, const float* packB, int K, | |||||
| float* output, int LDC, bool is_first_k, | |||||
| int n_remain) { | |||||
| static void kern_4x4( | |||||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||||
| bool is_first_k, int n_remain) { | |||||
| MEGDNN_MARK_USED_VAR(LDC); | MEGDNN_MARK_USED_VAR(LDC); | ||||
| const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
| const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
| @@ -818,15 +820,15 @@ struct matmul_mk4_8x12 { | |||||
| [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | ||||
| [output0] "+r"(output0), [n_remain] "+r"(n_remain) | [output0] "+r"(output0), [n_remain] "+r"(n_remain) | ||||
| : | : | ||||
| : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc", | |||||
| "memory"); | |||||
| : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc", "memory"); | |||||
| #undef LOAD_C | #undef LOAD_C | ||||
| #undef STORE_C | #undef STORE_C | ||||
| } | } | ||||
| static void sgemm_8x12_pack_A(float* outptr, const float* inptr, int ldin, | |||||
| int y0, int ymax, int k0, int kmax) { | |||||
| static void sgemm_8x12_pack_A( | |||||
| float* outptr, const float* inptr, int ldin, int y0, int ymax, int k0, | |||||
| int kmax) { | |||||
| megdnn_assert(y0 % 4 == 0 && ymax % 4 == 0, "M must be time of 4"); | megdnn_assert(y0 % 4 == 0 && ymax % 4 == 0, "M must be time of 4"); | ||||
| megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | ||||
| constexpr int PACK_SIZE_32 = 4 * 8; | constexpr int PACK_SIZE_32 = 4 * 8; | ||||
| @@ -855,8 +857,8 @@ struct matmul_mk4_8x12 { | |||||
| } | } | ||||
| } | } | ||||
| static void sgemm_8x12_pack_B(float* out, const float* in, int ldin, int x0, | |||||
| int xmax, int k0, int kmax) { | |||||
| static void sgemm_8x12_pack_B( | |||||
| float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||||
| megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | ||||
| float tmpbuff[16] = {0.0f}; | float tmpbuff[16] = {0.0f}; | ||||
| @@ -886,8 +888,7 @@ struct matmul_mk4_8x12 { | |||||
| outptr += ksize4; | outptr += ksize4; | ||||
| } | } | ||||
| if (x < xmax) { | if (x < xmax) { | ||||
| std::memcpy(tmpbuff, inptr, | |||||
| sizeof(float) * (xmax - x) * PACK_C_SIZE); | |||||
| std::memcpy(tmpbuff, inptr, sizeof(float) * (xmax - x) * PACK_C_SIZE); | |||||
| auto outptr_interleave = outptr; | auto outptr_interleave = outptr; | ||||
| const float* tmp_ptr = &tmpbuff[0]; | const float* tmp_ptr = &tmpbuff[0]; | ||||
| transpose_1x4_4_s<float>(tmp_ptr, outptr_interleave); | transpose_1x4_4_s<float>(tmp_ptr, outptr_interleave); | ||||
| @@ -44,8 +44,9 @@ struct matmul_mk4_8x12_a53 { | |||||
| // +--+ --- - +--------+--------+--------+ | // +--+ --- - +--------+--------+--------+ | ||||
| // | // | ||||
| // Accumulator | // Accumulator | ||||
| static void kern_8x12(const float* packA, const float* packB, int K, | |||||
| float* output, int LDC, bool is_first_k) { | |||||
| static void kern_8x12( | |||||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||||
| bool is_first_k) { | |||||
| const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
| const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
| float* output0 = output; | float* output0 = output; | ||||
| @@ -553,11 +554,11 @@ struct matmul_mk4_8x12_a53 { | |||||
| [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | ||||
| [output0] "+r"(output0), [output1] "+r"(output1) | [output0] "+r"(output0), [output1] "+r"(output1) | ||||
| : | : | ||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||||
| "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||||
| "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", | |||||
| "v28", "v29", "v30", "v31", "x1", "x2", "x8", "x9", "x10", | |||||
| "x11", "x12", "x13", "cc", "memory"); | |||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", | |||||
| "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", | |||||
| "v31", "x1", "x2", "x8", "x9", "x10", "x11", "x12", "x13", "cc", | |||||
| "memory"); | |||||
| } | } | ||||
| // Overview of register layout: | // Overview of register layout: | ||||
| @@ -587,9 +588,9 @@ struct matmul_mk4_8x12_a53 { | |||||
| // +--+ --- - +--------+ | // +--+ --- - +--------+ | ||||
| // | // | ||||
| // Accumulator | // Accumulator | ||||
| static void kern_8x4(const float* packA, const float* packB, int K, | |||||
| float* output, int LDC, bool is_first_k, | |||||
| int n_remain) { | |||||
| static void kern_8x4( | |||||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||||
| bool is_first_k, int n_remain) { | |||||
| const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
| const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
| float* output0 = output; | float* output0 = output; | ||||
| @@ -831,8 +832,8 @@ struct matmul_mk4_8x12_a53 { | |||||
| [output0] "+r"(output0), [output1] "+r"(output1), | [output0] "+r"(output0), [output1] "+r"(output1), | ||||
| [n_remain] "+r"(n_remain) | [n_remain] "+r"(n_remain) | ||||
| : | : | ||||
| : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", | |||||
| "v13", "v14", "v15", "x8", "x9", "x10", "cc", "memory"); | |||||
| : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||||
| "v15", "x8", "x9", "x10", "cc", "memory"); | |||||
| #undef LOAD_C | #undef LOAD_C | ||||
| #undef STORE_C | #undef STORE_C | ||||
| @@ -862,8 +863,9 @@ struct matmul_mk4_8x12_a53 { | |||||
| // | // | ||||
| // Accumulator | // Accumulator | ||||
| static void kern_4x12(const float* packA, const float* packB, int K, | |||||
| float* output, int LDC, bool is_first_k) { | |||||
| static void kern_4x12( | |||||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||||
| bool is_first_k) { | |||||
| MEGDNN_MARK_USED_VAR(LDC); | MEGDNN_MARK_USED_VAR(LDC); | ||||
| const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
| const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
| @@ -1098,9 +1100,9 @@ struct matmul_mk4_8x12_a53 { | |||||
| [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | ||||
| [output0] "+r"(output0) | [output0] "+r"(output0) | ||||
| : | : | ||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||||
| "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||||
| "v19", "x1", "x8", "x9", "x10", "x11", "x12", "cc", "memory"); | |||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", | |||||
| "x8", "x9", "x10", "x11", "x12", "cc", "memory"); | |||||
| } | } | ||||
| // Overview of register layout: | // Overview of register layout: | ||||
| @@ -1126,9 +1128,9 @@ struct matmul_mk4_8x12_a53 { | |||||
| // +--+ --- - +--------+ | // +--+ --- - +--------+ | ||||
| // | // | ||||
| // Accumulator | // Accumulator | ||||
| static void kern_4x4(const float* packA, const float* packB, int K, | |||||
| float* output, int LDC, bool is_first_k, | |||||
| int n_remain) { | |||||
| static void kern_4x4( | |||||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||||
| bool is_first_k, int n_remain) { | |||||
| MEGDNN_MARK_USED_VAR(LDC); | MEGDNN_MARK_USED_VAR(LDC); | ||||
| const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
| const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
| @@ -1246,8 +1248,7 @@ struct matmul_mk4_8x12_a53 { | |||||
| [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | ||||
| [output0] "+r"(output0), [n_remain] "+r"(n_remain) | [output0] "+r"(output0), [n_remain] "+r"(n_remain) | ||||
| : | : | ||||
| : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc", | |||||
| "memory"); | |||||
| : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc", "memory"); | |||||
| #undef LOAD_C | #undef LOAD_C | ||||
| #undef STORE_C | #undef STORE_C | ||||
| @@ -44,8 +44,9 @@ struct matmul_mk4_8x12_a55 { | |||||
| // +--+ --- - +--------+--------+--------+ | // +--+ --- - +--------+--------+--------+ | ||||
| // | // | ||||
| // Accumulator | // Accumulator | ||||
| static void kern_8x12(const float* packA, const float* packB, int K, | |||||
| float* output, int LDC, bool is_first_k) { | |||||
| static void kern_8x12( | |||||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||||
| bool is_first_k) { | |||||
| const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
| const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
| float* output0 = output; | float* output0 = output; | ||||
| @@ -519,11 +520,11 @@ struct matmul_mk4_8x12_a55 { | |||||
| [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | ||||
| [output0] "+r"(output0), [output1] "+r"(output1) | [output0] "+r"(output0), [output1] "+r"(output1) | ||||
| : | : | ||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||||
| "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||||
| "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", | |||||
| "v28", "v29", "v30", "v31", "x1", "x2", "x8", "x9", "x10", | |||||
| "x11", "x12", "x13", "cc", "memory"); | |||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", | |||||
| "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", | |||||
| "v31", "x1", "x2", "x8", "x9", "x10", "x11", "x12", "x13", "cc", | |||||
| "memory"); | |||||
| } | } | ||||
| // Overview of register layout: | // Overview of register layout: | ||||
| @@ -553,9 +554,9 @@ struct matmul_mk4_8x12_a55 { | |||||
| // +--+ --- - +--------+ | // +--+ --- - +--------+ | ||||
| // | // | ||||
| // Accumulator | // Accumulator | ||||
| static void kern_8x4(const float* packA, const float* packB, int K, | |||||
| float* output, int LDC, bool is_first_k, | |||||
| int n_remain) { | |||||
| static void kern_8x4( | |||||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||||
| bool is_first_k, int n_remain) { | |||||
| const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
| const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
| float* output0 = output; | float* output0 = output; | ||||
| @@ -749,8 +750,8 @@ struct matmul_mk4_8x12_a55 { | |||||
| [output0] "+r"(output0), [output1] "+r"(output1), | [output0] "+r"(output0), [output1] "+r"(output1), | ||||
| [n_remain] "+r"(n_remain) | [n_remain] "+r"(n_remain) | ||||
| : | : | ||||
| : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", | |||||
| "v13", "v14", "v15", "x8", "x9", "x10", "cc", "memory"); | |||||
| : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||||
| "v15", "x8", "x9", "x10", "cc", "memory"); | |||||
| #undef LOAD_C | #undef LOAD_C | ||||
| #undef STORE_C | #undef STORE_C | ||||
| @@ -780,8 +781,9 @@ struct matmul_mk4_8x12_a55 { | |||||
| // | // | ||||
| // Accumulator | // Accumulator | ||||
| static void kern_4x12(const float* packA, const float* packB, int K, | |||||
| float* output, int LDC, bool is_first_k) { | |||||
| static void kern_4x12( | |||||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||||
| bool is_first_k) { | |||||
| MEGDNN_MARK_USED_VAR(LDC); | MEGDNN_MARK_USED_VAR(LDC); | ||||
| const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
| const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
| @@ -997,9 +999,9 @@ struct matmul_mk4_8x12_a55 { | |||||
| [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | ||||
| [output0] "+r"(output0) | [output0] "+r"(output0) | ||||
| : | : | ||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||||
| "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||||
| "v19", "x1", "x8", "x9", "x10", "x11", "x12", "cc", "memory"); | |||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", | |||||
| "x8", "x9", "x10", "x11", "x12", "cc", "memory"); | |||||
| } | } | ||||
| // Overview of register layout: | // Overview of register layout: | ||||
| @@ -1025,9 +1027,9 @@ struct matmul_mk4_8x12_a55 { | |||||
| // +--+ --- - +--------+ | // +--+ --- - +--------+ | ||||
| // | // | ||||
| // Accumulator | // Accumulator | ||||
| static void kern_4x4(const float* packA, const float* packB, int K, | |||||
| float* output, int LDC, bool is_first_k, | |||||
| int n_remain) { | |||||
| static void kern_4x4( | |||||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||||
| bool is_first_k, int n_remain) { | |||||
| MEGDNN_MARK_USED_VAR(LDC); | MEGDNN_MARK_USED_VAR(LDC); | ||||
| const float* a_ptr = packA; | const float* a_ptr = packA; | ||||
| const float* b_ptr = packB; | const float* b_ptr = packB; | ||||
| @@ -1146,8 +1148,7 @@ struct matmul_mk4_8x12_a55 { | |||||
| [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | ||||
| [output0] "+r"(output0), [n_remain] "+r"(n_remain) | [output0] "+r"(output0), [n_remain] "+r"(n_remain) | ||||
| : | : | ||||
| : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc", | |||||
| "memory"); | |||||
| : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc", "memory"); | |||||
| #undef LOAD_C | #undef LOAD_C | ||||
| #undef STORE_C | #undef STORE_C | ||||
| @@ -10,6 +10,7 @@ | |||||
| * implied. | * implied. | ||||
| */ | */ | ||||
| #include "src/aarch64/matrix_mul/fp32/strategy.h" | |||||
| #include "src/aarch64/matrix_mul/fp32/kernel_general_4x16.h" | #include "src/aarch64/matrix_mul/fp32/kernel_general_4x16.h" | ||||
| #include "src/aarch64/matrix_mul/fp32/kernel_general_8x12.h" | #include "src/aarch64/matrix_mul/fp32/kernel_general_8x12.h" | ||||
| #include "src/aarch64/matrix_mul/fp32/kernel_general_8x12_a53.h" | #include "src/aarch64/matrix_mul/fp32/kernel_general_8x12_a53.h" | ||||
| @@ -17,44 +18,40 @@ | |||||
| #include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h" | #include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h" | ||||
| #include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a53.h" | #include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a53.h" | ||||
| #include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a55.h" | #include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a55.h" | ||||
| #include "src/aarch64/matrix_mul/fp32/strategy.h" | |||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| using namespace megdnn; | using namespace megdnn; | ||||
| using namespace aarch64; | using namespace aarch64; | ||||
| using namespace aarch64::matmul; | using namespace aarch64::matmul; | ||||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_4x16); | MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_4x16); | ||||
| void sgemm_4x16::pack_A(float* out, const float* in, int ldin, int y0, int ymax, | |||||
| int k0, int kmax, bool transpose_A) const { | |||||
| void sgemm_4x16::pack_A( | |||||
| float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax, | |||||
| bool transpose_A) const { | |||||
| if (transpose_A) { | if (transpose_A) { | ||||
| matmul_general_4x16::sgemm_4x16_pack_A_t(out, in, ldin, y0, ymax, k0, | |||||
| kmax); | |||||
| matmul_general_4x16::sgemm_4x16_pack_A_t(out, in, ldin, y0, ymax, k0, kmax); | |||||
| } else { | } else { | ||||
| matmul_general_4x16::sgemm_4x16_pack_A_n(out, in, ldin, y0, ymax, k0, | |||||
| kmax); | |||||
| matmul_general_4x16::sgemm_4x16_pack_A_n(out, in, ldin, y0, ymax, k0, kmax); | |||||
| } | } | ||||
| } | } | ||||
| void sgemm_4x16::pack_B(float* out, const float* in, int ldin, int x0, int xmax, | |||||
| int k0, int kmax, bool transpose_B) const { | |||||
| void sgemm_4x16::pack_B( | |||||
| float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax, | |||||
| bool transpose_B) const { | |||||
| if (transpose_B) { | if (transpose_B) { | ||||
| matmul_general_4x16::sgemm_4x16_pack_B_t(out, in, ldin, x0, xmax, k0, | |||||
| kmax); | |||||
| matmul_general_4x16::sgemm_4x16_pack_B_t(out, in, ldin, x0, xmax, k0, kmax); | |||||
| } else { | } else { | ||||
| matmul_general_4x16::sgemm_4x16_pack_B_n(out, in, ldin, x0, xmax, k0, | |||||
| kmax); | |||||
| matmul_general_4x16::sgemm_4x16_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | |||||
| } | } | ||||
| } | } | ||||
| void sgemm_4x16::kern(const float* packA, const float* packB, size_t M, | |||||
| size_t N, size_t K, float* C, size_t LDC, bool is_first_k, | |||||
| const float*, float*) const { | |||||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
| A_dtype.enumv() == C_dtype.enumv() && | |||||
| A_dtype.enumv() == DTypeEnum::Float32); | |||||
| void sgemm_4x16::kern( | |||||
| const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C, | |||||
| size_t LDC, bool is_first_k, const float*, float*) const { | |||||
| megdnn_assert( | |||||
| A_dtype.enumv() == B_dtype.enumv() && A_dtype.enumv() == C_dtype.enumv() && | |||||
| A_dtype.enumv() == DTypeEnum::Float32); | |||||
| MEGDNN_MARK_USED_VAR(A_dtype); | MEGDNN_MARK_USED_VAR(A_dtype); | ||||
| MEGDNN_MARK_USED_VAR(B_dtype); | MEGDNN_MARK_USED_VAR(B_dtype); | ||||
| MEGDNN_MARK_USED_VAR(C_dtype); | MEGDNN_MARK_USED_VAR(C_dtype); | ||||
| @@ -71,9 +68,9 @@ void sgemm_4x16::kern(const float* packA, const float* packB, size_t M, | |||||
| size_t n = 0; | size_t n = 0; | ||||
| const float* cur_packB = packB; | const float* cur_packB = packB; | ||||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
| matmul_general_4x16::kern_4x16(packA, cur_packB, K, output, LDC, | |||||
| is_first_k, | |||||
| std::min<size_t>(M - m, 4)); | |||||
| matmul_general_4x16::kern_4x16( | |||||
| packA, cur_packB, K, output, LDC, is_first_k, | |||||
| std::min<size_t>(M - m, 4)); | |||||
| output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
| cur_packB += K16; | cur_packB += K16; | ||||
| } | } | ||||
| @@ -92,32 +89,30 @@ void sgemm_4x16::kern(const float* packA, const float* packB, size_t M, | |||||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_8x12); | MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_8x12); | ||||
| void sgemm_8x12::pack_A(float* out, const float* in, int ldin, int y0, int ymax, | |||||
| int k0, int kmax, bool transpose_A) const { | |||||
| void sgemm_8x12::pack_A( | |||||
| float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax, | |||||
| bool transpose_A) const { | |||||
| if (transpose_A) { | if (transpose_A) { | ||||
| matmul_general_8x12::sgemm_8x12_pack_A_t(out, in, ldin, y0, ymax, k0, | |||||
| kmax); | |||||
| matmul_general_8x12::sgemm_8x12_pack_A_t(out, in, ldin, y0, ymax, k0, kmax); | |||||
| } else { | } else { | ||||
| matmul_general_8x12::sgemm_8x12_pack_A_n(out, in, ldin, y0, ymax, k0, | |||||
| kmax); | |||||
| matmul_general_8x12::sgemm_8x12_pack_A_n(out, in, ldin, y0, ymax, k0, kmax); | |||||
| } | } | ||||
| } | } | ||||
| void sgemm_8x12::pack_B(float* out, const float* in, int ldin, int x0, int xmax, | |||||
| int k0, int kmax, bool transpose_B) const { | |||||
| void sgemm_8x12::pack_B( | |||||
| float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax, | |||||
| bool transpose_B) const { | |||||
| if (transpose_B) { | if (transpose_B) { | ||||
| matmul_general_8x12::sgemm_8x12_pack_B_t(out, in, ldin, x0, xmax, k0, | |||||
| kmax); | |||||
| matmul_general_8x12::sgemm_8x12_pack_B_t(out, in, ldin, x0, xmax, k0, kmax); | |||||
| } else { | } else { | ||||
| matmul_general_8x12::sgemm_8x12_pack_B_n(out, in, ldin, x0, xmax, k0, | |||||
| kmax); | |||||
| matmul_general_8x12::sgemm_8x12_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | |||||
| } | } | ||||
| } | } | ||||
| template <typename gemm_class> | template <typename gemm_class> | ||||
| static inline void sgemm_8x12_helper(const float* packA, const float* packB, | |||||
| size_t M, size_t N, size_t K, float* C, | |||||
| size_t LDC, bool is_first_k) { | |||||
| static inline void sgemm_8x12_helper( | |||||
| const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C, | |||||
| size_t LDC, bool is_first_k) { | |||||
| constexpr size_t A_INTERLEAVE = 8; | constexpr size_t A_INTERLEAVE = 8; | ||||
| constexpr size_t A_INTERLEAVE4 = 4; | constexpr size_t A_INTERLEAVE4 = 4; | ||||
| constexpr size_t B_INTERLEAVE = 12; | constexpr size_t B_INTERLEAVE = 12; | ||||
| @@ -138,8 +133,9 @@ static inline void sgemm_8x12_helper(const float* packA, const float* packB, | |||||
| } | } | ||||
| for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
| gemm_class::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k, | |||||
| std::min<size_t>(N - n, 4)); | |||||
| gemm_class::kern_8x4( | |||||
| packA, cur_packB, K, output, LDC, is_first_k, | |||||
| std::min<size_t>(N - n, 4)); | |||||
| output += 4; | output += 4; | ||||
| cur_packB += K4; | cur_packB += K4; | ||||
| } | } | ||||
| @@ -150,16 +146,17 @@ static inline void sgemm_8x12_helper(const float* packA, const float* packB, | |||||
| size_t n = 0; | size_t n = 0; | ||||
| const float* cur_packB = packB; | const float* cur_packB = packB; | ||||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
| gemm_class::kern_4x12(packA, cur_packB, K, output, LDC, is_first_k, | |||||
| std::min<size_t>(M - m, 4)); | |||||
| gemm_class::kern_4x12( | |||||
| packA, cur_packB, K, output, LDC, is_first_k, | |||||
| std::min<size_t>(M - m, 4)); | |||||
| output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
| cur_packB += K12; | cur_packB += K12; | ||||
| } | } | ||||
| for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
| gemm_class::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, | |||||
| std::min<size_t>(M - m, 4), | |||||
| std::min<size_t>(N - n, 4)); | |||||
| gemm_class::kern_4x4( | |||||
| packA, cur_packB, K, output, LDC, is_first_k, | |||||
| std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||||
| output += 4; | output += 4; | ||||
| cur_packB += K4; | cur_packB += K4; | ||||
| } | } | ||||
| @@ -167,56 +164,55 @@ static inline void sgemm_8x12_helper(const float* packA, const float* packB, | |||||
| } | } | ||||
| } | } | ||||
| void sgemm_8x12::kern(const float* packA, const float* packB, size_t M, | |||||
| size_t N, size_t K, float* C, size_t LDC, bool is_first_k, | |||||
| const float*, float*) const { | |||||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
| A_dtype.enumv() == C_dtype.enumv() && | |||||
| A_dtype.enumv() == DTypeEnum::Float32); | |||||
| void sgemm_8x12::kern( | |||||
| const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C, | |||||
| size_t LDC, bool is_first_k, const float*, float*) const { | |||||
| megdnn_assert( | |||||
| A_dtype.enumv() == B_dtype.enumv() && A_dtype.enumv() == C_dtype.enumv() && | |||||
| A_dtype.enumv() == DTypeEnum::Float32); | |||||
| MEGDNN_MARK_USED_VAR(A_dtype); | MEGDNN_MARK_USED_VAR(A_dtype); | ||||
| MEGDNN_MARK_USED_VAR(B_dtype); | MEGDNN_MARK_USED_VAR(B_dtype); | ||||
| MEGDNN_MARK_USED_VAR(C_dtype); | MEGDNN_MARK_USED_VAR(C_dtype); | ||||
| #if !MGB_ENABLE_CPUINFO | #if !MGB_ENABLE_CPUINFO | ||||
| sgemm_8x12_helper<matmul_general_8x12>(packA, packB, M, N, K, C, LDC, | |||||
| is_first_k); | |||||
| sgemm_8x12_helper<matmul_general_8x12>(packA, packB, M, N, K, C, LDC, is_first_k); | |||||
| #else | #else | ||||
| auto arch = cpuinfo_get_current_core()->uarch; | auto arch = cpuinfo_get_current_core()->uarch; | ||||
| #ifdef __IN_TEE_ENV__ | #ifdef __IN_TEE_ENV__ | ||||
| arch = cpuinfo_uarch_unknown; | arch = cpuinfo_uarch_unknown; | ||||
| #endif | #endif | ||||
| if (arch == cpuinfo_uarch_cortex_a53) { | if (arch == cpuinfo_uarch_cortex_a53) { | ||||
| sgemm_8x12_helper<matmul_general_8x12_a53>(packA, packB, M, N, K, C, | |||||
| LDC, is_first_k); | |||||
| sgemm_8x12_helper<matmul_general_8x12_a53>( | |||||
| packA, packB, M, N, K, C, LDC, is_first_k); | |||||
| } else if (arch == cpuinfo_uarch_cortex_a55) { | } else if (arch == cpuinfo_uarch_cortex_a55) { | ||||
| sgemm_8x12_helper<matmul_general_8x12_a55>(packA, packB, M, N, K, C, | |||||
| LDC, is_first_k); | |||||
| sgemm_8x12_helper<matmul_general_8x12_a55>( | |||||
| packA, packB, M, N, K, C, LDC, is_first_k); | |||||
| } else { | } else { | ||||
| sgemm_8x12_helper<matmul_general_8x12>(packA, packB, M, N, K, C, LDC, | |||||
| is_first_k); | |||||
| sgemm_8x12_helper<matmul_general_8x12>( | |||||
| packA, packB, M, N, K, C, LDC, is_first_k); | |||||
| } | } | ||||
| #endif | #endif | ||||
| } | } | ||||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_mk4_8x12); | MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_mk4_8x12); | ||||
| void sgemm_mk4_8x12::pack_A(float* out, const float* in, int ldin, int y0, | |||||
| int ymax, int k0, int kmax, | |||||
| bool transpose_A) const { | |||||
| void sgemm_mk4_8x12::pack_A( | |||||
| float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax, | |||||
| bool transpose_A) const { | |||||
| megdnn_assert(!transpose_A, "mk4 float matmul not support transpose A"); | megdnn_assert(!transpose_A, "mk4 float matmul not support transpose A"); | ||||
| matmul_mk4_8x12::sgemm_8x12_pack_A(out, in, ldin, y0, ymax, k0, kmax); | matmul_mk4_8x12::sgemm_8x12_pack_A(out, in, ldin, y0, ymax, k0, kmax); | ||||
| } | } | ||||
| void sgemm_mk4_8x12::pack_B(float* out, const float* in, int ldin, int x0, | |||||
| int xmax, int k0, int kmax, | |||||
| bool transpose_B) const { | |||||
| void sgemm_mk4_8x12::pack_B( | |||||
| float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax, | |||||
| bool transpose_B) const { | |||||
| megdnn_assert(!transpose_B, "mk4 float matmul not support transpose B"); | megdnn_assert(!transpose_B, "mk4 float matmul not support transpose B"); | ||||
| matmul_mk4_8x12::sgemm_8x12_pack_B(out, in, ldin, x0, xmax, k0, kmax); | matmul_mk4_8x12::sgemm_8x12_pack_B(out, in, ldin, x0, xmax, k0, kmax); | ||||
| } | } | ||||
| template <typename gemm_name> | template <typename gemm_name> | ||||
| static inline void sgemm_mk4_8x12_helper(const float* packA, const float* packB, | |||||
| size_t M, size_t N, size_t K, float* C, | |||||
| size_t LDC, bool is_first_k) { | |||||
| static inline void sgemm_mk4_8x12_helper( | |||||
| const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C, | |||||
| size_t LDC, bool is_first_k) { | |||||
| const int K12 = K * 12; | const int K12 = K * 12; | ||||
| const int K8 = K * 8; | const int K8 = K * 8; | ||||
| const int K4 = K * 4; | const int K4 = K * 4; | ||||
| @@ -237,8 +233,9 @@ static inline void sgemm_mk4_8x12_helper(const float* packA, const float* packB, | |||||
| } | } | ||||
| for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
| gemm_name::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k, | |||||
| std::min<size_t>(N - n, 4)); | |||||
| gemm_name::kern_8x4( | |||||
| packA, cur_packB, K, output, LDC, is_first_k, | |||||
| std::min<size_t>(N - n, 4)); | |||||
| output += 4 * PACK_C_SIZE; | output += 4 * PACK_C_SIZE; | ||||
| cur_packB += K4; | cur_packB += K4; | ||||
| } | } | ||||
| @@ -254,41 +251,41 @@ static inline void sgemm_mk4_8x12_helper(const float* packA, const float* packB, | |||||
| cur_packB += K12; | cur_packB += K12; | ||||
| } | } | ||||
| for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
| gemm_name::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, | |||||
| std::min<size_t>(N - n, 4)); | |||||
| gemm_name::kern_4x4( | |||||
| packA, cur_packB, K, output, LDC, is_first_k, | |||||
| std::min<size_t>(N - n, 4)); | |||||
| output += 4 * PACK_C_SIZE; | output += 4 * PACK_C_SIZE; | ||||
| cur_packB += K4; | cur_packB += K4; | ||||
| } | } | ||||
| packA += K4; | packA += K4; | ||||
| } | } | ||||
| } | } | ||||
| void sgemm_mk4_8x12::kern(const float* packA, const float* packB, size_t M, | |||||
| size_t N, size_t K, float* C, size_t LDC, | |||||
| bool is_first_k, const float*, float*) const { | |||||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
| A_dtype.enumv() == C_dtype.enumv() && | |||||
| A_dtype.enumv() == DTypeEnum::Float32); | |||||
| void sgemm_mk4_8x12::kern( | |||||
| const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C, | |||||
| size_t LDC, bool is_first_k, const float*, float*) const { | |||||
| megdnn_assert( | |||||
| A_dtype.enumv() == B_dtype.enumv() && A_dtype.enumv() == C_dtype.enumv() && | |||||
| A_dtype.enumv() == DTypeEnum::Float32); | |||||
| MEGDNN_MARK_USED_VAR(A_dtype); | MEGDNN_MARK_USED_VAR(A_dtype); | ||||
| MEGDNN_MARK_USED_VAR(B_dtype); | MEGDNN_MARK_USED_VAR(B_dtype); | ||||
| MEGDNN_MARK_USED_VAR(C_dtype); | MEGDNN_MARK_USED_VAR(C_dtype); | ||||
| megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4"); | megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4"); | ||||
| #if !MGB_ENABLE_CPUINFO | #if !MGB_ENABLE_CPUINFO | ||||
| sgemm_mk4_8x12_helper<matmul_mk4_8x12>(packA, packB, M, N, K, C, LDC, | |||||
| is_first_k); | |||||
| sgemm_mk4_8x12_helper<matmul_mk4_8x12>(packA, packB, M, N, K, C, LDC, is_first_k); | |||||
| #else | #else | ||||
| auto arch = cpuinfo_get_current_core()->uarch; | auto arch = cpuinfo_get_current_core()->uarch; | ||||
| #ifdef __IN_TEE_ENV__ | #ifdef __IN_TEE_ENV__ | ||||
| arch = cpuinfo_uarch_unknown; | arch = cpuinfo_uarch_unknown; | ||||
| #endif | #endif | ||||
| if (arch == cpuinfo_uarch_cortex_a53) { | if (arch == cpuinfo_uarch_cortex_a53) { | ||||
| sgemm_mk4_8x12_helper<matmul_mk4_8x12_a53>(packA, packB, M, N, K, C, | |||||
| LDC, is_first_k); | |||||
| sgemm_mk4_8x12_helper<matmul_mk4_8x12_a53>( | |||||
| packA, packB, M, N, K, C, LDC, is_first_k); | |||||
| } else if (arch == cpuinfo_uarch_cortex_a55) { | } else if (arch == cpuinfo_uarch_cortex_a55) { | ||||
| sgemm_mk4_8x12_helper<matmul_mk4_8x12_a55>(packA, packB, M, N, K, C, | |||||
| LDC, is_first_k); | |||||
| sgemm_mk4_8x12_helper<matmul_mk4_8x12_a55>( | |||||
| packA, packB, M, N, K, C, LDC, is_first_k); | |||||
| } else { | } else { | ||||
| sgemm_mk4_8x12_helper<matmul_mk4_8x12>(packA, packB, M, N, K, C, LDC, | |||||
| is_first_k); | |||||
| sgemm_mk4_8x12_helper<matmul_mk4_8x12>( | |||||
| packA, packB, M, N, K, C, LDC, is_first_k); | |||||
| } | } | ||||
| #endif | #endif | ||||
| } | } | ||||
| @@ -15,17 +15,14 @@ | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace aarch64 { | namespace aarch64 { | ||||
| namespace matmul { | namespace matmul { | ||||
| MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, true, | |||||
| sgemm_8x12); | |||||
| MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, true, sgemm_8x12); | |||||
| MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 16, 1, false, true, | |||||
| sgemm_4x16); | |||||
| MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 16, 1, false, true, sgemm_4x16); | |||||
| MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, false, | |||||
| sgemm_mk4_8x12); | |||||
| MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, false, sgemm_mk4_8x12); | |||||
| MEGDNN_REG_GEMM_STRATEGY_NOPACK(float, float, float, 4, 16, 1, false, true, | |||||
| sgemm_nopack_4x16); | |||||
| MEGDNN_REG_GEMM_STRATEGY_NOPACK( | |||||
| float, float, float, 4, 16, 1, false, true, sgemm_nopack_4x16); | |||||
| } // namespace matmul | } // namespace matmul | ||||
| } // namespace aarch64 | } // namespace aarch64 | ||||
| @@ -20,8 +20,8 @@ using namespace aarch64::matmul; | |||||
| namespace { | namespace { | ||||
| void kern_4x1(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||||
| float* output) { | |||||
| void kern_4x1( | |||||
| const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, float* output) { | |||||
| LDB *= sizeof(float); | LDB *= sizeof(float); | ||||
| asm volatile( | asm volatile( | ||||
| "subs %w[K], %w[K], #4\n" | "subs %w[K], %w[K], #4\n" | ||||
| @@ -64,8 +64,7 @@ void kern_4x1(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
| [output] "+r"(output), [LDB] "+r"(LDB) | [output] "+r"(output), [LDB] "+r"(LDB) | ||||
| : | : | ||||
| : "v0", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "cc", | |||||
| "memory"); | |||||
| : "v0", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "cc", "memory"); | |||||
| } | } | ||||
| // Overview of register layout: | // Overview of register layout: | ||||
| @@ -89,8 +88,8 @@ void kern_4x1(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||||
| // +--------+ - - - - -+--------+ | // +--------+ - - - - -+--------+ | ||||
| // Accumulator | // Accumulator | ||||
| void kern_4x4(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||||
| float* output) { | |||||
| void kern_4x4( | |||||
| const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, float* output) { | |||||
| //! As each load 16 number from B, but the pos add 12 * 4, so we minus 12 | //! As each load 16 number from B, but the pos add 12 * 4, so we minus 12 | ||||
| //! here. | //! here. | ||||
| LDB = (LDB - 12) * sizeof(float); | LDB = (LDB - 12) * sizeof(float); | ||||
| @@ -165,8 +164,8 @@ void kern_4x4(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
| [output] "+r"(output), [LDB] "+r"(LDB) | [output] "+r"(output), [LDB] "+r"(LDB) | ||||
| : | : | ||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", | |||||
| "v18", "v19", "cc", "memory"); | |||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", | |||||
| "v19", "cc", "memory"); | |||||
| } | } | ||||
| // Overview of register layout: | // Overview of register layout: | ||||
| @@ -195,8 +194,8 @@ void kern_4x4(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||||
| // +--------+ - - - - -+--------+ | // +--------+ - - - - -+--------+ | ||||
| // Accumulator | // Accumulator | ||||
| void kern_4x8(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||||
| float* output) { | |||||
| void kern_4x8( | |||||
| const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, float* output) { | |||||
| //! As each load 32 number from B, but the pos add 24 * 4, so we minus 24 | //! As each load 32 number from B, but the pos add 24 * 4, so we minus 24 | ||||
| //! here. | //! here. | ||||
| LDB = (LDB - 24) * sizeof(float); | LDB = (LDB - 24) * sizeof(float); | ||||
| @@ -304,9 +303,9 @@ void kern_4x8(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
| [output] "+r"(output), [LDB] "+r"(LDB) | [output] "+r"(output), [LDB] "+r"(LDB) | ||||
| : | : | ||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", | |||||
| "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", | |||||
| "v27", "cc", "memory"); | |||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", | |||||
| "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "cc", | |||||
| "memory"); | |||||
| } | } | ||||
| // Overview of register layout: | // Overview of register layout: | ||||
| @@ -342,8 +341,7 @@ void kern_4x8(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||||
| // +--------+ | // +--------+ | ||||
| // Accumulator | // Accumulator | ||||
| void kern_4x16(const float* a_ptr, const float* b_ptr, int LDB, int K, | |||||
| float* output) { | |||||
| void kern_4x16(const float* a_ptr, const float* b_ptr, int LDB, int K, float* output) { | |||||
| //! As each load 64 number from B, but the pos add 56 * 4, so we minus 56 | //! As each load 64 number from B, but the pos add 56 * 4, so we minus 56 | ||||
| //! here. | //! here. | ||||
| LDB = (LDB - 56) * sizeof(float); | LDB = (LDB - 56) * sizeof(float); | ||||
| @@ -565,20 +563,18 @@ void kern_4x16(const float* a_ptr, const float* b_ptr, int LDB, int K, | |||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
| [output] "+r"(output), [LDB] "+r"(LDB) | [output] "+r"(output), [LDB] "+r"(LDB) | ||||
| : | : | ||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
| "v11", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", | |||||
| "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", | |||||
| "memory"); | |||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||||
| "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", | |||||
| "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory"); | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(sgemm_nopack_4x16); | MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(sgemm_nopack_4x16); | ||||
| void sgemm_nopack_4x16::kern(const float* A, size_t LDA, const float* B, | |||||
| size_t LDB, float* C, size_t LDC, size_t M, | |||||
| size_t K, size_t N, const float*, void*, bool trA, | |||||
| bool trB) const { | |||||
| void sgemm_nopack_4x16::kern( | |||||
| const float* A, size_t LDA, const float* B, size_t LDB, float* C, size_t LDC, | |||||
| size_t M, size_t K, size_t N, const float*, void*, bool trA, bool trB) const { | |||||
| constexpr static size_t MB = 4; | constexpr static size_t MB = 4; | ||||
| constexpr static size_t KB = 4; | constexpr static size_t KB = 4; | ||||
| constexpr static size_t NB = 16; | constexpr static size_t NB = 16; | ||||
| @@ -46,8 +46,9 @@ namespace matmul_12x8x1 { | |||||
| * Accumulator | * Accumulator | ||||
| */ | */ | ||||
| static void kern_12x8(const int16_t* packA, const int16_t* packB, int K, | |||||
| int32_t* output, int LDC, bool is_first_k) { | |||||
| static void kern_12x8( | |||||
| const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC, | |||||
| bool is_first_k) { | |||||
| const int16_t* a_ptr = packA; | const int16_t* a_ptr = packA; | ||||
| const int16_t* b_ptr = packB; | const int16_t* b_ptr = packB; | ||||
| @@ -155,15 +156,13 @@ static void kern_12x8(const int16_t* packA, const int16_t* packB, int K, | |||||
| "stp q25, q26, [x9]\n" | "stp q25, q26, [x9]\n" | ||||
| "stp q27, q28, [x10]\n" | "stp q27, q28, [x10]\n" | ||||
| "stp q29, q30, [x11]\n" | "stp q29, q30, [x11]\n" | ||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||||
| [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||||
| [output] "+r"(output) | |||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||||
| [K] "+r"(K), [LDC] "+r"(LDC), [output] "+r"(output) | |||||
| : | : | ||||
| : "v0", "v1", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", | |||||
| "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", | |||||
| "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "x1", | |||||
| "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", | |||||
| "cc", "memory"); | |||||
| : "v0", "v1", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||||
| "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", | |||||
| "v25", "v26", "v27", "v28", "v29", "v30", "x1", "x2", "x3", "x4", "x5", | |||||
| "x6", "x7", "x8", "x9", "x10", "x11", "cc", "memory"); | |||||
| #undef LOAD_LINE | #undef LOAD_LINE | ||||
| #undef LOAD_C | #undef LOAD_C | ||||
| #undef STORE_LINE | #undef STORE_LINE | ||||
| @@ -196,8 +195,9 @@ static void kern_12x8(const int16_t* packA, const int16_t* packB, int K, | |||||
| * Accumulator | * Accumulator | ||||
| */ | */ | ||||
| static void kern_8x8(const int16_t* packA, const int16_t* packB, int K, | |||||
| int32_t* output, int LDC, bool is_first_k) { | |||||
| static void kern_8x8( | |||||
| const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC, | |||||
| bool is_first_k) { | |||||
| const int16_t* a_ptr = packA; | const int16_t* a_ptr = packA; | ||||
| const int16_t* b_ptr = packB; | const int16_t* b_ptr = packB; | ||||
| @@ -276,13 +276,12 @@ static void kern_8x8(const int16_t* packA, const int16_t* packB, int K, | |||||
| "stp q17, q18, [x5]\n" | "stp q17, q18, [x5]\n" | ||||
| "stp q19, q20, [x6]\n" | "stp q19, q20, [x6]\n" | ||||
| "stp q21, q22, [x7]\n" | "stp q21, q22, [x7]\n" | ||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||||
| [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||||
| [output] "+r"(output) | |||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||||
| [K] "+r"(K), [LDC] "+r"(LDC), [output] "+r"(output) | |||||
| : | : | ||||
| : "v0", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||||
| "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "x1", | |||||
| "x2", "x3", "x4", "x5", "x6", "x7", "cc", "memory"); | |||||
| : "v0", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", | |||||
| "v16", "v17", "v18", "v19", "v20", "v21", "v22", "x1", "x2", "x3", "x4", | |||||
| "x5", "x6", "x7", "cc", "memory"); | |||||
| #undef LOAD_LINE | #undef LOAD_LINE | ||||
| #undef LOAD_C | #undef LOAD_C | ||||
| #undef STORE_LINE | #undef STORE_LINE | ||||
| @@ -311,9 +310,9 @@ static void kern_8x8(const int16_t* packA, const int16_t* packB, int K, | |||||
| * Accumulator | * Accumulator | ||||
| */ | */ | ||||
| static void kern_4x8(const int16_t* packA, const int16_t* packB, int K, | |||||
| int32_t* output, int LDC, bool is_first_k, | |||||
| size_t m_remain) { | |||||
| static void kern_4x8( | |||||
| const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC, | |||||
| bool is_first_k, size_t m_remain) { | |||||
| const int16_t* a_ptr = packA; | const int16_t* a_ptr = packA; | ||||
| const int16_t* b_ptr = packB; | const int16_t* b_ptr = packB; | ||||
| @@ -388,14 +387,13 @@ static void kern_4x8(const int16_t* packA, const int16_t* packB, int K, | |||||
| "cbnz %w[K], 2b\n" | "cbnz %w[K], 2b\n" | ||||
| "3:\n" STORE_C | "3:\n" STORE_C | ||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||||
| [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||||
| [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), | |||||
| [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [x0] "+r"(x0), | |||||
| [m_remain] "+r"(m_remain) | |||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||||
| [K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), | |||||
| [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||||
| [x0] "+r"(x0), [m_remain] "+r"(m_remain) | |||||
| : | : | ||||
| : "v0", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||||
| "cc", "memory"); | |||||
| : "v0", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "cc", | |||||
| "memory"); | |||||
| #undef LOAD_LINE | #undef LOAD_LINE | ||||
| #undef LOAD_C | #undef LOAD_C | ||||
| #undef STORE_LINE | #undef STORE_LINE | ||||
| @@ -432,9 +430,9 @@ static void kern_4x8(const int16_t* packA, const int16_t* packB, int K, | |||||
| * Accumulator | * Accumulator | ||||
| */ | */ | ||||
| static void kern_12x4(const int16_t* packA, const int16_t* packB, int K, | |||||
| int32_t* output, int LDC, bool is_first_k, | |||||
| size_t n_remain) { | |||||
| static void kern_12x4( | |||||
| const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC, | |||||
| bool is_first_k, size_t n_remain) { | |||||
| const int16_t* a_ptr = packA; | const int16_t* a_ptr = packA; | ||||
| const int16_t* b_ptr = packB; | const int16_t* b_ptr = packB; | ||||
| @@ -573,18 +571,16 @@ static void kern_12x4(const int16_t* packA, const int16_t* packB, int K, | |||||
| "cbnz %w[K], 2b\n" | "cbnz %w[K], 2b\n" | ||||
| "3:\n" STORE_C | "3:\n" STORE_C | ||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||||
| [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||||
| [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), | |||||
| [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||||
| [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), | |||||
| [outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), | |||||
| [outptr8] "=r"(outptr8), [outptr9] "=r"(outptr9), | |||||
| [outptr10] "=r"(outptr10), [outptr11] "=r"(outptr11), | |||||
| [x0] "+r"(x0), [n_remain] "+r"(n_remain) | |||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||||
| [K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), | |||||
| [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||||
| [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6), | |||||
| [outptr7] "=r"(outptr7), [outptr8] "=r"(outptr8), [outptr9] "=r"(outptr9), | |||||
| [outptr10] "=r"(outptr10), [outptr11] "=r"(outptr11), [x0] "+r"(x0), | |||||
| [n_remain] "+r"(n_remain) | |||||
| : | : | ||||
| : "v0", "v1", "v2", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||||
| "v15", "v16", "v17", "v18", "v19", "cc", "memory"); | |||||
| : "v0", "v1", "v2", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", | |||||
| "v16", "v17", "v18", "v19", "cc", "memory"); | |||||
| #undef LOAD_LINE | #undef LOAD_LINE | ||||
| #undef LOAD_C | #undef LOAD_C | ||||
| @@ -618,9 +614,9 @@ static void kern_12x4(const int16_t* packA, const int16_t* packB, int K, | |||||
| * Accumulator | * Accumulator | ||||
| */ | */ | ||||
| static void kern_8x4(const int16_t* packA, const int16_t* packB, int K, | |||||
| int32_t* output, int LDC, bool is_first_k, | |||||
| size_t n_remain) { | |||||
| static void kern_8x4( | |||||
| const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC, | |||||
| bool is_first_k, size_t n_remain) { | |||||
| const int16_t* a_ptr = packA; | const int16_t* a_ptr = packA; | ||||
| const int16_t* b_ptr = packB; | const int16_t* b_ptr = packB; | ||||
| @@ -734,16 +730,14 @@ static void kern_8x4(const int16_t* packA, const int16_t* packB, int K, | |||||
| "cbnz %w[K], 2b\n" | "cbnz %w[K], 2b\n" | ||||
| "3:\n" STORE_C | "3:\n" STORE_C | ||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||||
| [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||||
| [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), | |||||
| [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||||
| [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), | |||||
| [outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), [x0] "+r"(x0), | |||||
| [n_remain] "+r"(n_remain) | |||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||||
| [K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), | |||||
| [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||||
| [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6), | |||||
| [outptr7] "=r"(outptr7), [x0] "+r"(x0), [n_remain] "+r"(n_remain) | |||||
| : | : | ||||
| : "v0", "v2", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", | |||||
| "cc", "memory"); | |||||
| : "v0", "v2", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "cc", | |||||
| "memory"); | |||||
| #undef LOAD_LINE | #undef LOAD_LINE | ||||
| #undef LOAD_C | #undef LOAD_C | ||||
| @@ -773,9 +767,9 @@ static void kern_8x4(const int16_t* packA, const int16_t* packB, int K, | |||||
| * Accumulator | * Accumulator | ||||
| */ | */ | ||||
| static void kern_4x4(const int16_t* packA, const int16_t* packB, int K, | |||||
| int32_t* output, int LDC, bool is_first_k, size_t m_remain, | |||||
| size_t n_remain) { | |||||
| static void kern_4x4( | |||||
| const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC, | |||||
| bool is_first_k, size_t m_remain, size_t n_remain) { | |||||
| const int16_t* a_ptr = packA; | const int16_t* a_ptr = packA; | ||||
| const int16_t* b_ptr = packB; | const int16_t* b_ptr = packB; | ||||
| @@ -874,11 +868,10 @@ static void kern_4x4(const int16_t* packA, const int16_t* packB, int K, | |||||
| "cbnz %w[K], 2b\n" | "cbnz %w[K], 2b\n" | ||||
| "3:\n" STORE_C | "3:\n" STORE_C | ||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||||
| [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||||
| [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), | |||||
| [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [x0] "+r"(x0), | |||||
| [m_remain] "+r"(m_remain), [x1] "+r"(x1), | |||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||||
| [K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), | |||||
| [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||||
| [x0] "+r"(x0), [m_remain] "+r"(m_remain), [x1] "+r"(x1), | |||||
| [n_remain] "+r"(n_remain) | [n_remain] "+r"(n_remain) | ||||
| : | : | ||||
| : "v0", "v2", "v8", "v9", "v10", "v11", "cc", "memory"); | : "v0", "v2", "v8", "v9", "v10", "v11", "cc", "memory"); | ||||
| @@ -889,9 +882,9 @@ static void kern_4x4(const int16_t* packA, const int16_t* packB, int K, | |||||
| #undef STORE_C | #undef STORE_C | ||||
| } | } | ||||
| static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr, | |||||
| int ldin, int y0, int ymax, int k0, | |||||
| int kmax) { | |||||
| static void gemm_s16_12x8x1_pack_A_n( | |||||
| int16_t* outptr, const int16_t* inptr, int ldin, int y0, int ymax, int k0, | |||||
| int kmax) { | |||||
| int16_t zerobuff[4]; | int16_t zerobuff[4]; | ||||
| std::memset(zerobuff, 0, sizeof(int16_t) * 4); | std::memset(zerobuff, 0, sizeof(int16_t) * 4); | ||||
| @@ -925,15 +918,15 @@ static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr, | |||||
| int K = kmax - k0; | int K = kmax - k0; | ||||
| for (; K > 3; K -= 4) { | for (; K > 3; K -= 4) { | ||||
| interleave_12x1_4_h(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||||
| inptr6, inptr7, inptr8, inptr9, inptr10, | |||||
| inptr11, outptr); | |||||
| interleave_12x1_4_h( | |||||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
| inptr8, inptr9, inptr10, inptr11, outptr); | |||||
| } | } | ||||
| if (K > 0) { | if (K > 0) { | ||||
| interleave_12(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||||
| inptr6, inptr7, inptr8, inptr9, inptr10, inptr11, | |||||
| outptr, 1, K); | |||||
| interleave_12( | |||||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
| inptr8, inptr9, inptr10, inptr11, outptr, 1, K); | |||||
| } | } | ||||
| } | } | ||||
| @@ -949,13 +942,15 @@ static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr, | |||||
| int K = kmax - k0; | int K = kmax - k0; | ||||
| for (; K > 7; K -= 8) { | for (; K > 7; K -= 8) { | ||||
| interleave_8x1_8_h(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||||
| inptr6, inptr7, outptr); | |||||
| interleave_8x1_8_h( | |||||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
| outptr); | |||||
| } | } | ||||
| if (K > 0) { | if (K > 0) { | ||||
| interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||||
| inptr7, outptr, 1, K); | |||||
| interleave_8( | |||||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
| outptr, 1, K); | |||||
| } | } | ||||
| } | } | ||||
| @@ -975,9 +970,11 @@ static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr, | |||||
| if (y + 3 >= ymax) { | if (y + 3 >= ymax) { | ||||
| switch (y + 3 - ymax) { | switch (y + 3 - ymax) { | ||||
| case 2: | case 2: | ||||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr1 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 1: | case 1: | ||||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr2 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 0: | case 0: | ||||
| inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
| break; | break; | ||||
| @@ -992,9 +989,11 @@ static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr, | |||||
| if (y + 3 >= ymax) { | if (y + 3 >= ymax) { | ||||
| switch (y + 3 - ymax) { | switch (y + 3 - ymax) { | ||||
| case 2: | case 2: | ||||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr1 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 1: | case 1: | ||||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr2 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 0: | case 0: | ||||
| inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
| break; | break; | ||||
| @@ -1007,9 +1006,8 @@ static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr, | |||||
| } | } | ||||
| } | } | ||||
| static void gemm_s16_12x8x1_transpose_pack_A_n(int16_t* out, const int16_t* in, | |||||
| int ldin, int x0, int xmax, | |||||
| int k0, int kmax) { | |||||
| static void gemm_s16_12x8x1_transpose_pack_A_n( | |||||
| int16_t* out, const int16_t* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||||
| const int ksize = kmax - k0; | const int ksize = kmax - k0; | ||||
| const int ksize4 = ksize * 4; | const int ksize4 = ksize * 4; | ||||
| const int ksize8 = ksize4 * 2; | const int ksize8 = ksize4 * 2; | ||||
| @@ -1054,8 +1052,8 @@ static void gemm_s16_12x8x1_transpose_pack_A_n(int16_t* out, const int16_t* in, | |||||
| } | } | ||||
| } | } | ||||
| static void gemm_s16_12x8x1_pack_B_n(int16_t* out, const int16_t* in, int ldin, | |||||
| int x0, int xmax, int k0, int kmax) { | |||||
| static void gemm_s16_12x8x1_pack_B_n( | |||||
| int16_t* out, const int16_t* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||||
| const int ksize = kmax - k0; | const int ksize = kmax - k0; | ||||
| const int ksize4 = ksize * 4; | const int ksize4 = ksize * 4; | ||||
| const int ksize8 = ksize4 * 2; | const int ksize8 = ksize4 * 2; | ||||
| @@ -1090,10 +1088,9 @@ static void gemm_s16_12x8x1_pack_B_n(int16_t* out, const int16_t* in, int ldin, | |||||
| } | } | ||||
| } | } | ||||
| static void gemm_s16_12x8x1_transpose_pack_B_n(int16_t* outptr, | |||||
| const int16_t* inptr, int ldin, | |||||
| int y0, int ymax, int k0, | |||||
| int kmax) { | |||||
| static void gemm_s16_12x8x1_transpose_pack_B_n( | |||||
| int16_t* outptr, const int16_t* inptr, int ldin, int y0, int ymax, int k0, | |||||
| int kmax) { | |||||
| int16_t zerobuff[4]; | int16_t zerobuff[4]; | ||||
| std::memset(zerobuff, 0, sizeof(int16_t) * 4); | std::memset(zerobuff, 0, sizeof(int16_t) * 4); | ||||
| @@ -1110,13 +1107,15 @@ static void gemm_s16_12x8x1_transpose_pack_B_n(int16_t* outptr, | |||||
| int K = kmax - k0; | int K = kmax - k0; | ||||
| for (; K > 7; K -= 8) { | for (; K > 7; K -= 8) { | ||||
| interleave_8x1_8_h(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||||
| inptr6, inptr7, outptr); | |||||
| interleave_8x1_8_h( | |||||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
| outptr); | |||||
| } | } | ||||
| if (K > 0) { | if (K > 0) { | ||||
| interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||||
| inptr7, outptr, 1, K); | |||||
| interleave_8( | |||||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
| outptr, 1, K); | |||||
| } | } | ||||
| } | } | ||||
| @@ -1136,9 +1135,11 @@ static void gemm_s16_12x8x1_transpose_pack_B_n(int16_t* outptr, | |||||
| if (y + 3 >= ymax) { | if (y + 3 >= ymax) { | ||||
| switch (y + 3 - ymax) { | switch (y + 3 - ymax) { | ||||
| case 2: | case 2: | ||||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr1 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 1: | case 1: | ||||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr2 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 0: | case 0: | ||||
| inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
| break; | break; | ||||
| @@ -1153,9 +1154,11 @@ static void gemm_s16_12x8x1_transpose_pack_B_n(int16_t* outptr, | |||||
| if (y + 3 >= ymax) { | if (y + 3 >= ymax) { | ||||
| switch (y + 3 - ymax) { | switch (y + 3 - ymax) { | ||||
| case 2: | case 2: | ||||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr1 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 1: | case 1: | ||||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr2 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 0: | case 0: | ||||
| inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
| break; | break; | ||||
| @@ -22,39 +22,37 @@ using namespace aarch64::matmul; | |||||
| ///////////////////////// gemm_s16_12x8x1 //////////////////////////////////// | ///////////////////////// gemm_s16_12x8x1 //////////////////////////////////// | ||||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s16_12x8x1); | MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s16_12x8x1); | ||||
| void gemm_s16_12x8x1::pack_A(dt_int16* outptr, const dt_int16* inptr, int ldin, | |||||
| int y0, int ymax, int k0, int kmax, | |||||
| bool transpose) const { | |||||
| void gemm_s16_12x8x1::pack_A( | |||||
| dt_int16* outptr, const dt_int16* inptr, int ldin, int y0, int ymax, int k0, | |||||
| int kmax, bool transpose) const { | |||||
| if (transpose) { | if (transpose) { | ||||
| matmul_12x8x1::gemm_s16_12x8x1_transpose_pack_A_n(outptr, inptr, ldin, | |||||
| y0, ymax, k0, kmax); | |||||
| matmul_12x8x1::gemm_s16_12x8x1_transpose_pack_A_n( | |||||
| outptr, inptr, ldin, y0, ymax, k0, kmax); | |||||
| } else { | } else { | ||||
| matmul_12x8x1::gemm_s16_12x8x1_pack_A_n(outptr, inptr, ldin, y0, ymax, | |||||
| k0, kmax); | |||||
| matmul_12x8x1::gemm_s16_12x8x1_pack_A_n( | |||||
| outptr, inptr, ldin, y0, ymax, k0, kmax); | |||||
| } | } | ||||
| } | } | ||||
| void gemm_s16_12x8x1::pack_B(dt_int16* out, const dt_int16* in, int ldin, | |||||
| int x0, int xmax, int k0, int kmax, | |||||
| bool transpose) const { | |||||
| void gemm_s16_12x8x1::pack_B( | |||||
| dt_int16* out, const dt_int16* in, int ldin, int x0, int xmax, int k0, int kmax, | |||||
| bool transpose) const { | |||||
| if (transpose) { | if (transpose) { | ||||
| matmul_12x8x1::gemm_s16_12x8x1_transpose_pack_B_n(out, in, ldin, x0, | |||||
| xmax, k0, kmax); | |||||
| matmul_12x8x1::gemm_s16_12x8x1_transpose_pack_B_n( | |||||
| out, in, ldin, x0, xmax, k0, kmax); | |||||
| } else { | } else { | ||||
| matmul_12x8x1::gemm_s16_12x8x1_pack_B_n(out, in, ldin, x0, xmax, k0, | |||||
| kmax); | |||||
| matmul_12x8x1::gemm_s16_12x8x1_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | |||||
| } | } | ||||
| } | } | ||||
| void gemm_s16_12x8x1::kern(const dt_int16* packA, const dt_int16* packB, | |||||
| size_t M, size_t N, size_t K, dt_int32* C, | |||||
| size_t LDC, bool is_first_k, const dt_int32*, | |||||
| dt_int32*) const { | |||||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
| (A_dtype.enumv() == DTypeEnum::Int16 && | |||||
| C_dtype.enumv() == DTypeEnum::Int32), | |||||
| "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||||
| C_dtype.name()); | |||||
| void gemm_s16_12x8x1::kern( | |||||
| const dt_int16* packA, const dt_int16* packB, size_t M, size_t N, size_t K, | |||||
| dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const { | |||||
| megdnn_assert( | |||||
| A_dtype.enumv() == B_dtype.enumv() && | |||||
| (A_dtype.enumv() == DTypeEnum::Int16 && | |||||
| C_dtype.enumv() == DTypeEnum::Int32), | |||||
| "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); | |||||
| MEGDNN_MARK_USED_VAR(A_dtype); | MEGDNN_MARK_USED_VAR(A_dtype); | ||||
| MEGDNN_MARK_USED_VAR(B_dtype); | MEGDNN_MARK_USED_VAR(B_dtype); | ||||
| MEGDNN_MARK_USED_VAR(C_dtype); | MEGDNN_MARK_USED_VAR(C_dtype); | ||||
| @@ -72,15 +70,15 @@ void gemm_s16_12x8x1::kern(const dt_int16* packA, const dt_int16* packB, | |||||
| size_t n = 0; | size_t n = 0; | ||||
| const dt_int16* cur_packB = packB; | const dt_int16* cur_packB = packB; | ||||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
| matmul_12x8x1::kern_12x8(packA, cur_packB, K, output, LDC, | |||||
| is_first_k); | |||||
| matmul_12x8x1::kern_12x8(packA, cur_packB, K, output, LDC, is_first_k); | |||||
| output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
| cur_packB += K8; | cur_packB += K8; | ||||
| } | } | ||||
| for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
| matmul_12x8x1::kern_12x4(packA, cur_packB, K, output, LDC, | |||||
| is_first_k, std::min<size_t>(N - n, 4)); | |||||
| matmul_12x8x1::kern_12x4( | |||||
| packA, cur_packB, K, output, LDC, is_first_k, | |||||
| std::min<size_t>(N - n, 4)); | |||||
| output += 4; | output += 4; | ||||
| cur_packB += K4; | cur_packB += K4; | ||||
| } | } | ||||
| @@ -92,15 +90,15 @@ void gemm_s16_12x8x1::kern(const dt_int16* packA, const dt_int16* packB, | |||||
| const dt_int16* cur_packB = packB; | const dt_int16* cur_packB = packB; | ||||
| size_t n = 0; | size_t n = 0; | ||||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
| matmul_12x8x1::kern_8x8(packA, cur_packB, K, output, LDC, | |||||
| is_first_k); | |||||
| matmul_12x8x1::kern_8x8(packA, cur_packB, K, output, LDC, is_first_k); | |||||
| output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
| cur_packB += K8; | cur_packB += K8; | ||||
| } | } | ||||
| for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
| matmul_12x8x1::kern_8x4(packA, cur_packB, K, output, LDC, | |||||
| is_first_k, std::min<size_t>(N - n, 4)); | |||||
| matmul_12x8x1::kern_8x4( | |||||
| packA, cur_packB, K, output, LDC, is_first_k, | |||||
| std::min<size_t>(N - n, 4)); | |||||
| output += 4; | output += 4; | ||||
| cur_packB += K4; | cur_packB += K4; | ||||
| } | } | ||||
| @@ -112,16 +110,17 @@ void gemm_s16_12x8x1::kern(const dt_int16* packA, const dt_int16* packB, | |||||
| const dt_int16* cur_packB = packB; | const dt_int16* cur_packB = packB; | ||||
| size_t n = 0; | size_t n = 0; | ||||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
| matmul_12x8x1::kern_4x8(packA, cur_packB, K, output, LDC, | |||||
| is_first_k, std::min<size_t>(M - m, 4)); | |||||
| matmul_12x8x1::kern_4x8( | |||||
| packA, cur_packB, K, output, LDC, is_first_k, | |||||
| std::min<size_t>(M - m, 4)); | |||||
| output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
| cur_packB += K8; | cur_packB += K8; | ||||
| } | } | ||||
| for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
| matmul_12x8x1::kern_4x4(packA, cur_packB, K, output, LDC, | |||||
| is_first_k, std::min<size_t>(M - m, 4), | |||||
| std::min<size_t>(N - n, 4)); | |||||
| matmul_12x8x1::kern_4x4( | |||||
| packA, cur_packB, K, output, LDC, is_first_k, | |||||
| std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||||
| output += 4; | output += 4; | ||||
| cur_packB += K4; | cur_packB += K4; | ||||
| } | } | ||||
| @@ -16,11 +16,11 @@ namespace megdnn { | |||||
| namespace aarch64 { | namespace aarch64 { | ||||
| namespace matmul { | namespace matmul { | ||||
| MEGDNN_REG_GEMM_STRATEGY(dt_int16, dt_int32, dt_int32, 12, 8, 1, false, true, | |||||
| gemm_s16_12x8x1); | |||||
| MEGDNN_REG_GEMM_STRATEGY( | |||||
| dt_int16, dt_int32, dt_int32, 12, 8, 1, false, true, gemm_s16_12x8x1); | |||||
| MEGDNN_REG_GEMM_STRATEGY_NOPACK(dt_int16, dt_int32, dt_int32, 8, 8, 1, false, | |||||
| true, gemm_nopack_s16_8x8); | |||||
| MEGDNN_REG_GEMM_STRATEGY_NOPACK( | |||||
| dt_int16, dt_int32, dt_int32, 8, 8, 1, false, true, gemm_nopack_s16_8x8); | |||||
| } // namespace matmul | } // namespace matmul | ||||
| } // namespace aarch64 | } // namespace aarch64 | ||||
| @@ -9,8 +9,8 @@ | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| #include "src/aarch64/matrix_mul/int16/strategy.h" | |||||
| #include "src/aarch64/matrix_mul/asm/common.h" | #include "src/aarch64/matrix_mul/asm/common.h" | ||||
| #include "src/aarch64/matrix_mul/int16/strategy.h" | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| @@ -20,8 +20,9 @@ using namespace aarch64::matmul; | |||||
| namespace { | namespace { | ||||
| void kern_8x1(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||||
| dt_int32* output) { | |||||
| void kern_8x1( | |||||
| const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||||
| dt_int32* output) { | |||||
| //! As each load 32 number from B, but the pos add 24 * 2, so we minus 24 | //! As each load 32 number from B, but the pos add 24 * 2, so we minus 24 | ||||
| //! here. | //! here. | ||||
| LDB *= sizeof(dt_int16); | LDB *= sizeof(dt_int16); | ||||
| @@ -91,9 +92,8 @@ void kern_8x1(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
| [output] "+r"(output), [LDB] "+r"(LDB) | [output] "+r"(output), [LDB] "+r"(LDB) | ||||
| : | : | ||||
| : "v0", "v16", "v17", "v18", "v19", "v20", "v21", | |||||
| "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||||
| "v29", "v30", "v31", "cc", "memory"); | |||||
| : "v0", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", | |||||
| "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory"); | |||||
| } | } | ||||
| // Overview of register layout: | // Overview of register layout: | ||||
| @@ -120,8 +120,9 @@ void kern_8x1(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||||
| // | v31[0-7]| |v23[0-3]| | // | v31[0-7]| |v23[0-3]| | ||||
| // +---------+ +--------+ | // +---------+ +--------+ | ||||
| // Accumulator | // Accumulator | ||||
| void kern_8x4(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||||
| dt_int32* output) { | |||||
| void kern_8x4( | |||||
| const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||||
| dt_int32* output) { | |||||
| //! As each load 32 number from B, but the pos add 24 * 2, so we minus 24 | //! As each load 32 number from B, but the pos add 24 * 2, so we minus 24 | ||||
| //! here. | //! here. | ||||
| LDB = (LDB - 24) * sizeof(dt_int16); | LDB = (LDB - 24) * sizeof(dt_int16); | ||||
| @@ -349,9 +350,9 @@ void kern_8x4(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
| [output] "+r"(output), [LDB] "+r"(LDB) | [output] "+r"(output), [LDB] "+r"(LDB) | ||||
| : | : | ||||
| : "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19", | |||||
| "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||||
| "v29", "v30", "v31", "cc", "memory"); | |||||
| : "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19", "v20", "v21", "v22", | |||||
| "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", | |||||
| "memory"); | |||||
| } | } | ||||
| // Overview of register layout: | // Overview of register layout: | ||||
| @@ -382,8 +383,9 @@ void kern_8x4(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||||
| // | v7[0-7]| |v30[0-3]|v31[0-3]| | // | v7[0-7]| |v30[0-3]|v31[0-3]| | ||||
| // +--------+ +--------+--------+ | // +--------+ +--------+--------+ | ||||
| // Accumulator | // Accumulator | ||||
| void kern_8x8(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||||
| dt_int32* output) { | |||||
| void kern_8x8( | |||||
| const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||||
| dt_int32* output) { | |||||
| //! As each load 64 number from B, but the pos add 48 * 2, so we minus 48 | //! As each load 64 number from B, but the pos add 48 * 2, so we minus 48 | ||||
| //! here. | //! here. | ||||
| LDB = (LDB - 48) * sizeof(dt_int16); | LDB = (LDB - 48) * sizeof(dt_int16); | ||||
| @@ -693,20 +695,20 @@ void kern_8x8(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | ||||
| [output] "+r"(output), [LDB] "+r"(LDB) | [output] "+r"(output), [LDB] "+r"(LDB) | ||||
| : | : | ||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||||
| "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||||
| "v29", "v30", "v31", "cc", "memory"); | |||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||||
| "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||||
| "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", | |||||
| "cc", "memory"); | |||||
| } | } | ||||
| } // anonymous namespace | } // anonymous namespace | ||||
| MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gemm_nopack_s16_8x8); | MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gemm_nopack_s16_8x8); | ||||
| void gemm_nopack_s16_8x8::kern(const dt_int16* A, size_t LDA, const dt_int16* B, | |||||
| size_t LDB, dt_int32* C, size_t LDC, size_t M, | |||||
| size_t K, size_t N, const dt_int32*, void*, | |||||
| bool trA, bool trB) const { | |||||
| void gemm_nopack_s16_8x8::kern( | |||||
| const dt_int16* A, size_t LDA, const dt_int16* B, size_t LDB, dt_int32* C, | |||||
| size_t LDC, size_t M, size_t K, size_t N, const dt_int32*, void*, bool trA, | |||||
| bool trB) const { | |||||
| constexpr static size_t MB = 8; | constexpr static size_t MB = 8; | ||||
| constexpr static size_t KB = 8; | constexpr static size_t KB = 8; | ||||
| constexpr static size_t NB = 8; | constexpr static size_t NB = 8; | ||||
| @@ -36,9 +36,9 @@ namespace matmul_s4_4x4x16 { | |||||
| * Accumulator | * Accumulator | ||||
| */ | */ | ||||
| static void s4_kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, | |||||
| int16_t* output, int LDC, bool is_first_k, int m_remain, | |||||
| int n_remain) { | |||||
| static void s4_kern_8x8_remain( | |||||
| const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||||
| bool is_first_k, int m_remain, int n_remain) { | |||||
| K /= 8; | K /= 8; | ||||
| LDC = LDC * sizeof(int16_t); | LDC = LDC * sizeof(int16_t); | ||||
| const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
| @@ -170,7 +170,7 @@ static void s4_kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, | |||||
| "dup v5.8b,v20.b[5]\n" | "dup v5.8b,v20.b[5]\n" | ||||
| "dup v6.8b,v20.b[6]\n" | "dup v6.8b,v20.b[6]\n" | ||||
| "dup v7.8b,v20.b[7]\n" | "dup v7.8b,v20.b[7]\n" | ||||
| "ld1 {v17.8b}, [%[b_ptr]], 8\n" | "ld1 {v17.8b}, [%[b_ptr]], 8\n" | ||||
| "dup v8.8b,v20.b[8]\n" | "dup v8.8b,v20.b[8]\n" | ||||
| @@ -318,16 +318,16 @@ static void s4_kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, | |||||
| STORE_C | STORE_C | ||||
| : | : | ||||
| [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), | |||||
| [ is_first_k ] "+r"(is_first_k), [ K ] "+r"(K), [ LDC ] "+r"(LDC), | |||||
| [ outptr ] "+r"(outptr), [ m_remain ] "+r"(m_remain), | |||||
| [ n_remain ] "+r"(n_remain) //,[tmp_packa1]"+r"(tmp_packa1),[tmp_packb1]"+r"(tmp_packb1) | |||||
| [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||||
| [K] "+r"(K), [LDC] "+r"(LDC), [outptr] "+r"(outptr), | |||||
| [m_remain] "+r"(m_remain), | |||||
| [n_remain] "+r"( | |||||
| n_remain) //,[tmp_packa1]"+r"(tmp_packa1),[tmp_packb1]"+r"(tmp_packb1) | |||||
| : | : | ||||
| : "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", | |||||
| "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||||
| "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||||
| "v29", "v30", "v31"); | |||||
| : "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "v0", | |||||
| "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", | |||||
| "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", | |||||
| "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); | |||||
| #undef LOAD_LINE | #undef LOAD_LINE | ||||
| #undef LOAD_C | #undef LOAD_C | ||||
| @@ -335,14 +335,14 @@ static void s4_kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, | |||||
| #undef STORE_C | #undef STORE_C | ||||
| } | } | ||||
| static void s4_kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||||
| int16_t* output, int LDC, bool is_first_k, int m_remain, | |||||
| int n_remain) { | |||||
| static void s4_kern_8x8( | |||||
| const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||||
| bool is_first_k, int m_remain, int n_remain) { | |||||
| K /= 8; | K /= 8; | ||||
| LDC = LDC * sizeof(int16_t); | LDC = LDC * sizeof(int16_t); | ||||
| const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
| const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
| // clang-format off | |||||
| // clang-format off | |||||
| #define LOAD_C_8 \ | #define LOAD_C_8 \ | ||||
| "ld1 {v24.8h}, [x0], #16\n" \ | "ld1 {v24.8h}, [x0], #16\n" \ | ||||
| @@ -363,9 +363,9 @@ static void s4_kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||||
| "st1 {v28.8h}, [x4], #16\n" \ | "st1 {v28.8h}, [x4], #16\n" \ | ||||
| "st1 {v29.8h}, [x5], #16\n" \ | "st1 {v29.8h}, [x5], #16\n" \ | ||||
| "st1 {v30.8h}, [x6], #16\n" \ | "st1 {v30.8h}, [x6], #16\n" \ | ||||
| "st1 {v31.8h}, [x7], #16\n" \ | |||||
| "st1 {v31.8h}, [x7], #16\n" | |||||
| // clang-format on | |||||
| // clang-format on | |||||
| register int16_t* outptr asm("x0") = output; | register int16_t* outptr asm("x0") = output; | ||||
| asm volatile( | asm volatile( | ||||
| "add x1, x0, %x[LDC]\n" | "add x1, x0, %x[LDC]\n" | ||||
| @@ -395,8 +395,8 @@ static void s4_kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||||
| "PRFM PLDL1KEEP, [%[a_ptr], #512]\n" | "PRFM PLDL1KEEP, [%[a_ptr], #512]\n" | ||||
| "PRFM PLDL1KEEP, [%[b_ptr], #512]\n" | "PRFM PLDL1KEEP, [%[b_ptr], #512]\n" | ||||
| "1:\n" | "1:\n" | ||||
| // "ld1 {v20.16b}, [%[a_ptr]],#16\n" | |||||
| // "ld1 {v21.16b}, [%[a_ptr]],#16\n" | |||||
| // "ld1 {v20.16b}, [%[a_ptr]],#16\n" | |||||
| // "ld1 {v21.16b}, [%[a_ptr]],#16\n" | |||||
| "dup v0.8b,v20.b[0]\n" | "dup v0.8b,v20.b[0]\n" | ||||
| "ld1 {v22.16b}, [%[a_ptr]],#16\n" | "ld1 {v22.16b}, [%[a_ptr]],#16\n" | ||||
| "dup v1.8b,v20.b[1]\n" | "dup v1.8b,v20.b[1]\n" | ||||
| @@ -409,7 +409,6 @@ static void s4_kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||||
| "dup v5.8b,v20.b[5]\n" | "dup v5.8b,v20.b[5]\n" | ||||
| "dup v6.8b,v20.b[6]\n" | "dup v6.8b,v20.b[6]\n" | ||||
| "dup v7.8b,v20.b[7]\n" | "dup v7.8b,v20.b[7]\n" | ||||
| "dup v8.8b,v20.b[8]\n" | "dup v8.8b,v20.b[8]\n" | ||||
| "smlal v24.8h, v0.8b, v16.8b\n" | "smlal v24.8h, v0.8b, v16.8b\n" | ||||
| @@ -560,26 +559,26 @@ static void s4_kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||||
| STORE_C_8 | STORE_C_8 | ||||
| : | : | ||||
| [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), | |||||
| [ is_first_k ] "+r"(is_first_k), [ K ] "+r"(K), [ LDC ] "+r"(LDC), | |||||
| [ outptr ] "+r"(outptr), [ m_remain ] "+r"(m_remain), | |||||
| [ n_remain ] "+r"(n_remain) //,[tmp_packa1]"+r"(tmp_packa1),[tmp_packb1]"+r"(tmp_packb1) | |||||
| [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||||
| [K] "+r"(K), [LDC] "+r"(LDC), [outptr] "+r"(outptr), | |||||
| [m_remain] "+r"(m_remain), | |||||
| [n_remain] "+r"( | |||||
| n_remain) //,[tmp_packa1]"+r"(tmp_packa1),[tmp_packb1]"+r"(tmp_packb1) | |||||
| : | : | ||||
| : "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", | |||||
| "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||||
| "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||||
| "v29", "v30", "v31"); | |||||
| : "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "v0", | |||||
| "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", | |||||
| "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", | |||||
| "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); | |||||
| #undef LOAD_LINE | #undef LOAD_LINE | ||||
| #undef LOAD_C | #undef LOAD_C | ||||
| #undef STORE_LINE | #undef STORE_LINE | ||||
| #undef STORE_C | #undef STORE_C | ||||
| } | } | ||||
| //packa | |||||
| static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* inptr, | |||||
| int ldin, int y0, int ymax, int k0, | |||||
| int kmax) { | |||||
| // packa | |||||
| static void gemm_s4x4x16_8x8x8_transpose_pack( | |||||
| dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||||
| int kmax) { | |||||
| int8_t zerobuff[8]; | int8_t zerobuff[8]; | ||||
| int8_t tmpbuff0[8]; | int8_t tmpbuff0[8]; | ||||
| int8_t tmpbuff1[8]; | int8_t tmpbuff1[8]; | ||||
| @@ -617,22 +616,23 @@ static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* in | |||||
| prefetch_2x(inptr5); | prefetch_2x(inptr5); | ||||
| prefetch_2x(inptr6); | prefetch_2x(inptr6); | ||||
| prefetch_2x(inptr7); | prefetch_2x(inptr7); | ||||
| int K = (kmax - k0)/2; | |||||
| int K = (kmax - k0) / 2; | |||||
| //! read 4 * 16 in each row | //! read 4 * 16 in each row | ||||
| for (; K > 3; K -= 4) { | for (; K > 3; K -= 4) { | ||||
| transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, | |||||
| inptr5, inptr6, inptr7, outptr); | |||||
| transpose_4x8_1_b_with_shift( | |||||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
| outptr); | |||||
| } | } | ||||
| if (K > 0) { | if (K > 0) { | ||||
| std::memcpy(tmpbuff0,inptr0,K); | |||||
| std::memcpy(tmpbuff1,inptr1,K); | |||||
| std::memcpy(tmpbuff2,inptr2,K); | |||||
| std::memcpy(tmpbuff3,inptr3,K); | |||||
| std::memcpy(tmpbuff4,inptr4,K); | |||||
| std::memcpy(tmpbuff5,inptr5,K); | |||||
| std::memcpy(tmpbuff6,inptr6,K); | |||||
| std::memcpy(tmpbuff7,inptr7,K); | |||||
| std::memcpy(tmpbuff0, inptr0, K); | |||||
| std::memcpy(tmpbuff1, inptr1, K); | |||||
| std::memcpy(tmpbuff2, inptr2, K); | |||||
| std::memcpy(tmpbuff3, inptr3, K); | |||||
| std::memcpy(tmpbuff4, inptr4, K); | |||||
| std::memcpy(tmpbuff5, inptr5, K); | |||||
| std::memcpy(tmpbuff6, inptr6, K); | |||||
| std::memcpy(tmpbuff7, inptr7, K); | |||||
| inptr0 = tmpbuff0; | inptr0 = tmpbuff0; | ||||
| inptr1 = tmpbuff1; | inptr1 = tmpbuff1; | ||||
| inptr2 = tmpbuff2; | inptr2 = tmpbuff2; | ||||
| @@ -641,8 +641,9 @@ static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* in | |||||
| inptr5 = tmpbuff5; | inptr5 = tmpbuff5; | ||||
| inptr6 = tmpbuff6; | inptr6 = tmpbuff6; | ||||
| inptr7 = tmpbuff7; | inptr7 = tmpbuff7; | ||||
| transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, | |||||
| inptr5, inptr6, inptr7, outptr); | |||||
| transpose_4x8_1_b_with_shift( | |||||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
| outptr); | |||||
| } | } | ||||
| } | } | ||||
| for (; y < ymax; y += 8) { | for (; y < ymax; y += 8) { | ||||
| @@ -655,23 +656,29 @@ static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* in | |||||
| const int8_t* inptr6 = inptr5 + ldin; | const int8_t* inptr6 = inptr5 + ldin; | ||||
| const int8_t* inptr7 = inptr6 + ldin; | const int8_t* inptr7 = inptr6 + ldin; | ||||
| int K = (kmax - k0)/2; | |||||
| int K = (kmax - k0) / 2; | |||||
| //! read 4 * 16 in each row | //! read 4 * 16 in each row | ||||
| for (; K > 3; K -= 4) { | for (; K > 3; K -= 4) { | ||||
| if (y + 7 >= ymax) { | if (y + 7 >= ymax) { | ||||
| switch (y + 7 - ymax) { | switch (y + 7 - ymax) { | ||||
| case 6: | case 6: | ||||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr1 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 5: | case 5: | ||||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr2 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 4: | case 4: | ||||
| inptr3 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr3 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 3: | case 3: | ||||
| inptr4 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr4 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 2: | case 2: | ||||
| inptr5 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr5 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 1: | case 1: | ||||
| inptr6 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr6 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 0: | case 0: | ||||
| inptr7 = zerobuff; | inptr7 = zerobuff; | ||||
| break; | break; | ||||
| @@ -679,24 +686,31 @@ static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* in | |||||
| megdnn_assert(0); | megdnn_assert(0); | ||||
| } | } | ||||
| } | } | ||||
| transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, | |||||
| inptr5, inptr6, inptr7, outptr); | |||||
| transpose_4x8_1_b_with_shift( | |||||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
| outptr); | |||||
| } | } | ||||
| if (K > 0) { | if (K > 0) { | ||||
| if (y + 7 >= ymax) { | if (y + 7 >= ymax) { | ||||
| switch (y + 7 - ymax) { | switch (y + 7 - ymax) { | ||||
| case 6: | case 6: | ||||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr1 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 5: | case 5: | ||||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr2 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 4: | case 4: | ||||
| inptr3 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr3 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 3: | case 3: | ||||
| inptr4 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr4 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 2: | case 2: | ||||
| inptr5 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr5 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 1: | case 1: | ||||
| inptr6 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr6 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 0: | case 0: | ||||
| inptr7 = zerobuff; | inptr7 = zerobuff; | ||||
| break; | break; | ||||
| @@ -705,14 +719,14 @@ static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* in | |||||
| } | } | ||||
| } | } | ||||
| std::memcpy(tmpbuff0,inptr0,K); | |||||
| std::memcpy(tmpbuff1,inptr1,K); | |||||
| std::memcpy(tmpbuff2,inptr2,K); | |||||
| std::memcpy(tmpbuff3,inptr3,K); | |||||
| std::memcpy(tmpbuff4,inptr4,K); | |||||
| std::memcpy(tmpbuff5,inptr5,K); | |||||
| std::memcpy(tmpbuff6,inptr6,K); | |||||
| std::memcpy(tmpbuff7,inptr7,K); | |||||
| std::memcpy(tmpbuff0, inptr0, K); | |||||
| std::memcpy(tmpbuff1, inptr1, K); | |||||
| std::memcpy(tmpbuff2, inptr2, K); | |||||
| std::memcpy(tmpbuff3, inptr3, K); | |||||
| std::memcpy(tmpbuff4, inptr4, K); | |||||
| std::memcpy(tmpbuff5, inptr5, K); | |||||
| std::memcpy(tmpbuff6, inptr6, K); | |||||
| std::memcpy(tmpbuff7, inptr7, K); | |||||
| inptr0 = tmpbuff0; | inptr0 = tmpbuff0; | ||||
| inptr1 = tmpbuff1; | inptr1 = tmpbuff1; | ||||
| inptr2 = tmpbuff2; | inptr2 = tmpbuff2; | ||||
| @@ -721,14 +735,15 @@ static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* in | |||||
| inptr5 = tmpbuff5; | inptr5 = tmpbuff5; | ||||
| inptr6 = tmpbuff6; | inptr6 = tmpbuff6; | ||||
| inptr7 = tmpbuff7; | inptr7 = tmpbuff7; | ||||
| transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, | |||||
| inptr5, inptr6, inptr7, outptr); | |||||
| transpose_4x8_1_b_with_shift( | |||||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
| outptr); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| //packb | |||||
| static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in, int ldin, | |||||
| int x0, int xmax, int k0, int kmax) { | |||||
| // packb | |||||
| static void gemm_s4x4x16_8x8x8_interleave_pack( | |||||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||||
| int8_t zerobuff[8]; | int8_t zerobuff[8]; | ||||
| int8_t tmpbuff0[8]; | int8_t tmpbuff0[8]; | ||||
| int8_t tmpbuff1[8]; | int8_t tmpbuff1[8]; | ||||
| @@ -748,7 +763,7 @@ static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in, | |||||
| std::memset(tmpbuff6, 0, sizeof(int8_t) * 8); | std::memset(tmpbuff6, 0, sizeof(int8_t) * 8); | ||||
| std::memset(tmpbuff7, 0, sizeof(int8_t) * 8); | std::memset(tmpbuff7, 0, sizeof(int8_t) * 8); | ||||
| const int ksize = kmax - k0; | const int ksize = kmax - k0; | ||||
| const int ksize8 = round_up(ksize, 8) * 8; //pack to int8 *8 packto s4 *4 | |||||
| const int ksize8 = round_up(ksize, 8) * 8; // pack to int8 *8 packto s4 *4 | |||||
| int8_t* outptr = out; | int8_t* outptr = out; | ||||
| int8_t* outptr_interleave = nullptr; | int8_t* outptr_interleave = nullptr; | ||||
| @@ -776,21 +791,22 @@ static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in, | |||||
| int8_t* outptr_inner = outptr; | int8_t* outptr_inner = outptr; | ||||
| for (; x + 3 < xmax; x += 4) { | for (; x + 3 < xmax; x += 4) { | ||||
| outptr_interleave = outptr_inner; | outptr_interleave = outptr_inner; | ||||
| interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||||
| inptr6, inptr7, outptr_interleave); | |||||
| interleave_8x4_1_b_with_shift( | |||||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
| outptr_interleave); | |||||
| outptr_inner += ksize8; | outptr_inner += ksize8; | ||||
| } | } | ||||
| if (x < xmax) { | if (x < xmax) { | ||||
| int remainx = xmax - x; | int remainx = xmax - x; | ||||
| std::memcpy(tmpbuff0,inptr0,remainx); | |||||
| std::memcpy(tmpbuff1,inptr1,remainx); | |||||
| std::memcpy(tmpbuff2,inptr2,remainx); | |||||
| std::memcpy(tmpbuff3,inptr3,remainx); | |||||
| std::memcpy(tmpbuff4,inptr4,remainx); | |||||
| std::memcpy(tmpbuff5,inptr5,remainx); | |||||
| std::memcpy(tmpbuff6,inptr6,remainx); | |||||
| std::memcpy(tmpbuff7,inptr7,remainx); | |||||
| std::memcpy(tmpbuff0, inptr0, remainx); | |||||
| std::memcpy(tmpbuff1, inptr1, remainx); | |||||
| std::memcpy(tmpbuff2, inptr2, remainx); | |||||
| std::memcpy(tmpbuff3, inptr3, remainx); | |||||
| std::memcpy(tmpbuff4, inptr4, remainx); | |||||
| std::memcpy(tmpbuff5, inptr5, remainx); | |||||
| std::memcpy(tmpbuff6, inptr6, remainx); | |||||
| std::memcpy(tmpbuff7, inptr7, remainx); | |||||
| inptr0 = tmpbuff0; | inptr0 = tmpbuff0; | ||||
| inptr1 = tmpbuff1; | inptr1 = tmpbuff1; | ||||
| inptr2 = tmpbuff2; | inptr2 = tmpbuff2; | ||||
| @@ -801,8 +817,9 @@ static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in, | |||||
| inptr7 = tmpbuff7; | inptr7 = tmpbuff7; | ||||
| outptr_interleave = outptr_inner; | outptr_interleave = outptr_inner; | ||||
| interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||||
| inptr6, inptr7, outptr_interleave); | |||||
| interleave_8x4_1_b_with_shift( | |||||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
| outptr_interleave); | |||||
| outptr_inner += ksize8; | outptr_inner += ksize8; | ||||
| } | } | ||||
| outptr += 64; | outptr += 64; | ||||
| @@ -847,8 +864,9 @@ static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in, | |||||
| break; | break; | ||||
| } | } | ||||
| outptr_interleave = outptr_inner; | outptr_interleave = outptr_inner; | ||||
| interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||||
| inptr6, inptr7, outptr_interleave); | |||||
| interleave_8x4_1_b_with_shift( | |||||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
| outptr_interleave); | |||||
| outptr_inner += ksize8; | outptr_inner += ksize8; | ||||
| } | } | ||||
| if (x < xmax) { | if (x < xmax) { | ||||
| @@ -880,14 +898,14 @@ static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in, | |||||
| } | } | ||||
| int remainx = xmax - x; | int remainx = xmax - x; | ||||
| outptr_interleave = outptr_inner; | outptr_interleave = outptr_inner; | ||||
| std::memcpy(tmpbuff0,inptr0,remainx); | |||||
| std::memcpy(tmpbuff1,inptr1,remainx); | |||||
| std::memcpy(tmpbuff2,inptr2,remainx); | |||||
| std::memcpy(tmpbuff3,inptr3,remainx); | |||||
| std::memcpy(tmpbuff4,inptr4,remainx); | |||||
| std::memcpy(tmpbuff5,inptr5,remainx); | |||||
| std::memcpy(tmpbuff6,inptr6,remainx); | |||||
| std::memcpy(tmpbuff7,inptr7,remainx); | |||||
| std::memcpy(tmpbuff0, inptr0, remainx); | |||||
| std::memcpy(tmpbuff1, inptr1, remainx); | |||||
| std::memcpy(tmpbuff2, inptr2, remainx); | |||||
| std::memcpy(tmpbuff3, inptr3, remainx); | |||||
| std::memcpy(tmpbuff4, inptr4, remainx); | |||||
| std::memcpy(tmpbuff5, inptr5, remainx); | |||||
| std::memcpy(tmpbuff6, inptr6, remainx); | |||||
| std::memcpy(tmpbuff7, inptr7, remainx); | |||||
| inptr0 = tmpbuff0; | inptr0 = tmpbuff0; | ||||
| inptr1 = tmpbuff1; | inptr1 = tmpbuff1; | ||||
| inptr2 = tmpbuff2; | inptr2 = tmpbuff2; | ||||
| @@ -898,16 +916,16 @@ static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in, | |||||
| inptr7 = tmpbuff7; | inptr7 = tmpbuff7; | ||||
| outptr_interleave = outptr_inner; | outptr_interleave = outptr_inner; | ||||
| interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||||
| inptr6, inptr7, outptr_interleave); | |||||
| interleave_8x4_1_b_with_shift( | |||||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
| outptr_interleave); | |||||
| outptr_inner += ksize8; | outptr_inner += ksize8; | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } // namespace matmul_4x4x16 | |||||
| } // namespace matmul_s4_4x4x16 | |||||
| } // namespace aarch64 | } // namespace aarch64 | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -10,9 +10,9 @@ | |||||
| * implied. | * implied. | ||||
| */ | */ | ||||
| #include "src/aarch64/matrix_mul/int4x4x16/strategy.h" | |||||
| #include "src/aarch64/matrix_mul/asm/common.h" | #include "src/aarch64/matrix_mul/asm/common.h" | ||||
| #include "src/aarch64/matrix_mul/int4x4x16/kernel_int4_8x8x8.h" | #include "src/aarch64/matrix_mul/int4x4x16/kernel_int4_8x8x8.h" | ||||
| #include "src/aarch64/matrix_mul/int4x4x16/strategy.h" | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/fallback/matrix_mul/gemm_common.h" | #include "src/fallback/matrix_mul/gemm_common.h" | ||||
| @@ -23,39 +23,38 @@ using namespace aarch64::matmul; | |||||
| // ===========================gemm_s4x4x16_s4_8x8x8================================== | // ===========================gemm_s4x4x16_s4_8x8x8================================== | ||||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s4x4x16_s4_8x8x8); | MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s4x4x16_s4_8x8x8); | ||||
| void gemm_s4x4x16_s4_8x8x8::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, | |||||
| int ymax, int k0, int kmax, | |||||
| bool transpose) const { | |||||
| void gemm_s4x4x16_s4_8x8x8::pack_A( | |||||
| dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax, | |||||
| bool transpose) const { | |||||
| if (transpose) { | if (transpose) { | ||||
| matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_interleave_pack(out, in, ldin, y0, ymax, k0, | |||||
| kmax); | |||||
| matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_interleave_pack( | |||||
| out, in, ldin, y0, ymax, k0, kmax); | |||||
| } else { | } else { | ||||
| matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_transpose_pack(out, in, ldin, y0, ymax, k0, | |||||
| kmax); | |||||
| matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_transpose_pack( | |||||
| out, in, ldin, y0, ymax, k0, kmax); | |||||
| } | } | ||||
| } | } | ||||
| void gemm_s4x4x16_s4_8x8x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||||
| int xmax, int k0, int kmax, | |||||
| bool transpose) const { | |||||
| void gemm_s4x4x16_s4_8x8x8::pack_B( | |||||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||||
| bool transpose) const { | |||||
| if (transpose) { | if (transpose) { | ||||
| matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_transpose_pack(out, in, ldin, x0, xmax, k0, | |||||
| kmax); | |||||
| matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_transpose_pack( | |||||
| out, in, ldin, x0, xmax, k0, kmax); | |||||
| } else { | } else { | ||||
| matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_interleave_pack(out, in, ldin, x0, xmax, k0, | |||||
| kmax); | |||||
| matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_interleave_pack( | |||||
| out, in, ldin, x0, xmax, k0, kmax); | |||||
| } | } | ||||
| } | } | ||||
| void gemm_s4x4x16_s4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||||
| size_t M, size_t N, size_t K, dt_int16* C, | |||||
| size_t LDC, bool is_first_k, const dt_int16*, | |||||
| dt_int16*) const { | |||||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
| (A_dtype.enumv() == DTypeEnum::QuantizedS4 && | |||||
| C_dtype.enumv() == DTypeEnum::QuantizedS16), | |||||
| "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||||
| C_dtype.name()); | |||||
| void gemm_s4x4x16_s4_8x8x8::kern( | |||||
| const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||||
| dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const { | |||||
| megdnn_assert( | |||||
| A_dtype.enumv() == B_dtype.enumv() && | |||||
| (A_dtype.enumv() == DTypeEnum::QuantizedS4 && | |||||
| C_dtype.enumv() == DTypeEnum::QuantizedS16), | |||||
| "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); | |||||
| MEGDNN_MARK_USED_VAR(A_dtype); | MEGDNN_MARK_USED_VAR(A_dtype); | ||||
| MEGDNN_MARK_USED_VAR(B_dtype); | MEGDNN_MARK_USED_VAR(B_dtype); | ||||
| MEGDNN_MARK_USED_VAR(C_dtype); | MEGDNN_MARK_USED_VAR(C_dtype); | ||||
| @@ -72,16 +71,17 @@ void gemm_s4x4x16_s4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||||
| size_t n = 0; | size_t n = 0; | ||||
| const dt_int8* cur_packB = packB; | const dt_int8* cur_packB = packB; | ||||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
| matmul_s4_4x4x16::s4_kern_8x8(packA, cur_packB, K, output, LDC, | |||||
| is_first_k, A_INTERLEAVE, B_INTERLEAVE); | |||||
| matmul_s4_4x4x16::s4_kern_8x8( | |||||
| packA, cur_packB, K, output, LDC, is_first_k, A_INTERLEAVE, | |||||
| B_INTERLEAVE); | |||||
| output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
| cur_packB += K8; | cur_packB += K8; | ||||
| } | } | ||||
| for (; n < N; n += B_INTERLEAVE) { | for (; n < N; n += B_INTERLEAVE) { | ||||
| matmul_s4_4x4x16::s4_kern_8x8_remain(packA, cur_packB, K, output, LDC, | |||||
| is_first_k, A_INTERLEAVE, | |||||
| std::min<size_t>(N - n, B_INTERLEAVE)); | |||||
| matmul_s4_4x4x16::s4_kern_8x8_remain( | |||||
| packA, cur_packB, K, output, LDC, is_first_k, A_INTERLEAVE, | |||||
| std::min<size_t>(N - n, B_INTERLEAVE)); | |||||
| output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
| cur_packB += K8; | cur_packB += K8; | ||||
| } | } | ||||
| @@ -94,10 +94,10 @@ void gemm_s4x4x16_s4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||||
| size_t n = 0; | size_t n = 0; | ||||
| const dt_int8* cur_packB = packB; | const dt_int8* cur_packB = packB; | ||||
| for (; n < N; n += B_INTERLEAVE) { | for (; n < N; n += B_INTERLEAVE) { | ||||
| matmul_s4_4x4x16::s4_kern_8x8_remain(packA, cur_packB, K, output, LDC, | |||||
| is_first_k, | |||||
| std::min<size_t>(M - m, A_INTERLEAVE), | |||||
| std::min<size_t>(N - n, B_INTERLEAVE)); | |||||
| matmul_s4_4x4x16::s4_kern_8x8_remain( | |||||
| packA, cur_packB, K, output, LDC, is_first_k, | |||||
| std::min<size_t>(M - m, A_INTERLEAVE), | |||||
| std::min<size_t>(N - n, B_INTERLEAVE)); | |||||
| output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
| cur_packB += K8; | cur_packB += K8; | ||||
| } | } | ||||
| @@ -105,5 +105,4 @@ void gemm_s4x4x16_s4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||||
| } | } | ||||
| } | } | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -17,8 +17,8 @@ namespace megdnn { | |||||
| namespace aarch64 { | namespace aarch64 { | ||||
| namespace matmul { | namespace matmul { | ||||
| MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 8, 8, 8, false, true, | |||||
| gemm_s4x4x16_s4_8x8x8); | |||||
| MEGDNN_REG_GEMM_STRATEGY( | |||||
| dt_int8, dt_int16, dt_int16, 8, 8, 8, false, true, gemm_s4x4x16_s4_8x8x8); | |||||
| } // namespace matmul | } // namespace matmul | ||||
| } // namespace aarch64 | } // namespace aarch64 | ||||
| @@ -51,8 +51,9 @@ namespace matmul_4x4x16 { | |||||
| * Accumulator | * Accumulator | ||||
| */ | */ | ||||
| static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
| int32_t* output, int LDC, bool is_first_k) { | |||||
| static void kern_4x4( | |||||
| const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||||
| bool is_first_k) { | |||||
| K /= 16; | K /= 16; | ||||
| const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
| const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
| @@ -472,9 +473,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
| ); | ); | ||||
| } | } | ||||
| static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K, | |||||
| int32_t* output, int LDC, bool is_first_k, | |||||
| int m_remain, int n_remain) { | |||||
| static void kern_4x4_remain( | |||||
| const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||||
| bool is_first_k, int m_remain, int n_remain) { | |||||
| megdnn_assert(K > 0); | megdnn_assert(K > 0); | ||||
| K /= 16; | K /= 16; | ||||
| const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
| @@ -655,16 +656,14 @@ static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K, | |||||
| STORE_C | STORE_C | ||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||||
| [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||||
| [output] "+r"(output), [m_remain] "+r"(m_remain), | |||||
| [n_remain] "+r"(n_remain) | |||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||||
| [K] "+r"(K), [LDC] "+r"(LDC), [output] "+r"(output), | |||||
| [m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain) | |||||
| : | : | ||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||||
| "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||||
| "v29", "v30", "v31", "x0", "x1", "x2", "x3", "x4", "x5", "cc", | |||||
| "memory"); | |||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||||
| "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||||
| "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", | |||||
| "x0", "x1", "x2", "x3", "x4", "x5", "cc", "memory"); | |||||
| #undef LOAD_LINE | #undef LOAD_LINE | ||||
| #undef LOAD_C | #undef LOAD_C | ||||
| @@ -672,8 +671,9 @@ static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K, | |||||
| #undef STORE_C | #undef STORE_C | ||||
| } | } | ||||
| static void gemm_s8_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||||
| int ldin, int y0, int ymax, int k0, int kmax) { | |||||
| static void gemm_s8_4x4_pack_A_n( | |||||
| dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||||
| int kmax) { | |||||
| int8_t zerobuff[16]; | int8_t zerobuff[16]; | ||||
| std::memset(zerobuff, 0, sizeof(int8_t) * 16); | std::memset(zerobuff, 0, sizeof(int8_t) * 16); | ||||
| @@ -716,9 +716,11 @@ static void gemm_s8_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||||
| if (y + 3 >= ymax) { | if (y + 3 >= ymax) { | ||||
| switch (y + 3 - ymax) { | switch (y + 3 - ymax) { | ||||
| case 2: | case 2: | ||||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr1 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 1: | case 1: | ||||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr2 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 0: | case 0: | ||||
| inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
| break; | break; | ||||
| @@ -734,9 +736,11 @@ static void gemm_s8_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||||
| if (y + 3 >= ymax) { | if (y + 3 >= ymax) { | ||||
| switch (y + 3 - ymax) { | switch (y + 3 - ymax) { | ||||
| case 2: | case 2: | ||||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr1 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 1: | case 1: | ||||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr2 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 0: | case 0: | ||||
| inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
| break; | break; | ||||
| @@ -749,8 +753,8 @@ static void gemm_s8_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||||
| } | } | ||||
| } | } | ||||
| static void gemm_s8_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||||
| int x0, int xmax, int k0, int kmax) { | |||||
| static void gemm_s8_4x4_pack_B_n( | |||||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||||
| int8_t zerobuff[16]; | int8_t zerobuff[16]; | ||||
| std::memset(zerobuff, 0, sizeof(int8_t) * 16); | std::memset(zerobuff, 0, sizeof(int8_t) * 16); | ||||
| const int ksize = kmax - k0; | const int ksize = kmax - k0; | ||||
| @@ -777,19 +781,26 @@ static void gemm_s8_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||||
| if (remain >= 0) { | if (remain >= 0) { | ||||
| switch (remain) { | switch (remain) { | ||||
| case 7: | case 7: | ||||
| inptr0 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr0 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 6: | case 6: | ||||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr1 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 5: | case 5: | ||||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr2 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 4: | case 4: | ||||
| inptr3 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr3 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 3: | case 3: | ||||
| inptr4 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr4 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 2: | case 2: | ||||
| inptr5 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr5 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 1: | case 1: | ||||
| inptr6 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr6 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 0: | case 0: | ||||
| inptr7 = zerobuff; | inptr7 = zerobuff; | ||||
| break; | break; | ||||
| @@ -798,9 +809,9 @@ static void gemm_s8_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||||
| } | } | ||||
| } | } | ||||
| transpose_4x16_1_b_helper(inptr0, inptr1, inptr2, inptr3, | |||||
| inptr4, inptr5, inptr6, inptr7, | |||||
| outptr_inner); | |||||
| transpose_4x16_1_b_helper( | |||||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
| outptr_inner); | |||||
| outptr_inner += ksize4; | outptr_inner += ksize4; | ||||
| } | } | ||||
| @@ -808,19 +819,26 @@ static void gemm_s8_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||||
| if (remain >= 0) { | if (remain >= 0) { | ||||
| switch (remain) { | switch (remain) { | ||||
| case 7: | case 7: | ||||
| inptr0 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr0 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 6: | case 6: | ||||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr1 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 5: | case 5: | ||||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr2 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 4: | case 4: | ||||
| inptr3 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr3 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 3: | case 3: | ||||
| inptr4 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr4 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 2: | case 2: | ||||
| inptr5 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr5 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 1: | case 1: | ||||
| inptr6 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr6 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 0: | case 0: | ||||
| inptr7 = zerobuff; | inptr7 = zerobuff; | ||||
| break; | break; | ||||
| @@ -42,8 +42,9 @@ namespace matmul_8x8x8 { | |||||
| * Accumulator | * Accumulator | ||||
| */ | */ | ||||
| static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||||
| int32_t* output, int LDC, bool is_first_k) { | |||||
| static void kern_8x8( | |||||
| const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||||
| bool is_first_k) { | |||||
| K /= 8; | K /= 8; | ||||
| const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
| const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
| @@ -272,14 +273,13 @@ static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||||
| "stp q18, q19, [x5]\n" | "stp q18, q19, [x5]\n" | ||||
| "stp q20, q21, [x6]\n" | "stp q20, q21, [x6]\n" | ||||
| "stp q22, q23, [x7]\n" | "stp q22, q23, [x7]\n" | ||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||||
| [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||||
| [output] "+r"(output) | |||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||||
| [K] "+r"(K), [LDC] "+r"(LDC), [output] "+r"(output) | |||||
| : | : | ||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||||
| "v20", "v21", "v22", "v23", "v26", "v27", "x1", | |||||
| "x2", "x3", "x4", "x5", "x6", "x7", "cc", "memory"); | |||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||||
| "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||||
| "v22", "v23", "v26", "v27", "x1", "x2", "x3", "x4", "x5", "x6", "x7", | |||||
| "cc", "memory"); | |||||
| } | } | ||||
| /** | /** | ||||
| @@ -309,9 +309,9 @@ static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||||
| * Accumulator | * Accumulator | ||||
| */ | */ | ||||
| static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||||
| int32_t* output, int LDC, bool is_first_k, | |||||
| size_t n_remain) { | |||||
| static void kern_8x4( | |||||
| const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||||
| bool is_first_k, size_t n_remain) { | |||||
| K /= 8; | K /= 8; | ||||
| const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
| const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
| @@ -520,16 +520,14 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||||
| "cbnz %w[K], 2b\n" | "cbnz %w[K], 2b\n" | ||||
| "3:\n" STORE_C | "3:\n" STORE_C | ||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||||
| [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||||
| [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), | |||||
| [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||||
| [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), | |||||
| [outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), [x0] "+r"(x0), | |||||
| [n_remain] "+r"(n_remain) | |||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||||
| [K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), | |||||
| [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||||
| [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6), | |||||
| [outptr7] "=r"(outptr7), [x0] "+r"(x0), [n_remain] "+r"(n_remain) | |||||
| : | : | ||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "cc", "memory"); | |||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||||
| "v12", "v13", "v14", "v15", "v16", "v17", "cc", "memory"); | |||||
| #undef LOAD_LINE | #undef LOAD_LINE | ||||
| #undef LOAD_C | #undef LOAD_C | ||||
| @@ -559,9 +557,9 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||||
| * Accumulator | * Accumulator | ||||
| */ | */ | ||||
| static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, | |||||
| int32_t* output, int LDC, bool is_first_k, | |||||
| size_t m_remain) { | |||||
| static void kern_4x8( | |||||
| const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||||
| bool is_first_k, size_t m_remain) { | |||||
| K /= 8; | K /= 8; | ||||
| const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
| const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
| @@ -724,14 +722,13 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, | |||||
| "cbnz %w[K], 2b\n" | "cbnz %w[K], 2b\n" | ||||
| "3:\n" STORE_C | "3:\n" STORE_C | ||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||||
| [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||||
| [outptr0] "+r"(outptr0), | |||||
| [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), | |||||
| [outptr3] "=r"(outptr3), [x0] "+r"(x0), [m_remain] "+r"(m_remain) | |||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||||
| [K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), | |||||
| [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||||
| [x0] "+r"(x0), [m_remain] "+r"(m_remain) | |||||
| : | : | ||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
| "v11", "v12", "v13", "cc", "memory"); | |||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||||
| "v12", "v13", "cc", "memory"); | |||||
| #undef LOAD_LINE | #undef LOAD_LINE | ||||
| #undef LOAD_C | #undef LOAD_C | ||||
| @@ -762,9 +759,9 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, | |||||
| * Accumulator | * Accumulator | ||||
| */ | */ | ||||
| static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
| int32_t* output, int LDC, bool is_first_k, size_t m_remain, | |||||
| size_t n_remain) { | |||||
| static void kern_4x4( | |||||
| const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||||
| bool is_first_k, size_t m_remain, size_t n_remain) { | |||||
| K /= 8; | K /= 8; | ||||
| const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
| const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
| @@ -922,11 +919,10 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
| "cbnz %w[K], 2b\n" | "cbnz %w[K], 2b\n" | ||||
| "3:\n" STORE_C | "3:\n" STORE_C | ||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||||
| [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||||
| [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), | |||||
| [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [x0] "+r"(x0), | |||||
| [x1] "+r"(x1), [m_remain] "+r"(m_remain), | |||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||||
| [K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), | |||||
| [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||||
| [x0] "+r"(x0), [x1] "+r"(x1), [m_remain] "+r"(m_remain), | |||||
| [n_remain] "+r"(n_remain) | [n_remain] "+r"(n_remain) | ||||
| : | : | ||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v11", "cc", | : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v11", "cc", | ||||
| @@ -938,8 +934,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
| #undef STORE_C | #undef STORE_C | ||||
| } | } | ||||
| static void gemm_s8_8x8_pack_A_n(int8_t* outptr, const int8_t* inptr, int ldin, | |||||
| int y0, int ymax, int k0, int kmax) { | |||||
| static void gemm_s8_8x8_pack_A_n( | |||||
| int8_t* outptr, const int8_t* inptr, int ldin, int y0, int ymax, int k0, | |||||
| int kmax) { | |||||
| int8_t zerobuff[16]; | int8_t zerobuff[16]; | ||||
| std::memset(zerobuff, 0, sizeof(int8_t) * 16); | std::memset(zerobuff, 0, sizeof(int8_t) * 16); | ||||
| @@ -965,13 +962,15 @@ static void gemm_s8_8x8_pack_A_n(int8_t* outptr, const int8_t* inptr, int ldin, | |||||
| int K = kmax - k0; | int K = kmax - k0; | ||||
| for (; K > 15; K -= 16) { | for (; K > 15; K -= 16) { | ||||
| interleave_8x8_2_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||||
| inptr6, inptr7, outptr); | |||||
| interleave_8x8_2_b( | |||||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
| outptr); | |||||
| } | } | ||||
| if (K > 0) { | if (K > 0) { | ||||
| interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||||
| inptr7, outptr, 8, K); | |||||
| interleave_8( | |||||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
| outptr, 8, K); | |||||
| } | } | ||||
| } | } | ||||
| @@ -991,9 +990,11 @@ static void gemm_s8_8x8_pack_A_n(int8_t* outptr, const int8_t* inptr, int ldin, | |||||
| if (y + 3 >= ymax) { | if (y + 3 >= ymax) { | ||||
| switch (y + 3 - ymax) { | switch (y + 3 - ymax) { | ||||
| case 2: | case 2: | ||||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr1 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 1: | case 1: | ||||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr2 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 0: | case 0: | ||||
| inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
| break; | break; | ||||
| @@ -1009,9 +1010,11 @@ static void gemm_s8_8x8_pack_A_n(int8_t* outptr, const int8_t* inptr, int ldin, | |||||
| if (y + 3 >= ymax) { | if (y + 3 >= ymax) { | ||||
| switch (y + 3 - ymax) { | switch (y + 3 - ymax) { | ||||
| case 2: | case 2: | ||||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr1 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 1: | case 1: | ||||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr2 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 0: | case 0: | ||||
| inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
| break; | break; | ||||
| @@ -1024,9 +1027,8 @@ static void gemm_s8_8x8_pack_A_n(int8_t* outptr, const int8_t* inptr, int ldin, | |||||
| } | } | ||||
| } | } | ||||
| static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in, | |||||
| int ldin, int x0, int xmax, int k0, | |||||
| int kmax) { | |||||
| static void gemm_s8_8x8_transpose_pack_A_n( | |||||
| int8_t* out, const int8_t* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||||
| int8_t zerobuff[16]; | int8_t zerobuff[16]; | ||||
| std::memset(zerobuff, 0, sizeof(int8_t) * 16); | std::memset(zerobuff, 0, sizeof(int8_t) * 16); | ||||
| const int ksize = kmax - k0; | const int ksize = kmax - k0; | ||||
| @@ -1063,17 +1065,23 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in, | |||||
| if (k + 7 >= kmax) { | if (k + 7 >= kmax) { | ||||
| switch (k + 7 - kmax) { | switch (k + 7 - kmax) { | ||||
| case 6: | case 6: | ||||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr1 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 5: | case 5: | ||||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr2 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 4: | case 4: | ||||
| inptr3 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr3 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 3: | case 3: | ||||
| inptr4 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr4 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 2: | case 2: | ||||
| inptr5 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr5 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 1: | case 1: | ||||
| inptr6 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr6 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 0: | case 0: | ||||
| inptr7 = zerobuff; | inptr7 = zerobuff; | ||||
| break; | break; | ||||
| @@ -1081,8 +1089,9 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in, | |||||
| megdnn_assert(0); | megdnn_assert(0); | ||||
| } | } | ||||
| } | } | ||||
| transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||||
| inptr6, inptr7, outptr); | |||||
| transpose_8x8_1_b( | |||||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
| outptr); | |||||
| outptr += ksize8; | outptr += ksize8; | ||||
| } | } | ||||
| @@ -1091,17 +1100,23 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in, | |||||
| if (k + 7 >= kmax) { | if (k + 7 >= kmax) { | ||||
| switch (k + 7 - kmax) { | switch (k + 7 - kmax) { | ||||
| case 6: | case 6: | ||||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr1 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 5: | case 5: | ||||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr2 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 4: | case 4: | ||||
| inptr3 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr3 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 3: | case 3: | ||||
| inptr4 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr4 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 2: | case 2: | ||||
| inptr5 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr5 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 1: | case 1: | ||||
| inptr6 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr6 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 0: | case 0: | ||||
| inptr7 = zerobuff; | inptr7 = zerobuff; | ||||
| break; | break; | ||||
| @@ -1110,8 +1125,9 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in, | |||||
| } | } | ||||
| } | } | ||||
| transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||||
| inptr7, outptr, 4, 4); | |||||
| transpose_8( | |||||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
| outptr, 4, 4); | |||||
| outptr += ksize4; | outptr += ksize4; | ||||
| } | } | ||||
| @@ -1119,17 +1135,23 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in, | |||||
| if (k + 7 >= kmax) { | if (k + 7 >= kmax) { | ||||
| switch (k + 7 - kmax) { | switch (k + 7 - kmax) { | ||||
| case 6: | case 6: | ||||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr1 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 5: | case 5: | ||||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr2 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 4: | case 4: | ||||
| inptr3 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr3 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 3: | case 3: | ||||
| inptr4 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr4 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 2: | case 2: | ||||
| inptr5 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr5 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 1: | case 1: | ||||
| inptr6 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr6 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 0: | case 0: | ||||
| inptr7 = zerobuff; | inptr7 = zerobuff; | ||||
| break; | break; | ||||
| @@ -1138,8 +1160,9 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in, | |||||
| } | } | ||||
| } | } | ||||
| transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||||
| inptr7, outptr, 4, xmax - x); | |||||
| transpose_8( | |||||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
| outptr, 4, xmax - x); | |||||
| } | } | ||||
| outptr_base += 8 * 8; | outptr_base += 8 * 8; | ||||
| @@ -1147,8 +1170,8 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in, | |||||
| } | } | ||||
| } | } | ||||
| static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin, | |||||
| int x0, int xmax, int k0, int kmax) { | |||||
| static void gemm_s8_8x8_pack_B_n( | |||||
| int8_t* out, const int8_t* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||||
| int8_t zerobuff[16]; | int8_t zerobuff[16]; | ||||
| std::memset(zerobuff, 0, sizeof(int8_t) * 16); | std::memset(zerobuff, 0, sizeof(int8_t) * 16); | ||||
| const int ksize = kmax - k0; | const int ksize = kmax - k0; | ||||
| @@ -1186,17 +1209,23 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin, | |||||
| if (k + 7 >= kmax) { | if (k + 7 >= kmax) { | ||||
| switch (k + 7 - kmax) { | switch (k + 7 - kmax) { | ||||
| case 6: | case 6: | ||||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr1 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 5: | case 5: | ||||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr2 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 4: | case 4: | ||||
| inptr3 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr3 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 3: | case 3: | ||||
| inptr4 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr4 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 2: | case 2: | ||||
| inptr5 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr5 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 1: | case 1: | ||||
| inptr6 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr6 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 0: | case 0: | ||||
| inptr7 = zerobuff; | inptr7 = zerobuff; | ||||
| break; | break; | ||||
| @@ -1205,8 +1234,9 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin, | |||||
| } | } | ||||
| } | } | ||||
| outptr_interleave = outptr; | outptr_interleave = outptr; | ||||
| interleave_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||||
| inptr6, inptr7, outptr_interleave); | |||||
| interleave_8x8_1_b( | |||||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
| outptr_interleave); | |||||
| outptr += ksize8; | outptr += ksize8; | ||||
| } | } | ||||
| @@ -1215,17 +1245,23 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin, | |||||
| if (k + 7 >= kmax) { | if (k + 7 >= kmax) { | ||||
| switch (k + 7 - kmax) { | switch (k + 7 - kmax) { | ||||
| case 6: | case 6: | ||||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr1 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 5: | case 5: | ||||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr2 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 4: | case 4: | ||||
| inptr3 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr3 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 3: | case 3: | ||||
| inptr4 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr4 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 2: | case 2: | ||||
| inptr5 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr5 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 1: | case 1: | ||||
| inptr6 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr6 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 0: | case 0: | ||||
| inptr7 = zerobuff; | inptr7 = zerobuff; | ||||
| break; | break; | ||||
| @@ -1235,8 +1271,9 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin, | |||||
| } | } | ||||
| outptr_interleave = outptr; | outptr_interleave = outptr; | ||||
| interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||||
| inptr7, outptr_interleave, 4, 4); | |||||
| interleave_8( | |||||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
| outptr_interleave, 4, 4); | |||||
| outptr += ksize4; | outptr += ksize4; | ||||
| } | } | ||||
| @@ -1244,17 +1281,23 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin, | |||||
| if (k + 7 >= kmax) { | if (k + 7 >= kmax) { | ||||
| switch (k + 7 - kmax) { | switch (k + 7 - kmax) { | ||||
| case 6: | case 6: | ||||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr1 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 5: | case 5: | ||||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr2 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 4: | case 4: | ||||
| inptr3 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr3 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 3: | case 3: | ||||
| inptr4 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr4 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 2: | case 2: | ||||
| inptr5 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr5 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 1: | case 1: | ||||
| inptr6 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr6 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 0: | case 0: | ||||
| inptr7 = zerobuff; | inptr7 = zerobuff; | ||||
| break; | break; | ||||
| @@ -1264,8 +1307,9 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin, | |||||
| } | } | ||||
| outptr_interleave = outptr; | outptr_interleave = outptr; | ||||
| interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||||
| inptr7, outptr_interleave, 4, xmax - x); | |||||
| interleave_8( | |||||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
| outptr_interleave, 4, xmax - x); | |||||
| } | } | ||||
| outptr_base += 8 * 8; | outptr_base += 8 * 8; | ||||
| @@ -1273,9 +1317,9 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin, | |||||
| } | } | ||||
| } | } | ||||
| static void gemm_s8_8x8_transpose_pack_B_n(int8_t* outptr, const int8_t* inptr, | |||||
| int ldin, int y0, int ymax, int k0, | |||||
| int kmax) { | |||||
| static void gemm_s8_8x8_transpose_pack_B_n( | |||||
| int8_t* outptr, const int8_t* inptr, int ldin, int y0, int ymax, int k0, | |||||
| int kmax) { | |||||
| int8_t zerobuff[16]; | int8_t zerobuff[16]; | ||||
| std::memset(zerobuff, 0, sizeof(int8_t) * 16); | std::memset(zerobuff, 0, sizeof(int8_t) * 16); | ||||
| constexpr int interleave4 = 32; | constexpr int interleave4 = 32; | ||||
| @@ -1303,14 +1347,16 @@ static void gemm_s8_8x8_transpose_pack_B_n(int8_t* outptr, const int8_t* inptr, | |||||
| int K = kmax - k0; | int K = kmax - k0; | ||||
| for (; K > 7; K -= 8) { | for (; K > 7; K -= 8) { | ||||
| transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||||
| inptr6, inptr7, outptr); | |||||
| transpose_8x8_1_b( | |||||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
| outptr); | |||||
| outptr += interleave8; | outptr += interleave8; | ||||
| } | } | ||||
| if (K > 0) { | if (K > 0) { | ||||
| transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||||
| inptr7, outptr, 8, K); | |||||
| transpose_8( | |||||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
| outptr, 8, K); | |||||
| outptr += interleave8; | outptr += interleave8; | ||||
| } | } | ||||
| } | } | ||||
| @@ -1331,9 +1377,11 @@ static void gemm_s8_8x8_transpose_pack_B_n(int8_t* outptr, const int8_t* inptr, | |||||
| if (y + 3 >= ymax) { | if (y + 3 >= ymax) { | ||||
| switch (y + 3 - ymax) { | switch (y + 3 - ymax) { | ||||
| case 2: | case 2: | ||||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr1 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 1: | case 1: | ||||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr2 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 0: | case 0: | ||||
| inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
| break; | break; | ||||
| @@ -1350,9 +1398,11 @@ static void gemm_s8_8x8_transpose_pack_B_n(int8_t* outptr, const int8_t* inptr, | |||||
| if (y + 3 >= ymax) { | if (y + 3 >= ymax) { | ||||
| switch (y + 3 - ymax) { | switch (y + 3 - ymax) { | ||||
| case 2: | case 2: | ||||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr1 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 1: | case 1: | ||||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr2 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 0: | case 0: | ||||
| inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
| break; | break; | ||||
| @@ -50,8 +50,9 @@ namespace matmul_mk4_4x4x16 { | |||||
| * Accumulator | * Accumulator | ||||
| */ | */ | ||||
| static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
| int32_t* output, bool is_first_k) { | |||||
| static void kern_4x4( | |||||
| const int8_t* packA, const int8_t* packB, int K, int32_t* output, | |||||
| bool is_first_k) { | |||||
| K = div_ceil(K, 16); | K = div_ceil(K, 16); | ||||
| const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
| const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
| @@ -366,17 +367,18 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
| "6:\n" | "6:\n" | ||||
| "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[output]], #64\n" | "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[output]], #64\n" | ||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||||
| [is_first_k] "+r"(is_first_k), [k] "+r"(K), [output] "+r"(output) | |||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||||
| [k] "+r"(K), [output] "+r"(output) | |||||
| : | : | ||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||||
| "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||||
| "v29", "v30", "v31", "cc", "memory"); | |||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||||
| "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||||
| "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", | |||||
| "cc", "memory"); | |||||
| } | } | ||||
| static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K, | |||||
| int32_t* output, bool is_first_k, size_t remain_n) { | |||||
| static void kern_4x4_remain( | |||||
| const int8_t* packA, const int8_t* packB, int K, int32_t* output, | |||||
| bool is_first_k, size_t remain_n) { | |||||
| K = div_ceil(K, 16); | K = div_ceil(K, 16); | ||||
| const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
| const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
| @@ -718,26 +720,27 @@ static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K, | |||||
| "7:\n" | "7:\n" | ||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||||
| [remain_n] "+r"(remain_n), [is_first_k] "+r"(is_first_k), | |||||
| [k] "+r"(K), [output] "+r"(output) | |||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [remain_n] "+r"(remain_n), | |||||
| [is_first_k] "+r"(is_first_k), [k] "+r"(K), [output] "+r"(output) | |||||
| : | : | ||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||||
| "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||||
| "v29", "v30", "v31", "cc", "memory"); | |||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||||
| "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||||
| "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", | |||||
| "cc", "memory"); | |||||
| } | } | ||||
| static void gemm_mk4_s8_4x4_pack_A(dt_int8* outptr, const dt_int8* inptr, | |||||
| int ldin, int y0, int ymax, int k0, | |||||
| int kmax) { | |||||
| static void gemm_mk4_s8_4x4_pack_A( | |||||
| dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||||
| int kmax) { | |||||
| //! pack form {oc/4, ic/4, 4(ic), 4(oc)} to {oc/4, ic/16, 4(oc), 16(ic)} | //! pack form {oc/4, ic/4, 4(ic), 4(oc)} to {oc/4, ic/16, 4(oc), 16(ic)} | ||||
| int8_t zerobuff[4][64]; | int8_t zerobuff[4][64]; | ||||
| std::memset(zerobuff, 0, sizeof(int8_t) * 64 * 4); | std::memset(zerobuff, 0, sizeof(int8_t) * 64 * 4); | ||||
| megdnn_assert(ymax % 4 == 0 && y0 % 4 == 0 && (ymax - y0) % 4 == 0, | |||||
| "mk4 matmul with m is not times of 4"); | |||||
| megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0 && (kmax - k0) % 4 == 0, | |||||
| "mk4 matmul with k is not times of 4"); | |||||
| megdnn_assert( | |||||
| ymax % 4 == 0 && y0 % 4 == 0 && (ymax - y0) % 4 == 0, | |||||
| "mk4 matmul with m is not times of 4"); | |||||
| megdnn_assert( | |||||
| kmax % 4 == 0 && k0 % 4 == 0 && (kmax - k0) % 4 == 0, | |||||
| "mk4 matmul with k is not times of 4"); | |||||
| size_t roundk = round_up(kmax - k0, 16); | size_t roundk = round_up(kmax - k0, 16); | ||||
| size_t out_offset = roundk * 4; | size_t out_offset = roundk * 4; | ||||
| int y = y0; | int y = y0; | ||||
| @@ -754,8 +757,8 @@ static void gemm_mk4_s8_4x4_pack_A(dt_int8* outptr, const dt_int8* inptr, | |||||
| prefetch_2x(inptr3); | prefetch_2x(inptr3); | ||||
| int K = kmax - k0; | int K = kmax - k0; | ||||
| for (; K > 15; K -= 16) { | for (; K > 15; K -= 16) { | ||||
| transpose_interleave_4x4_4_b(inptr0, inptr1, inptr2, inptr3, output, | |||||
| out_offset); | |||||
| transpose_interleave_4x4_4_b( | |||||
| inptr0, inptr1, inptr2, inptr3, output, out_offset); | |||||
| output += 64; | output += 64; | ||||
| } | } | ||||
| if (K > 0) { | if (K > 0) { | ||||
| @@ -767,8 +770,8 @@ static void gemm_mk4_s8_4x4_pack_A(dt_int8* outptr, const dt_int8* inptr, | |||||
| inptr1 = zerobuff[1]; | inptr1 = zerobuff[1]; | ||||
| inptr2 = zerobuff[2]; | inptr2 = zerobuff[2]; | ||||
| inptr3 = zerobuff[3]; | inptr3 = zerobuff[3]; | ||||
| transpose_interleave_4x4_4_b(inptr0, inptr1, inptr2, inptr3, output, | |||||
| out_offset); | |||||
| transpose_interleave_4x4_4_b( | |||||
| inptr0, inptr1, inptr2, inptr3, output, out_offset); | |||||
| output += 64; | output += 64; | ||||
| } | } | ||||
| } | } | ||||
| @@ -790,21 +793,21 @@ static void gemm_mk4_s8_4x4_pack_A(dt_int8* outptr, const dt_int8* inptr, | |||||
| } | } | ||||
| } | } | ||||
| static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||||
| int x0, int xmax, int k0, int kmax) { | |||||
| static void gemm_mk4_s8_4x4_pack_B( | |||||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||||
| int32_t zerobuff[4]; | int32_t zerobuff[4]; | ||||
| std::memset(zerobuff, 0, sizeof(int8_t) * 16); | std::memset(zerobuff, 0, sizeof(int8_t) * 16); | ||||
| const int ksize = kmax - k0; | const int ksize = kmax - k0; | ||||
| const int ICB = (ksize) / 4; | const int ICB = (ksize) / 4; | ||||
| const int ksize4 = round_up<int>(ICB, 4) * 4; | const int ksize4 = round_up<int>(ICB, 4) * 4; | ||||
| int32_t* outptr = reinterpret_cast<int32_t*>(out); | int32_t* outptr = reinterpret_cast<int32_t*>(out); | ||||
| megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0 && ksize % 4 == 0, | |||||
| "mk4 matmul with k is not times of 4"); | |||||
| megdnn_assert( | |||||
| kmax % 4 == 0 && k0 % 4 == 0 && ksize % 4 == 0, | |||||
| "mk4 matmul with k is not times of 4"); | |||||
| int k = k0 / 4; | int k = k0 / 4; | ||||
| for (; k + 3 < ICB; k += 4) { | for (; k + 3 < ICB; k += 4) { | ||||
| const int32_t* inptr0 = | |||||
| reinterpret_cast<const int32_t*>(in + k * ldin + x0); | |||||
| const int32_t* inptr0 = reinterpret_cast<const int32_t*>(in + k * ldin + x0); | |||||
| const int32_t* inptr1 = | const int32_t* inptr1 = | ||||
| reinterpret_cast<const int32_t*>(in + (k + 1) * ldin + x0); | reinterpret_cast<const int32_t*>(in + (k + 1) * ldin + x0); | ||||
| const int32_t* inptr2 = | const int32_t* inptr2 = | ||||
| @@ -829,8 +832,7 @@ static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||||
| outptr += 4 * 4; | outptr += 4 * 4; | ||||
| } | } | ||||
| if (k < ICB) { | if (k < ICB) { | ||||
| const int32_t* inptr0 = | |||||
| reinterpret_cast<const int32_t*>(in + k * ldin + x0); | |||||
| const int32_t* inptr0 = reinterpret_cast<const int32_t*>(in + k * ldin + x0); | |||||
| const int32_t* inptr1 = | const int32_t* inptr1 = | ||||
| reinterpret_cast<const int32_t*>(in + (k + 1) * ldin + x0); | reinterpret_cast<const int32_t*>(in + (k + 1) * ldin + x0); | ||||
| const int32_t* inptr2 = | const int32_t* inptr2 = | ||||
| @@ -844,9 +846,11 @@ static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||||
| if (k + 3 >= ICB) { | if (k + 3 >= ICB) { | ||||
| switch (k + 3 - ICB) { | switch (k + 3 - ICB) { | ||||
| case 2: | case 2: | ||||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr1 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 1: | case 1: | ||||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr2 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 0: | case 0: | ||||
| inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
| break; | break; | ||||
| @@ -861,9 +865,11 @@ static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||||
| if (k + 3 >= ICB) { | if (k + 3 >= ICB) { | ||||
| switch (k + 3 - ICB) { | switch (k + 3 - ICB) { | ||||
| case 2: | case 2: | ||||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr1 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 1: | case 1: | ||||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr2 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 0: | case 0: | ||||
| inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
| break; | break; | ||||
| @@ -882,7 +888,7 @@ static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||||
| } | } | ||||
| } | } | ||||
| } // namespace matmul_4x4x16 | |||||
| } // namespace matmul_mk4_4x4x16 | |||||
| } // namespace aarch64 | } // namespace aarch64 | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -24,20 +24,19 @@ using namespace aarch64::matmul; | |||||
| ///////////////////////// gemm_s8_4x4 //////////////////////////////////// | ///////////////////////// gemm_s8_4x4 //////////////////////////////////// | ||||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_4x4); | MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_4x4); | ||||
| void gemm_s8_4x4::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin, | |||||
| int y0, int ymax, int k0, int kmax, | |||||
| bool transpose) const { | |||||
| void gemm_s8_4x4::pack_A( | |||||
| dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||||
| int kmax, bool transpose) const { | |||||
| if (transpose) { | if (transpose) { | ||||
| matmul_4x4x16::gemm_s8_4x4_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, | |||||
| kmax); | |||||
| matmul_4x4x16::gemm_s8_4x4_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||||
| } else { | } else { | ||||
| matmul_4x4x16::gemm_s8_4x4_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, | |||||
| kmax); | |||||
| matmul_4x4x16::gemm_s8_4x4_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||||
| } | } | ||||
| } | } | ||||
| void gemm_s8_4x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||||
| int xmax, int k0, int kmax, bool transpose) const { | |||||
| void gemm_s8_4x4::pack_B( | |||||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||||
| bool transpose) const { | |||||
| if (transpose) { | if (transpose) { | ||||
| matmul_4x4x16::gemm_s8_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, kmax); | matmul_4x4x16::gemm_s8_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, kmax); | ||||
| } else { | } else { | ||||
| @@ -45,16 +44,16 @@ void gemm_s8_4x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||||
| } | } | ||||
| } | } | ||||
| void gemm_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||||
| size_t N, size_t K, dt_int32* C, size_t LDC, | |||||
| bool is_first_k, const dt_int32*, dt_int32*) const { | |||||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
| ((A_dtype.enumv() == DTypeEnum::Int8 && | |||||
| C_dtype.enumv() == DTypeEnum::Int32) || | |||||
| (A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
| C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||||
| "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||||
| C_dtype.name()); | |||||
| void gemm_s8_4x4::kern( | |||||
| const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||||
| dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const { | |||||
| megdnn_assert( | |||||
| A_dtype.enumv() == B_dtype.enumv() && | |||||
| ((A_dtype.enumv() == DTypeEnum::Int8 && | |||||
| C_dtype.enumv() == DTypeEnum::Int32) || | |||||
| (A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
| C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||||
| "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); | |||||
| MEGDNN_MARK_USED_VAR(A_dtype); | MEGDNN_MARK_USED_VAR(A_dtype); | ||||
| MEGDNN_MARK_USED_VAR(B_dtype); | MEGDNN_MARK_USED_VAR(B_dtype); | ||||
| MEGDNN_MARK_USED_VAR(C_dtype); | MEGDNN_MARK_USED_VAR(C_dtype); | ||||
| @@ -72,16 +71,15 @@ void gemm_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||||
| size_t n = 0; | size_t n = 0; | ||||
| const dt_int8* cur_packB = packB; | const dt_int8* cur_packB = packB; | ||||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
| matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC, | |||||
| is_first_k); | |||||
| matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k); | |||||
| output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
| cur_packB += K4; | cur_packB += K4; | ||||
| } | } | ||||
| for (; n < N; n += B_INTERLEAVE) { | for (; n < N; n += B_INTERLEAVE) { | ||||
| matmul_4x4x16::kern_4x4_remain(packA, cur_packB, K, output, LDC, | |||||
| is_first_k, 4, | |||||
| std::min<size_t>(N - n, 4)); | |||||
| matmul_4x4x16::kern_4x4_remain( | |||||
| packA, cur_packB, K, output, LDC, is_first_k, 4, | |||||
| std::min<size_t>(N - n, 4)); | |||||
| output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
| cur_packB += K4; | cur_packB += K4; | ||||
| } | } | ||||
| @@ -107,33 +105,32 @@ void gemm_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||||
| ///////////////////////// gemm_mk4_s8_4x4 //////////////////////////////////// | ///////////////////////// gemm_mk4_s8_4x4 //////////////////////////////////// | ||||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_s8_4x4); | MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_s8_4x4); | ||||
| void gemm_mk4_s8_4x4::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin, | |||||
| int y0, int ymax, int k0, int kmax, | |||||
| bool transpose) const { | |||||
| megdnn_assert(!transpose, | |||||
| "the gemm_mk4_s8_4x4 strategy is not support transpose A"); | |||||
| matmul_mk4_4x4x16::gemm_mk4_s8_4x4_pack_A(outptr, inptr, ldin, y0, ymax, k0, | |||||
| kmax); | |||||
| void gemm_mk4_s8_4x4::pack_A( | |||||
| dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||||
| int kmax, bool transpose) const { | |||||
| megdnn_assert( | |||||
| !transpose, "the gemm_mk4_s8_4x4 strategy is not support transpose A"); | |||||
| matmul_mk4_4x4x16::gemm_mk4_s8_4x4_pack_A(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||||
| } | } | ||||
| void gemm_mk4_s8_4x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||||
| int xmax, int k0, int kmax, bool transpose) const { | |||||
| megdnn_assert(!transpose, | |||||
| "the gemm_mk4_s8_4x4 strategy is not support transpose B"); | |||||
| matmul_mk4_4x4x16::gemm_mk4_s8_4x4_pack_B(out, in, ldin, x0, xmax, k0, | |||||
| kmax); | |||||
| void gemm_mk4_s8_4x4::pack_B( | |||||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||||
| bool transpose) const { | |||||
| megdnn_assert( | |||||
| !transpose, "the gemm_mk4_s8_4x4 strategy is not support transpose B"); | |||||
| matmul_mk4_4x4x16::gemm_mk4_s8_4x4_pack_B(out, in, ldin, x0, xmax, k0, kmax); | |||||
| } | } | ||||
| void gemm_mk4_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||||
| size_t N, size_t K, dt_int32* C, size_t LDC, | |||||
| bool is_first_k, const dt_int32*, dt_int32*) const { | |||||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
| ((A_dtype.enumv() == DTypeEnum::Int8 && | |||||
| C_dtype.enumv() == DTypeEnum::Int32) || | |||||
| (A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
| C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||||
| "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||||
| C_dtype.name()); | |||||
| void gemm_mk4_s8_4x4::kern( | |||||
| const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||||
| dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const { | |||||
| megdnn_assert( | |||||
| A_dtype.enumv() == B_dtype.enumv() && | |||||
| ((A_dtype.enumv() == DTypeEnum::Int8 && | |||||
| C_dtype.enumv() == DTypeEnum::Int32) || | |||||
| (A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
| C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||||
| "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); | |||||
| MEGDNN_MARK_USED_VAR(A_dtype); | MEGDNN_MARK_USED_VAR(A_dtype); | ||||
| MEGDNN_MARK_USED_VAR(B_dtype); | MEGDNN_MARK_USED_VAR(B_dtype); | ||||
| MEGDNN_MARK_USED_VAR(C_dtype); | MEGDNN_MARK_USED_VAR(C_dtype); | ||||
| @@ -151,57 +148,54 @@ void gemm_mk4_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||||
| size_t n = 0; | size_t n = 0; | ||||
| const dt_int8* cur_packB = packB; | const dt_int8* cur_packB = packB; | ||||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
| matmul_mk4_4x4x16::kern_4x4(packA, cur_packB, K, output, | |||||
| is_first_k); | |||||
| matmul_mk4_4x4x16::kern_4x4(packA, cur_packB, K, output, is_first_k); | |||||
| output += B_INTERLEAVE * 4; | output += B_INTERLEAVE * 4; | ||||
| cur_packB += K4; | cur_packB += K4; | ||||
| } | } | ||||
| if (n < N) { | if (n < N) { | ||||
| matmul_mk4_4x4x16::kern_4x4_remain(packA, cur_packB, K, output, | |||||
| is_first_k, N - n); | |||||
| matmul_mk4_4x4x16::kern_4x4_remain( | |||||
| packA, cur_packB, K, output, is_first_k, N - n); | |||||
| } | } | ||||
| packA += K4; | packA += K4; | ||||
| } | } | ||||
| } | } | ||||
| ///////////////////////// gemm_s8_8x8 //////////////////////////////////// | ///////////////////////// gemm_s8_8x8 //////////////////////////////////// | ||||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x8); | MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x8); | ||||
| void gemm_s8_8x8::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin, | |||||
| int y0, int ymax, int k0, int kmax, | |||||
| bool transpose) const { | |||||
| void gemm_s8_8x8::pack_A( | |||||
| dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||||
| int kmax, bool transpose) const { | |||||
| if (transpose) { | if (transpose) { | ||||
| matmul_8x8x8::gemm_s8_8x8_transpose_pack_A_n(outptr, inptr, ldin, y0, | |||||
| ymax, k0, kmax); | |||||
| matmul_8x8x8::gemm_s8_8x8_transpose_pack_A_n( | |||||
| outptr, inptr, ldin, y0, ymax, k0, kmax); | |||||
| } else { | } else { | ||||
| matmul_8x8x8::gemm_s8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, | |||||
| kmax); | |||||
| matmul_8x8x8::gemm_s8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||||
| } | } | ||||
| } | } | ||||
| void gemm_s8_8x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||||
| int xmax, int k0, int kmax, bool transpose) const { | |||||
| void gemm_s8_8x8::pack_B( | |||||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||||
| bool transpose) const { | |||||
| if (transpose) { | if (transpose) { | ||||
| matmul_8x8x8::gemm_s8_8x8_transpose_pack_B_n(out, in, ldin, x0, xmax, | |||||
| k0, kmax); | |||||
| matmul_8x8x8::gemm_s8_8x8_transpose_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | |||||
| } else { | } else { | ||||
| matmul_8x8x8::gemm_s8_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | matmul_8x8x8::gemm_s8_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | ||||
| } | } | ||||
| } | } | ||||
| void gemm_s8_8x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||||
| size_t N, size_t K, dt_int32* C, size_t LDC, | |||||
| bool is_first_k, const dt_int32*, dt_int32*) const { | |||||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
| ((A_dtype.enumv() == DTypeEnum::Int8 && | |||||
| C_dtype.enumv() == DTypeEnum::Int32) || | |||||
| (A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
| C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||||
| "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||||
| C_dtype.name()); | |||||
| void gemm_s8_8x8::kern( | |||||
| const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||||
| dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const { | |||||
| megdnn_assert( | |||||
| A_dtype.enumv() == B_dtype.enumv() && | |||||
| ((A_dtype.enumv() == DTypeEnum::Int8 && | |||||
| C_dtype.enumv() == DTypeEnum::Int32) || | |||||
| (A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
| C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||||
| "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); | |||||
| MEGDNN_MARK_USED_VAR(A_dtype); | MEGDNN_MARK_USED_VAR(A_dtype); | ||||
| MEGDNN_MARK_USED_VAR(B_dtype); | MEGDNN_MARK_USED_VAR(B_dtype); | ||||
| MEGDNN_MARK_USED_VAR(C_dtype); | MEGDNN_MARK_USED_VAR(C_dtype); | ||||
| @@ -220,15 +214,15 @@ void gemm_s8_8x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||||
| size_t n = 0; | size_t n = 0; | ||||
| const dt_int8* cur_packB = packB; | const dt_int8* cur_packB = packB; | ||||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
| matmul_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, | |||||
| is_first_k); | |||||
| matmul_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, is_first_k); | |||||
| output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
| cur_packB += K8; | cur_packB += K8; | ||||
| } | } | ||||
| for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
| matmul_8x8x8::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k, | |||||
| std::min<size_t>(N - n, 4)); | |||||
| matmul_8x8x8::kern_8x4( | |||||
| packA, cur_packB, K, output, LDC, is_first_k, | |||||
| std::min<size_t>(N - n, 4)); | |||||
| output += 4; | output += 4; | ||||
| cur_packB += K4; | cur_packB += K4; | ||||
| } | } | ||||
| @@ -240,16 +234,17 @@ void gemm_s8_8x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||||
| const dt_int8* cur_packB = packB; | const dt_int8* cur_packB = packB; | ||||
| size_t n = 0; | size_t n = 0; | ||||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
| matmul_8x8x8::kern_4x8(packA, cur_packB, K, output, LDC, is_first_k, | |||||
| std::min<size_t>(M - m, 4)); | |||||
| matmul_8x8x8::kern_4x8( | |||||
| packA, cur_packB, K, output, LDC, is_first_k, | |||||
| std::min<size_t>(M - m, 4)); | |||||
| output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
| cur_packB += K8; | cur_packB += K8; | ||||
| } | } | ||||
| for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
| matmul_8x8x8::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, | |||||
| std::min<size_t>(M - m, 4), | |||||
| std::min<size_t>(N - n, 4)); | |||||
| matmul_8x8x8::kern_4x4( | |||||
| packA, cur_packB, K, output, LDC, is_first_k, | |||||
| std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||||
| output += 4; | output += 4; | ||||
| cur_packB += K4; | cur_packB += K4; | ||||
| } | } | ||||
| @@ -16,14 +16,14 @@ namespace megdnn { | |||||
| namespace aarch64 { | namespace aarch64 { | ||||
| namespace matmul { | namespace matmul { | ||||
| MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 4, 16, false, true, | |||||
| gemm_s8_4x4); | |||||
| MEGDNN_REG_GEMM_STRATEGY( | |||||
| dt_int8, dt_int32, dt_int32, 4, 4, 16, false, true, gemm_s8_4x4); | |||||
| MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 4, 16, false, false, | |||||
| gemm_mk4_s8_4x4); | |||||
| MEGDNN_REG_GEMM_STRATEGY( | |||||
| dt_int8, dt_int32, dt_int32, 4, 4, 16, false, false, gemm_mk4_s8_4x4); | |||||
| MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 8, 8, false, true, | |||||
| gemm_s8_8x8); | |||||
| MEGDNN_REG_GEMM_STRATEGY( | |||||
| dt_int8, dt_int32, dt_int32, 8, 8, 8, false, true, gemm_s8_8x8); | |||||
| } // namespace matmul | } // namespace matmul | ||||
| } // namespace aarch64 | } // namespace aarch64 | ||||
| @@ -52,8 +52,9 @@ namespace matmul_8x12x4 { | |||||
| #if 1 | #if 1 | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | MEGDNN_ATTRIBUTE_TARGET("dotprod") | ||||
| static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, | |||||
| int32_t* output, int LDC, bool is_first_k) { | |||||
| static void kern_8x12( | |||||
| const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||||
| bool is_first_k) { | |||||
| K /= 4; | K /= 4; | ||||
| const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
| const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
| @@ -410,8 +411,9 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, | |||||
| } | } | ||||
| #else | #else | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | MEGDNN_ATTRIBUTE_TARGET("dotprod") | ||||
| static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, | |||||
| int32_t* output, int LDC, bool is_first_k) { | |||||
| static void kern_8x12( | |||||
| const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||||
| bool is_first_k) { | |||||
| K /= 4; | K /= 4; | ||||
| const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
| const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
| @@ -612,18 +614,17 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, | |||||
| "stp q15, q23, [%[outptr7]]\n" | "stp q15, q23, [%[outptr7]]\n" | ||||
| "str q31, [%[outptr7], #32]\n" | "str q31, [%[outptr7], #32]\n" | ||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [a0] "+w"(a0), | |||||
| [a1] "+w"(a1), [a0a] "+w"(a0a), [a1a] "+w"(a1a), [b0] "+w"(b0), | |||||
| [b1] "+w"(b1), [b2] "+w"(b2), [k] "+r"(k), [LDC] "+r"(LDC), | |||||
| [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), | |||||
| [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), | |||||
| [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||||
| [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), | |||||
| [outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7) | |||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [a0] "+w"(a0), [a1] "+w"(a1), | |||||
| [a0a] "+w"(a0a), [a1a] "+w"(a1a), [b0] "+w"(b0), [b1] "+w"(b1), | |||||
| [b2] "+w"(b2), [k] "+r"(k), [LDC] "+r"(LDC), [oddk] "+r"(oddk), | |||||
| [is_first_k] "+r"(is_first_k), [outptr0] "+r"(outptr0), | |||||
| [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||||
| [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6), | |||||
| [outptr7] "=r"(outptr7) | |||||
| : | : | ||||
| : "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", | |||||
| "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", | |||||
| "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory"); | |||||
| : "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||||
| "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||||
| "v29", "v30", "v31", "cc", "memory"); | |||||
| } | } | ||||
| #endif | #endif | ||||
| @@ -653,8 +654,9 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, | |||||
| // | // | ||||
| // Accumulator | // Accumulator | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | MEGDNN_ATTRIBUTE_TARGET("dotprod") | ||||
| static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, | |||||
| int32_t* output, int LDC, bool is_first_k, int m_remain) { | |||||
| static void kern_4x12( | |||||
| const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||||
| bool is_first_k, int m_remain) { | |||||
| K /= 4; | K /= 4; | ||||
| const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
| const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
| @@ -796,15 +798,15 @@ static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, | |||||
| "4:\n" STORE_C | "4:\n" STORE_C | ||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [k] "+r"(k), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [k] "+r"(k), | ||||
| [outptr0] "+r"(outptr0), [oddk] "+r"(oddk), | |||||
| [is_first_k] "+r"(is_first_k), [m_remain] "+r"(m_remain), | |||||
| [LDC] "+r"(LDC), [a0] "=w"(a0), [a0a] "=w"(a0a), [b0] "=w"(b0), | |||||
| [b1] "=w"(b1), [b2] "=w"(b2), [b0a] "=w"(b0a), [b1a] "=w"(b1a), | |||||
| [b2a] "=w"(b2a), [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), | |||||
| [outptr3] "=r"(outptr3), [x0] "=r"(x0) | |||||
| [outptr0] "+r"(outptr0), [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), | |||||
| [m_remain] "+r"(m_remain), [LDC] "+r"(LDC), [a0] "=w"(a0), | |||||
| [a0a] "=w"(a0a), [b0] "=w"(b0), [b1] "=w"(b1), [b2] "=w"(b2), | |||||
| [b0a] "=w"(b0a), [b1a] "=w"(b1a), [b2a] "=w"(b2a), | |||||
| [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||||
| [x0] "=r"(x0) | |||||
| : | : | ||||
| : "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", | |||||
| "v17", "v18", "v19", "memory", "cc"); | |||||
| : "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||||
| "v19", "memory", "cc"); | |||||
| #undef LOAD_LINE | #undef LOAD_LINE | ||||
| #undef LOAD_C | #undef LOAD_C | ||||
| @@ -840,8 +842,9 @@ static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, | |||||
| // | // | ||||
| // Accumulator | // Accumulator | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | MEGDNN_ATTRIBUTE_TARGET("dotprod") | ||||
| static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||||
| int32_t* output, int LDC, bool is_first_k, int n_remain) { | |||||
| static void kern_8x4( | |||||
| const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||||
| bool is_first_k, int n_remain) { | |||||
| K /= 4; | K /= 4; | ||||
| const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
| const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
| @@ -1004,12 +1007,11 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||||
| [n_remain] "+r"(n_remain), [k] "+r"(k), [outptr0] "+r"(outptr0), | [n_remain] "+r"(n_remain), [k] "+r"(k), [outptr0] "+r"(outptr0), | ||||
| [a0] "=w"(a0), [a1] "=w"(a1), [a0a] "=w"(a0a), [a1a] "=w"(a1a), | [a0] "=w"(a0), [a1] "=w"(a1), [a0a] "=w"(a0a), [a1a] "=w"(a1a), | ||||
| [b0] "=w"(b0), [b0a] "=w"(b0a), [outptr1] "=r"(outptr1), | [b0] "=w"(b0), [b0a] "=w"(b0a), [outptr1] "=r"(outptr1), | ||||
| [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||||
| [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), | |||||
| [outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), [x0] "=r"(x0) | |||||
| [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [outptr4] "=r"(outptr4), | |||||
| [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), | |||||
| [x0] "=r"(x0) | |||||
| : | : | ||||
| : "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "memory", | |||||
| "cc"); | |||||
| : "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "memory", "cc"); | |||||
| #undef LOAD_LINE | #undef LOAD_LINE | ||||
| #undef LOAD_C | #undef LOAD_C | ||||
| @@ -1041,9 +1043,9 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||||
| // | // | ||||
| // Accumulator | // Accumulator | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | MEGDNN_ATTRIBUTE_TARGET("dotprod") | ||||
| static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
| int32_t* output, int LDC, bool is_first_k, int m_remain, | |||||
| int n_remain) { | |||||
| static void kern_4x4( | |||||
| const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||||
| bool is_first_k, int m_remain, int n_remain) { | |||||
| K /= 4; | K /= 4; | ||||
| const int32_t* a_ptr = reinterpret_cast<const int32_t*>(packA); | const int32_t* a_ptr = reinterpret_cast<const int32_t*>(packA); | ||||
| const int32_t* b_ptr = reinterpret_cast<const int32_t*>(packB); | const int32_t* b_ptr = reinterpret_cast<const int32_t*>(packB); | ||||
| @@ -1172,10 +1174,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
| "4:\n" STORE_C | "4:\n" STORE_C | ||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [oddk] "+r"(oddk), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [oddk] "+r"(oddk), | ||||
| [is_first_k] "+r"(is_first_k), [n_remain] "+r"(n_remain), | [is_first_k] "+r"(is_first_k), [n_remain] "+r"(n_remain), | ||||
| [m_remain] "+r"(m_remain), [LDC] "+r"(LDC), | |||||
| [outptr0] "+r"(outptr0), [k] "+r"(k), [a0] "=w"(a0), | |||||
| [a0a] "=w"(a0a), [b0] "=w"(b0), [b0a] "=w"(b0a), | |||||
| [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), | |||||
| [m_remain] "+r"(m_remain), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), | |||||
| [k] "+r"(k), [a0] "=w"(a0), [a0a] "=w"(a0a), [b0] "=w"(b0), | |||||
| [b0a] "=w"(b0a), [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), | |||||
| [outptr3] "=r"(outptr3), [x0] "=r"(x0), [x1] "=r"(x1) | [outptr3] "=r"(outptr3), [x0] "=r"(x0), [x1] "=r"(x1) | ||||
| : | : | ||||
| : "v4", "v5", "v6", "v7", "memory", "cc"); | : "v4", "v5", "v6", "v7", "memory", "cc"); | ||||
| @@ -1186,9 +1187,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
| #undef STORE_C | #undef STORE_C | ||||
| } | } | ||||
| static void gemm_s8_8x12_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||||
| int ldin, int y0, int ymax, int k0, | |||||
| int kmax) { | |||||
| static void gemm_s8_8x12_pack_A_n( | |||||
| dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||||
| int kmax) { | |||||
| int8_t zerobuff[16]; | int8_t zerobuff[16]; | ||||
| std::memset(zerobuff, 0, sizeof(int8_t) * 16); | std::memset(zerobuff, 0, sizeof(int8_t) * 16); | ||||
| @@ -1215,13 +1216,15 @@ static void gemm_s8_8x12_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||||
| int K = kmax - k0; | int K = kmax - k0; | ||||
| //! read 8 * 4 in each row | //! read 8 * 4 in each row | ||||
| for (; K > 15; K -= 16) { | for (; K > 15; K -= 16) { | ||||
| interleave_8x4_4_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||||
| inptr6, inptr7, outptr); | |||||
| interleave_8x4_4_b( | |||||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
| outptr); | |||||
| } | } | ||||
| if (K > 0) { | if (K > 0) { | ||||
| interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||||
| inptr7, outptr, 4, K); | |||||
| interleave_8( | |||||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
| outptr, 4, K); | |||||
| } | } | ||||
| } | } | ||||
| for (; y < ymax; y += 4) { | for (; y < ymax; y += 4) { | ||||
| @@ -1274,8 +1277,8 @@ static void gemm_s8_8x12_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||||
| } | } | ||||
| } | } | ||||
| static void gemm_s8_8x12_pack_A_t(dt_int8* out, const dt_int8* in, int ldin, | |||||
| int x0, int xmax, int k0, int kmax) { | |||||
| static void gemm_s8_8x12_pack_A_t( | |||||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||||
| int8_t zerobuff[16]; | int8_t zerobuff[16]; | ||||
| std::memset(zerobuff, 0, sizeof(int8_t) * 16); | std::memset(zerobuff, 0, sizeof(int8_t) * 16); | ||||
| const int ksize = kmax - k0; | const int ksize = kmax - k0; | ||||
| @@ -1361,8 +1364,8 @@ static void gemm_s8_8x12_pack_A_t(dt_int8* out, const dt_int8* in, int ldin, | |||||
| } | } | ||||
| } | } | ||||
| static void gemm_s8_8x12_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||||
| int x0, int xmax, int k0, int kmax) { | |||||
| static void gemm_s8_8x12_pack_B_n( | |||||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||||
| int8_t zerobuff[16]; | int8_t zerobuff[16]; | ||||
| std::memset(zerobuff, 0, sizeof(int8_t) * 16); | std::memset(zerobuff, 0, sizeof(int8_t) * 16); | ||||
| const int ksize = kmax - k0; | const int ksize = kmax - k0; | ||||
| @@ -1448,9 +1451,9 @@ static void gemm_s8_8x12_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||||
| } | } | ||||
| } | } | ||||
| static void gemm_s8_8x12_pack_B_t(dt_int8* outptr, const dt_int8* inptr, | |||||
| int ldin, int y0, int ymax, int k0, | |||||
| int kmax) { | |||||
| static void gemm_s8_8x12_pack_B_t( | |||||
| dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||||
| int kmax) { | |||||
| int8_t zerobuff[16]; | int8_t zerobuff[16]; | ||||
| std::memset(zerobuff, 0, sizeof(int8_t) * 16); | std::memset(zerobuff, 0, sizeof(int8_t) * 16); | ||||
| @@ -1485,15 +1488,15 @@ static void gemm_s8_8x12_pack_B_t(dt_int8* outptr, const dt_int8* inptr, | |||||
| int K = kmax - k0; | int K = kmax - k0; | ||||
| //! read 12 * 4 in each row | //! read 12 * 4 in each row | ||||
| for (; K > 15; K -= 16) { | for (; K > 15; K -= 16) { | ||||
| interleave_12x4_4_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||||
| inptr6, inptr7, inptr8, inptr9, inptr10, | |||||
| inptr11, outptr); | |||||
| interleave_12x4_4_b( | |||||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
| inptr8, inptr9, inptr10, inptr11, outptr); | |||||
| } | } | ||||
| if (K > 0) { | if (K > 0) { | ||||
| interleave_12(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||||
| inptr6, inptr7, inptr8, inptr9, inptr10, inptr11, | |||||
| outptr, 4, K); | |||||
| interleave_12( | |||||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
| inptr8, inptr9, inptr10, inptr11, outptr, 4, K); | |||||
| } | } | ||||
| } | } | ||||
| for (; y < ymax; y += 4) { | for (; y < ymax; y += 4) { | ||||
| @@ -40,8 +40,9 @@ namespace matmul_mk4_8x12x4 { | |||||
| // Accumulator | // Accumulator | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | MEGDNN_ATTRIBUTE_TARGET("dotprod") | ||||
| static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, | |||||
| int32_t* output, int LDC, bool is_first_k) { | |||||
| static void kern_8x12( | |||||
| const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||||
| bool is_first_k) { | |||||
| K /= 4; | K /= 4; | ||||
| const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
| const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
| @@ -397,8 +398,9 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, | |||||
| // Accumulator | // Accumulator | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | MEGDNN_ATTRIBUTE_TARGET("dotprod") | ||||
| static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, | |||||
| int32_t* output, int LDC, bool is_first_k) { | |||||
| static void kern_4x12( | |||||
| const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||||
| bool is_first_k) { | |||||
| K /= 4; | K /= 4; | ||||
| const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
| const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
| @@ -514,13 +516,12 @@ static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, | |||||
| "stp q16, q17, [%[outptr0], #128]\n" | "stp q16, q17, [%[outptr0], #128]\n" | ||||
| "stp q18, q19, [%[outptr0], #160]\n" | "stp q18, q19, [%[outptr0], #160]\n" | ||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [k] "+r"(k), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [k] "+r"(k), | ||||
| [outptr0] "+r"(outptr0), [oddk] "+r"(oddk), | |||||
| [is_first_k] "+r"(is_first_k), [a0] "=w"(a0), [a0a] "=w"(a0a), | |||||
| [b0] "=w"(b0), [b1] "=w"(b1), [b2] "=w"(b2), [b0a] "=w"(b0a), | |||||
| [b1a] "=w"(b1a), [b2a] "=w"(b2a) | |||||
| [outptr0] "+r"(outptr0), [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), | |||||
| [a0] "=w"(a0), [a0a] "=w"(a0a), [b0] "=w"(b0), [b1] "=w"(b1), | |||||
| [b2] "=w"(b2), [b0a] "=w"(b0a), [b1a] "=w"(b1a), [b2a] "=w"(b2a) | |||||
| : | : | ||||
| : "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", | |||||
| "v17", "v18", "v19", "memory", "cc"); | |||||
| : "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||||
| "v19", "memory", "cc"); | |||||
| } | } | ||||
| // Overview of register layout: | // Overview of register layout: | ||||
| @@ -544,8 +545,9 @@ static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, | |||||
| // Accumulator | // Accumulator | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | MEGDNN_ATTRIBUTE_TARGET("dotprod") | ||||
| static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||||
| int32_t* output, int LDC, bool is_first_k, int n_remain) { | |||||
| static void kern_8x4( | |||||
| const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||||
| bool is_first_k, int n_remain) { | |||||
| K /= 4; | K /= 4; | ||||
| const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
| const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
| @@ -689,11 +691,9 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||||
| [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), | [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), | ||||
| [n_remain] "+r"(n_remain), [k] "+r"(k), [outptr0] "+r"(outptr0), | [n_remain] "+r"(n_remain), [k] "+r"(k), [outptr0] "+r"(outptr0), | ||||
| [a0] "=w"(a0), [a1] "=w"(a1), [a0a] "=w"(a0a), [a1a] "=w"(a1a), | [a0] "=w"(a0), [a1] "=w"(a1), [a0a] "=w"(a0a), [a1a] "=w"(a1a), | ||||
| [b0] "=w"(b0), [b0a] "=w"(b0a), [outptr1] "=r"(outptr1), | |||||
| [x0] "=r"(x0) | |||||
| [b0] "=w"(b0), [b0a] "=w"(b0a), [outptr1] "=r"(outptr1), [x0] "=r"(x0) | |||||
| : | : | ||||
| : "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "memory", | |||||
| "cc"); | |||||
| : "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "memory", "cc"); | |||||
| #undef LOAD_LINE | #undef LOAD_LINE | ||||
| #undef LOAD_C | #undef LOAD_C | ||||
| @@ -720,8 +720,9 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||||
| // Accumulator | // Accumulator | ||||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | MEGDNN_ATTRIBUTE_TARGET("dotprod") | ||||
| static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
| int32_t* output, int LDC, bool is_first_k, int n_remain) { | |||||
| static void kern_4x4( | |||||
| const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||||
| bool is_first_k, int n_remain) { | |||||
| K /= 4; | K /= 4; | ||||
| const int32_t* a_ptr = reinterpret_cast<const int32_t*>(packA); | const int32_t* a_ptr = reinterpret_cast<const int32_t*>(packA); | ||||
| const int32_t* b_ptr = reinterpret_cast<const int32_t*>(packB); | const int32_t* b_ptr = reinterpret_cast<const int32_t*>(packB); | ||||
| @@ -834,10 +835,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
| "4:\n" STORE_C | "4:\n" STORE_C | ||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [oddk] "+r"(oddk), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [oddk] "+r"(oddk), | ||||
| [is_first_k] "+r"(is_first_k), [n_remain] "+r"(n_remain), | |||||
| [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), [k] "+r"(k), | |||||
| [a0] "=w"(a0), [a0a] "=w"(a0a), [b0] "=w"(b0), [b0a] "=w"(b0a), | |||||
| [x0] "=r"(x0) | |||||
| [is_first_k] "+r"(is_first_k), [n_remain] "+r"(n_remain), [LDC] "+r"(LDC), | |||||
| [outptr0] "+r"(outptr0), [k] "+r"(k), [a0] "=w"(a0), [a0a] "=w"(a0a), | |||||
| [b0] "=w"(b0), [b0a] "=w"(b0a), [x0] "=r"(x0) | |||||
| : | : | ||||
| : "v4", "v5", "v6", "v7", "memory", "cc"); | : "v4", "v5", "v6", "v7", "memory", "cc"); | ||||
| @@ -847,13 +847,11 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
| #undef STORE_C | #undef STORE_C | ||||
| } | } | ||||
| static void gemm_mk4_s8_8x12_pack_A(dt_int8* outptr, const dt_int8* inptr, | |||||
| int ldin, int y0, int ymax, int k0, | |||||
| int kmax) { | |||||
| megdnn_assert(ymax % 4 == 0 && y0 % 4 == 0, | |||||
| "mk4 matmul with m is not times of 4"); | |||||
| megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0, | |||||
| "mk4 matmul with k is not times of 4"); | |||||
| static void gemm_mk4_s8_8x12_pack_A( | |||||
| dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||||
| int kmax) { | |||||
| megdnn_assert(ymax % 4 == 0 && y0 % 4 == 0, "mk4 matmul with m is not times of 4"); | |||||
| megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0, "mk4 matmul with k is not times of 4"); | |||||
| int y = y0; | int y = y0; | ||||
| int start_y = y0 / 4; | int start_y = y0 / 4; | ||||
| for (; y + 7 < ymax; y += 8, start_y += 2) { | for (; y + 7 < ymax; y += 8, start_y += 2) { | ||||
| @@ -869,15 +867,15 @@ static void gemm_mk4_s8_8x12_pack_A(dt_int8* outptr, const dt_int8* inptr, | |||||
| interleave_2x4_4_b(inptr0, inptr1, outptr); | interleave_2x4_4_b(inptr0, inptr1, outptr); | ||||
| } | } | ||||
| } | } | ||||
| for (; y + 3 < ymax; y += 4, start_y ++) { | |||||
| for (; y + 3 < ymax; y += 4, start_y++) { | |||||
| int K = kmax - k0; | int K = kmax - k0; | ||||
| const int8_t* inptr0 = inptr + start_y * ldin + (k0 << 2); | const int8_t* inptr0 = inptr + start_y * ldin + (k0 << 2); | ||||
| std::memcpy(outptr, inptr0, sizeof(dt_int8) * K * 4); | std::memcpy(outptr, inptr0, sizeof(dt_int8) * K * 4); | ||||
| } | } | ||||
| } | } | ||||
| static void gemm_mk4_s8_8x12_pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||||
| int x0, int xmax, int k0, int kmax) { | |||||
| static void gemm_mk4_s8_8x12_pack_B( | |||||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||||
| const int ksize = kmax - k0; | const int ksize = kmax - k0; | ||||
| const int ksize12 = ksize * 12; | const int ksize12 = ksize * 12; | ||||
| const int ksize4 = ksize * 4; | const int ksize4 = ksize * 4; | ||||
| @@ -12,10 +12,10 @@ | |||||
| #include "src/aarch64/matrix_mul/int8_dot/strategy.h" | #include "src/aarch64/matrix_mul/int8_dot/strategy.h" | ||||
| #if MGB_ENABLE_DOT | #if MGB_ENABLE_DOT | ||||
| #include "src/aarch64/matrix_mul/asm/common.h" | #include "src/aarch64/matrix_mul/asm/common.h" | ||||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||||
| #include "src/common/utils.h" | |||||
| #include "src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h" | #include "src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h" | ||||
| #include "src/aarch64/matrix_mul/int8_dot/kernel_mk4_8x12x4.h" | #include "src/aarch64/matrix_mul/int8_dot/kernel_mk4_8x12x4.h" | ||||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||||
| #include "src/common/utils.h" | |||||
| using namespace megdnn; | using namespace megdnn; | ||||
| using namespace aarch64; | using namespace aarch64; | ||||
| @@ -24,20 +24,19 @@ using namespace aarch64::matmul; | |||||
| /* ====================== gemm_s8_8x12 ===========================*/ | /* ====================== gemm_s8_8x12 ===========================*/ | ||||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x12); | MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x12); | ||||
| void gemm_s8_8x12::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin, | |||||
| int y0, int ymax, int k0, int kmax, | |||||
| bool transpose) const { | |||||
| void gemm_s8_8x12::pack_A( | |||||
| dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||||
| int kmax, bool transpose) const { | |||||
| if (transpose) { | if (transpose) { | ||||
| matmul_8x12x4::gemm_s8_8x12_pack_A_t(outptr, inptr, ldin, y0, ymax, k0, | |||||
| kmax); | |||||
| matmul_8x12x4::gemm_s8_8x12_pack_A_t(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||||
| } else { | } else { | ||||
| matmul_8x12x4::gemm_s8_8x12_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, | |||||
| kmax); | |||||
| matmul_8x12x4::gemm_s8_8x12_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||||
| } | } | ||||
| } | } | ||||
| void gemm_s8_8x12::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||||
| int xmax, int k0, int kmax, bool transpose) const { | |||||
| void gemm_s8_8x12::pack_B( | |||||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||||
| bool transpose) const { | |||||
| if (transpose) { | if (transpose) { | ||||
| matmul_8x12x4::gemm_s8_8x12_pack_B_t(out, in, ldin, x0, xmax, k0, kmax); | matmul_8x12x4::gemm_s8_8x12_pack_B_t(out, in, ldin, x0, xmax, k0, kmax); | ||||
| } else { | } else { | ||||
| @@ -45,16 +44,16 @@ void gemm_s8_8x12::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||||
| } | } | ||||
| } | } | ||||
| void gemm_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||||
| size_t N, size_t K, dt_int32* C, size_t LDC, | |||||
| bool is_first_k, const dt_int32*, dt_int32*) const { | |||||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
| ((A_dtype.enumv() == DTypeEnum::Int8 && | |||||
| C_dtype.enumv() == DTypeEnum::Int32) || | |||||
| (A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
| C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||||
| "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||||
| C_dtype.name()); | |||||
| void gemm_s8_8x12::kern( | |||||
| const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||||
| dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const { | |||||
| megdnn_assert( | |||||
| A_dtype.enumv() == B_dtype.enumv() && | |||||
| ((A_dtype.enumv() == DTypeEnum::Int8 && | |||||
| C_dtype.enumv() == DTypeEnum::Int32) || | |||||
| (A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
| C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||||
| "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); | |||||
| MEGDNN_MARK_USED_VAR(A_dtype); | MEGDNN_MARK_USED_VAR(A_dtype); | ||||
| MEGDNN_MARK_USED_VAR(B_dtype); | MEGDNN_MARK_USED_VAR(B_dtype); | ||||
| @@ -75,15 +74,15 @@ void gemm_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||||
| size_t n = 0; | size_t n = 0; | ||||
| const dt_int8* cur_packB = packB; | const dt_int8* cur_packB = packB; | ||||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
| matmul_8x12x4::kern_8x12(packA, cur_packB, K, output, LDC, | |||||
| is_first_k); | |||||
| matmul_8x12x4::kern_8x12(packA, cur_packB, K, output, LDC, is_first_k); | |||||
| output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
| cur_packB += K12; | cur_packB += K12; | ||||
| } | } | ||||
| for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
| matmul_8x12x4::kern_8x4(packA, cur_packB, K, output, LDC, | |||||
| is_first_k, std::min<size_t>(N - n, 4)); | |||||
| matmul_8x12x4::kern_8x4( | |||||
| packA, cur_packB, K, output, LDC, is_first_k, | |||||
| std::min<size_t>(N - n, 4)); | |||||
| output += 4; | output += 4; | ||||
| cur_packB += K4; | cur_packB += K4; | ||||
| } | } | ||||
| @@ -95,16 +94,17 @@ void gemm_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||||
| const dt_int8* cur_packB = packB; | const dt_int8* cur_packB = packB; | ||||
| size_t n = 0; | size_t n = 0; | ||||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
| matmul_8x12x4::kern_4x12(packA, cur_packB, K, output, LDC, | |||||
| is_first_k, std::min<size_t>(M - m, 4)); | |||||
| matmul_8x12x4::kern_4x12( | |||||
| packA, cur_packB, K, output, LDC, is_first_k, | |||||
| std::min<size_t>(M - m, 4)); | |||||
| output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
| cur_packB += K12; | cur_packB += K12; | ||||
| } | } | ||||
| for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
| matmul_8x12x4::kern_4x4(packA, cur_packB, K, output, LDC, | |||||
| is_first_k, std::min<size_t>(M - m, 4), | |||||
| std::min<size_t>(N - n, 4)); | |||||
| matmul_8x12x4::kern_4x4( | |||||
| packA, cur_packB, K, output, LDC, is_first_k, | |||||
| std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||||
| output += 4; | output += 4; | ||||
| cur_packB += K4; | cur_packB += K4; | ||||
| } | } | ||||
| @@ -115,32 +115,32 @@ void gemm_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||||
| /* ====================== gemm_mk4_s8_8x12 ===========================*/ | /* ====================== gemm_mk4_s8_8x12 ===========================*/ | ||||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_s8_8x12); | MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_s8_8x12); | ||||
| void gemm_mk4_s8_8x12::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin, | |||||
| int y0, int ymax, int k0, int kmax, | |||||
| bool transpose) const { | |||||
| megdnn_assert(!transpose, "matrix mul mk4 with transposed matrix A is not supported"); | |||||
| matmul_mk4_8x12x4::gemm_mk4_s8_8x12_pack_A(outptr, inptr, ldin, y0, ymax, k0, | |||||
| kmax); | |||||
| void gemm_mk4_s8_8x12::pack_A( | |||||
| dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||||
| int kmax, bool transpose) const { | |||||
| megdnn_assert( | |||||
| !transpose, "matrix mul mk4 with transposed matrix A is not supported"); | |||||
| matmul_mk4_8x12x4::gemm_mk4_s8_8x12_pack_A(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||||
| } | } | ||||
| void gemm_mk4_s8_8x12::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||||
| int xmax, int k0, int kmax, | |||||
| bool transpose) const { | |||||
| megdnn_assert(!transpose, "matrix mul mk4 with transposed matrix B is not supported"); | |||||
| void gemm_mk4_s8_8x12::pack_B( | |||||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||||
| bool transpose) const { | |||||
| megdnn_assert( | |||||
| !transpose, "matrix mul mk4 with transposed matrix B is not supported"); | |||||
| matmul_mk4_8x12x4::gemm_mk4_s8_8x12_pack_B(out, in, ldin, x0, xmax, k0, kmax); | matmul_mk4_8x12x4::gemm_mk4_s8_8x12_pack_B(out, in, ldin, x0, xmax, k0, kmax); | ||||
| } | } | ||||
| void gemm_mk4_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, | |||||
| size_t M, size_t N, size_t K, dt_int32* C, | |||||
| size_t LDC, bool is_first_k, const dt_int32*, | |||||
| dt_int32*) const { | |||||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
| ((A_dtype.enumv() == DTypeEnum::Int8 && | |||||
| C_dtype.enumv() == DTypeEnum::Int32) || | |||||
| (A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
| C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||||
| "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||||
| C_dtype.name()); | |||||
| void gemm_mk4_s8_8x12::kern( | |||||
| const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||||
| dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const { | |||||
| megdnn_assert( | |||||
| A_dtype.enumv() == B_dtype.enumv() && | |||||
| ((A_dtype.enumv() == DTypeEnum::Int8 && | |||||
| C_dtype.enumv() == DTypeEnum::Int32) || | |||||
| (A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
| C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||||
| "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); | |||||
| MEGDNN_MARK_USED_VAR(A_dtype); | MEGDNN_MARK_USED_VAR(A_dtype); | ||||
| MEGDNN_MARK_USED_VAR(B_dtype); | MEGDNN_MARK_USED_VAR(B_dtype); | ||||
| @@ -161,15 +161,15 @@ void gemm_mk4_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, | |||||
| size_t n = 0; | size_t n = 0; | ||||
| const dt_int8* cur_packB = packB; | const dt_int8* cur_packB = packB; | ||||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
| matmul_mk4_8x12x4::kern_8x12(packA, cur_packB, K, output, LDC, | |||||
| is_first_k); | |||||
| matmul_mk4_8x12x4::kern_8x12(packA, cur_packB, K, output, LDC, is_first_k); | |||||
| output += (B_INTERLEAVE << 2); | output += (B_INTERLEAVE << 2); | ||||
| cur_packB += K12; | cur_packB += K12; | ||||
| } | } | ||||
| for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
| matmul_mk4_8x12x4::kern_8x4(packA, cur_packB, K, output, LDC, | |||||
| is_first_k, std::min<size_t>(N - n, 4)); | |||||
| matmul_mk4_8x12x4::kern_8x4( | |||||
| packA, cur_packB, K, output, LDC, is_first_k, | |||||
| std::min<size_t>(N - n, 4)); | |||||
| output += 16; | output += 16; | ||||
| cur_packB += K4; | cur_packB += K4; | ||||
| } | } | ||||
| @@ -181,15 +181,15 @@ void gemm_mk4_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, | |||||
| const dt_int8* cur_packB = packB; | const dt_int8* cur_packB = packB; | ||||
| size_t n = 0; | size_t n = 0; | ||||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
| matmul_mk4_8x12x4::kern_4x12(packA, cur_packB, K, output, LDC, | |||||
| is_first_k); | |||||
| matmul_mk4_8x12x4::kern_4x12(packA, cur_packB, K, output, LDC, is_first_k); | |||||
| output += (B_INTERLEAVE << 2); | output += (B_INTERLEAVE << 2); | ||||
| cur_packB += K12; | cur_packB += K12; | ||||
| } | } | ||||
| for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
| matmul_mk4_8x12x4::kern_4x4(packA, cur_packB, K, output, LDC, | |||||
| is_first_k, std::min<size_t>(N - n, 4)); | |||||
| matmul_mk4_8x12x4::kern_4x4( | |||||
| packA, cur_packB, K, output, LDC, is_first_k, | |||||
| std::min<size_t>(N - n, 4)); | |||||
| output += 16; | output += 16; | ||||
| cur_packB += K4; | cur_packB += K4; | ||||
| } | } | ||||
| @@ -16,14 +16,14 @@ namespace megdnn { | |||||
| namespace aarch64 { | namespace aarch64 { | ||||
| namespace matmul { | namespace matmul { | ||||
| MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 12, 4, false, true, | |||||
| gemm_s8_8x12); | |||||
| MEGDNN_REG_GEMM_STRATEGY( | |||||
| dt_int8, dt_int32, dt_int32, 8, 12, 4, false, true, gemm_s8_8x12); | |||||
| MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 12, 4, false, true, | |||||
| gemm_mk4_s8_8x12); | |||||
| MEGDNN_REG_GEMM_STRATEGY( | |||||
| dt_int8, dt_int32, dt_int32, 8, 12, 4, false, true, gemm_mk4_s8_8x12); | |||||
| } // namespace aarch64 | |||||
| } // namespace matmul | } // namespace matmul | ||||
| } // namespace aarch64 | |||||
| } // namespace megdnn | } // namespace megdnn | ||||
| #endif | #endif | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -34,9 +34,9 @@ namespace matmul_4x4x16 { | |||||
| * | * | ||||
| * Accumulator | * Accumulator | ||||
| */ | */ | ||||
| static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
| int16_t* output, int LDC, bool is_first_k, int m_remain, | |||||
| int n_remain) { | |||||
| static void kern_4x4( | |||||
| const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||||
| bool is_first_k, int m_remain, int n_remain) { | |||||
| K /= 16; | K /= 16; | ||||
| const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
| const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
| @@ -230,16 +230,14 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
| // Store back into memory | // Store back into memory | ||||
| STORE_C | STORE_C | ||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||||
| [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||||
| [outptr] "+r"(outptr), [m_remain] "+r"(m_remain), | |||||
| [n_remain] "+r"(n_remain) | |||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||||
| [K] "+r"(K), [LDC] "+r"(LDC), [outptr] "+r"(outptr), | |||||
| [m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain) | |||||
| : | : | ||||
| : "cc", "memory", "x1", "x2", "x3", "x4", "x5", "v0", "v1", "v2", | |||||
| "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", | |||||
| "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||||
| "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", | |||||
| "v31"); | |||||
| : "cc", "memory", "x1", "x2", "x3", "x4", "x5", "v0", "v1", "v2", "v3", | |||||
| "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||||
| "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", | |||||
| "v25", "v26", "v27", "v28", "v29", "v30", "v31"); | |||||
| #undef LOAD_LINE | #undef LOAD_LINE | ||||
| #undef LOAD_C | #undef LOAD_C | ||||
| @@ -247,9 +245,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
| #undef STORE_C | #undef STORE_C | ||||
| } | } | ||||
| static void gemm_s8x8x16_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||||
| int ldin, int y0, int ymax, int k0, | |||||
| int kmax) { | |||||
| static void gemm_s8x8x16_4x4_pack_A_n( | |||||
| dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||||
| int kmax) { | |||||
| int8_t zerobuff[16]; | int8_t zerobuff[16]; | ||||
| std::memset(zerobuff, 0, sizeof(int8_t) * 16); | std::memset(zerobuff, 0, sizeof(int8_t) * 16); | ||||
| @@ -292,9 +290,11 @@ static void gemm_s8x8x16_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||||
| if (y + 3 >= ymax) { | if (y + 3 >= ymax) { | ||||
| switch (y + 3 - ymax) { | switch (y + 3 - ymax) { | ||||
| case 2: | case 2: | ||||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr1 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 1: | case 1: | ||||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr2 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 0: | case 0: | ||||
| inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
| break; | break; | ||||
| @@ -309,9 +309,11 @@ static void gemm_s8x8x16_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||||
| if (y + 3 >= ymax) { | if (y + 3 >= ymax) { | ||||
| switch (y + 3 - ymax) { | switch (y + 3 - ymax) { | ||||
| case 2: | case 2: | ||||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr1 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 1: | case 1: | ||||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr2 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 0: | case 0: | ||||
| inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
| break; | break; | ||||
| @@ -324,8 +326,8 @@ static void gemm_s8x8x16_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||||
| } | } | ||||
| } | } | ||||
| static void gemm_s8x8x16_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||||
| int x0, int xmax, int k0, int kmax) { | |||||
| static void gemm_s8x8x16_4x4_pack_B_n( | |||||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||||
| int8_t zerobuff[16]; | int8_t zerobuff[16]; | ||||
| std::memset(zerobuff, 0, sizeof(int8_t) * 16); | std::memset(zerobuff, 0, sizeof(int8_t) * 16); | ||||
| const int ksize = kmax - k0; | const int ksize = kmax - k0; | ||||
| @@ -362,19 +364,26 @@ static void gemm_s8x8x16_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||||
| if (remain >= 0) { | if (remain >= 0) { | ||||
| switch (remain) { | switch (remain) { | ||||
| case 7: | case 7: | ||||
| inptr0 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr0 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 6: | case 6: | ||||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr1 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 5: | case 5: | ||||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr2 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 4: | case 4: | ||||
| inptr3 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr3 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 3: | case 3: | ||||
| inptr4 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr4 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 2: | case 2: | ||||
| inptr5 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr5 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 1: | case 1: | ||||
| inptr6 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr6 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 0: | case 0: | ||||
| inptr7 = zerobuff; | inptr7 = zerobuff; | ||||
| break; | break; | ||||
| @@ -383,9 +392,9 @@ static void gemm_s8x8x16_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||||
| } | } | ||||
| } | } | ||||
| transpose_4x16_1_b_helper(inptr0, inptr1, inptr2, inptr3, | |||||
| inptr4, inptr5, inptr6, inptr7, | |||||
| outptr_inner); | |||||
| transpose_4x16_1_b_helper( | |||||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
| outptr_inner); | |||||
| outptr_inner += ksize4; | outptr_inner += ksize4; | ||||
| } | } | ||||
| @@ -393,19 +402,26 @@ static void gemm_s8x8x16_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||||
| if (remain >= 0) { | if (remain >= 0) { | ||||
| switch (remain) { | switch (remain) { | ||||
| case 7: | case 7: | ||||
| inptr0 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr0 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 6: | case 6: | ||||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr1 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 5: | case 5: | ||||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr2 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 4: | case 4: | ||||
| inptr3 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr3 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 3: | case 3: | ||||
| inptr4 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr4 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 2: | case 2: | ||||
| inptr5 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr5 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 1: | case 1: | ||||
| inptr6 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr6 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 0: | case 0: | ||||
| inptr7 = zerobuff; | inptr7 = zerobuff; | ||||
| break; | break; | ||||
| @@ -42,8 +42,9 @@ namespace matmul_8x8x8 { | |||||
| * | * | ||||
| * Accumulator | * Accumulator | ||||
| */ | */ | ||||
| static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||||
| int16_t* output, int LDC, bool is_first_k) { | |||||
| static void kern_8x8( | |||||
| const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||||
| bool is_first_k) { | |||||
| K /= 8; | K /= 8; | ||||
| const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
| const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
| @@ -217,13 +218,12 @@ static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||||
| "bne 2b\n" | "bne 2b\n" | ||||
| "3:\n" STORE_C | "3:\n" STORE_C | ||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||||
| [outptr] "+r"(outptr) | |||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC), | |||||
| [is_first_k] "+r"(is_first_k), [outptr] "+r"(outptr) | |||||
| : | : | ||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "x1", "x2", "x3", | |||||
| "x4", "x5", "x6", "x7", "cc", "memory"); | |||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||||
| "v12", "v13", "v14", "v15", "v16", "v17", "x1", "x2", "x3", "x4", "x5", | |||||
| "x6", "x7", "cc", "memory"); | |||||
| #undef LOAD_LINE | #undef LOAD_LINE | ||||
| #undef LOAD_C | #undef LOAD_C | ||||
| #undef STORE_LINE | #undef STORE_LINE | ||||
| @@ -258,9 +258,9 @@ static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||||
| * Accumulator | * Accumulator | ||||
| */ | */ | ||||
| static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||||
| int16_t* output, int LDC, bool is_first_k, | |||||
| size_t n_remain) { | |||||
| static void kern_8x4( | |||||
| const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||||
| bool is_first_k, size_t n_remain) { | |||||
| K /= 8; | K /= 8; | ||||
| const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
| const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
| @@ -471,16 +471,14 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||||
| "cbnz %w[K], 2b\n" | "cbnz %w[K], 2b\n" | ||||
| "3:\n" STORE_C | "3:\n" STORE_C | ||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||||
| [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||||
| [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), | |||||
| [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||||
| [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), | |||||
| [outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), [x0] "+r"(x0), | |||||
| [n_remain] "+r"(n_remain) | |||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||||
| [K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), | |||||
| [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||||
| [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6), | |||||
| [outptr7] "=r"(outptr7), [x0] "+r"(x0), [n_remain] "+r"(n_remain) | |||||
| : | : | ||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "cc", "memory"); | |||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||||
| "v12", "v13", "v14", "v15", "v16", "v17", "cc", "memory"); | |||||
| #undef LOAD_LINE | #undef LOAD_LINE | ||||
| #undef LOAD_C | #undef LOAD_C | ||||
| @@ -514,9 +512,9 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||||
| * Accumulator | * Accumulator | ||||
| */ | */ | ||||
| static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, | |||||
| int16_t* output, int LDC, bool is_first_k, | |||||
| size_t m_remain) { | |||||
| static void kern_4x8( | |||||
| const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||||
| bool is_first_k, size_t m_remain) { | |||||
| K /= 8; | K /= 8; | ||||
| const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
| const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
| @@ -646,11 +644,10 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, | |||||
| "cbnz %w[K], 2b\n" | "cbnz %w[K], 2b\n" | ||||
| "3:\n" STORE_C | "3:\n" STORE_C | ||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||||
| [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||||
| [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), | |||||
| [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [x0] "+r"(x0), | |||||
| [m_remain] "+r"(m_remain) | |||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||||
| [K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), | |||||
| [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||||
| [x0] "+r"(x0), [m_remain] "+r"(m_remain) | |||||
| : | : | ||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "cc", | : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "cc", | ||||
| "memory"); | "memory"); | ||||
| @@ -686,9 +683,9 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, | |||||
| * | * | ||||
| * Accumulator | * Accumulator | ||||
| */ | */ | ||||
| static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
| int16_t* output, int LDC, bool is_first_k, size_t m_remain, | |||||
| size_t n_remain) { | |||||
| static void kern_4x4( | |||||
| const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||||
| bool is_first_k, size_t m_remain, size_t n_remain) { | |||||
| K /= 8; | K /= 8; | ||||
| const int8_t* a_ptr = packA; | const int8_t* a_ptr = packA; | ||||
| const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
| @@ -853,11 +850,10 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
| "3:\n" STORE_C | "3:\n" STORE_C | ||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [outptr] "+r"(outptr), | : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [outptr] "+r"(outptr), | ||||
| [K] "+r"(K), [is_first_k] "+r"(is_first_k), [LDC] "+r"(LDC), | [K] "+r"(K), [is_first_k] "+r"(is_first_k), [LDC] "+r"(LDC), | ||||
| [x0] "+r"(x0), [m_remain] "+r"(m_remain), | |||||
| [n_remain] "+r"(n_remain) | |||||
| [x0] "+r"(x0), [m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain) | |||||
| : | : | ||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "x1", | |||||
| "cc", "memory"); | |||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "x1", "cc", | |||||
| "memory"); | |||||
| #undef LOAD_LINE | #undef LOAD_LINE | ||||
| #undef LOAD_C | #undef LOAD_C | ||||
| @@ -865,9 +861,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
| #undef STORE_C | #undef STORE_C | ||||
| } | } | ||||
| static void gemm_s8x8x16_8x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||||
| int ldin, int y0, int ymax, int k0, | |||||
| int kmax) { | |||||
| static void gemm_s8x8x16_8x8_pack_A_n( | |||||
| dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||||
| int kmax) { | |||||
| int8_t zerobuff[16]; | int8_t zerobuff[16]; | ||||
| std::memset(zerobuff, 0, sizeof(int8_t) * 16); | std::memset(zerobuff, 0, sizeof(int8_t) * 16); | ||||
| @@ -893,13 +889,15 @@ static void gemm_s8x8x16_8x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||||
| int K = kmax - k0; | int K = kmax - k0; | ||||
| for (; K > 15; K -= 16) { | for (; K > 15; K -= 16) { | ||||
| interleave_8x8_2_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||||
| inptr6, inptr7, outptr); | |||||
| interleave_8x8_2_b( | |||||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
| outptr); | |||||
| } | } | ||||
| if (K > 0) { | if (K > 0) { | ||||
| interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||||
| inptr7, outptr, 8, K); | |||||
| interleave_8( | |||||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
| outptr, 8, K); | |||||
| } | } | ||||
| } | } | ||||
| for (; y < ymax; y += 4) { | for (; y < ymax; y += 4) { | ||||
| @@ -918,9 +916,11 @@ static void gemm_s8x8x16_8x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||||
| if (y + 3 >= ymax) { | if (y + 3 >= ymax) { | ||||
| switch (y + 3 - ymax) { | switch (y + 3 - ymax) { | ||||
| case 2: | case 2: | ||||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr1 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 1: | case 1: | ||||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr2 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 0: | case 0: | ||||
| inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
| break; | break; | ||||
| @@ -936,9 +936,11 @@ static void gemm_s8x8x16_8x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||||
| if (y + 3 >= ymax) { | if (y + 3 >= ymax) { | ||||
| switch (y + 3 - ymax) { | switch (y + 3 - ymax) { | ||||
| case 2: | case 2: | ||||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr1 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 1: | case 1: | ||||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr2 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 0: | case 0: | ||||
| inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
| break; | break; | ||||
| @@ -951,9 +953,8 @@ static void gemm_s8x8x16_8x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||||
| } | } | ||||
| } | } | ||||
| static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, | |||||
| int ldin, int x0, int xmax, | |||||
| int k0, int kmax) { | |||||
| static void gemm_s8x8x16_8x8_transpose_pack_A_n( | |||||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||||
| int8_t zerobuff[16]; | int8_t zerobuff[16]; | ||||
| std::memset(zerobuff, 0, sizeof(int8_t) * 16); | std::memset(zerobuff, 0, sizeof(int8_t) * 16); | ||||
| @@ -991,17 +992,23 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, | |||||
| if (k + 7 >= kmax) { | if (k + 7 >= kmax) { | ||||
| switch (k + 7 - kmax) { | switch (k + 7 - kmax) { | ||||
| case 6: | case 6: | ||||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr1 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 5: | case 5: | ||||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr2 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 4: | case 4: | ||||
| inptr3 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr3 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 3: | case 3: | ||||
| inptr4 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr4 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 2: | case 2: | ||||
| inptr5 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr5 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 1: | case 1: | ||||
| inptr6 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr6 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 0: | case 0: | ||||
| inptr7 = zerobuff; | inptr7 = zerobuff; | ||||
| break; | break; | ||||
| @@ -1009,8 +1016,9 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, | |||||
| megdnn_assert(0); | megdnn_assert(0); | ||||
| } | } | ||||
| } | } | ||||
| transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||||
| inptr6, inptr7, outptr); | |||||
| transpose_8x8_1_b( | |||||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
| outptr); | |||||
| outptr += ksize8; | outptr += ksize8; | ||||
| } | } | ||||
| @@ -1019,17 +1027,23 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, | |||||
| if (k + 7 >= kmax) { | if (k + 7 >= kmax) { | ||||
| switch (k + 7 - kmax) { | switch (k + 7 - kmax) { | ||||
| case 6: | case 6: | ||||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr1 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 5: | case 5: | ||||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr2 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 4: | case 4: | ||||
| inptr3 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr3 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 3: | case 3: | ||||
| inptr4 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr4 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 2: | case 2: | ||||
| inptr5 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr5 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 1: | case 1: | ||||
| inptr6 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr6 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 0: | case 0: | ||||
| inptr7 = zerobuff; | inptr7 = zerobuff; | ||||
| break; | break; | ||||
| @@ -1038,8 +1052,9 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, | |||||
| } | } | ||||
| } | } | ||||
| transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||||
| inptr7, outptr, 4, 4); | |||||
| transpose_8( | |||||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
| outptr, 4, 4); | |||||
| outptr += ksize4; | outptr += ksize4; | ||||
| } | } | ||||
| @@ -1047,17 +1062,23 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, | |||||
| if (k + 7 >= kmax) { | if (k + 7 >= kmax) { | ||||
| switch (k + 7 - kmax) { | switch (k + 7 - kmax) { | ||||
| case 6: | case 6: | ||||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr1 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 5: | case 5: | ||||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr2 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 4: | case 4: | ||||
| inptr3 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr3 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 3: | case 3: | ||||
| inptr4 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr4 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 2: | case 2: | ||||
| inptr5 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr5 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 1: | case 1: | ||||
| inptr6 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr6 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 0: | case 0: | ||||
| inptr7 = zerobuff; | inptr7 = zerobuff; | ||||
| break; | break; | ||||
| @@ -1066,8 +1087,9 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, | |||||
| } | } | ||||
| } | } | ||||
| transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||||
| inptr7, outptr, 4, xmax - x); | |||||
| transpose_8( | |||||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
| outptr, 4, xmax - x); | |||||
| } | } | ||||
| outptr_base += 8 * 8; | outptr_base += 8 * 8; | ||||
| @@ -1075,8 +1097,8 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, | |||||
| } | } | ||||
| } | } | ||||
| static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||||
| int x0, int xmax, int k0, int kmax) { | |||||
| static void gemm_s8x8x16_8x8_pack_B_n( | |||||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||||
| int8_t zerobuff[16]; | int8_t zerobuff[16]; | ||||
| std::memset(zerobuff, 0, sizeof(int8_t) * 16); | std::memset(zerobuff, 0, sizeof(int8_t) * 16); | ||||
| const int ksize = kmax - k0; | const int ksize = kmax - k0; | ||||
| @@ -1113,17 +1135,23 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||||
| if (k + 7 >= kmax) { | if (k + 7 >= kmax) { | ||||
| switch (k + 7 - kmax) { | switch (k + 7 - kmax) { | ||||
| case 6: | case 6: | ||||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr1 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 5: | case 5: | ||||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr2 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 4: | case 4: | ||||
| inptr3 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr3 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 3: | case 3: | ||||
| inptr4 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr4 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 2: | case 2: | ||||
| inptr5 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr5 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 1: | case 1: | ||||
| inptr6 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr6 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 0: | case 0: | ||||
| inptr7 = zerobuff; | inptr7 = zerobuff; | ||||
| break; | break; | ||||
| @@ -1132,8 +1160,9 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||||
| } | } | ||||
| } | } | ||||
| outptr_interleave = outptr; | outptr_interleave = outptr; | ||||
| interleave_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||||
| inptr6, inptr7, outptr_interleave); | |||||
| interleave_8x8_1_b( | |||||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
| outptr_interleave); | |||||
| outptr += ksize8; | outptr += ksize8; | ||||
| } | } | ||||
| @@ -1142,17 +1171,23 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||||
| if (k + 7 >= kmax) { | if (k + 7 >= kmax) { | ||||
| switch (k + 7 - kmax) { | switch (k + 7 - kmax) { | ||||
| case 6: | case 6: | ||||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr1 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 5: | case 5: | ||||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr2 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 4: | case 4: | ||||
| inptr3 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr3 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 3: | case 3: | ||||
| inptr4 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr4 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 2: | case 2: | ||||
| inptr5 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr5 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 1: | case 1: | ||||
| inptr6 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr6 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 0: | case 0: | ||||
| inptr7 = zerobuff; | inptr7 = zerobuff; | ||||
| break; | break; | ||||
| @@ -1162,8 +1197,9 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||||
| } | } | ||||
| outptr_interleave = outptr; | outptr_interleave = outptr; | ||||
| interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||||
| inptr7, outptr_interleave, 4, 4); | |||||
| interleave_8( | |||||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
| outptr_interleave, 4, 4); | |||||
| outptr += ksize4; | outptr += ksize4; | ||||
| } | } | ||||
| @@ -1171,17 +1207,23 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||||
| if (k + 7 >= kmax) { | if (k + 7 >= kmax) { | ||||
| switch (k + 7 - kmax) { | switch (k + 7 - kmax) { | ||||
| case 6: | case 6: | ||||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr1 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 5: | case 5: | ||||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr2 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 4: | case 4: | ||||
| inptr3 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr3 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 3: | case 3: | ||||
| inptr4 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr4 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 2: | case 2: | ||||
| inptr5 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr5 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 1: | case 1: | ||||
| inptr6 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr6 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 0: | case 0: | ||||
| inptr7 = zerobuff; | inptr7 = zerobuff; | ||||
| break; | break; | ||||
| @@ -1191,8 +1233,9 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||||
| } | } | ||||
| outptr_interleave = outptr; | outptr_interleave = outptr; | ||||
| interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||||
| inptr7, outptr_interleave, 4, xmax - x); | |||||
| interleave_8( | |||||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
| outptr_interleave, 4, xmax - x); | |||||
| } | } | ||||
| outptr_base += 8 * 8; | outptr_base += 8 * 8; | ||||
| @@ -1200,10 +1243,9 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||||
| } | } | ||||
| } | } | ||||
| static void gemm_s8x8x16_8x8_transpose_pack_B_n(dt_int8* outptr, | |||||
| const dt_int8* inptr, int ldin, | |||||
| int y0, int ymax, int k0, | |||||
| int kmax) { | |||||
| static void gemm_s8x8x16_8x8_transpose_pack_B_n( | |||||
| dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||||
| int kmax) { | |||||
| int8_t zerobuff[16]; | int8_t zerobuff[16]; | ||||
| std::memset(zerobuff, 0, sizeof(int8_t) * 16); | std::memset(zerobuff, 0, sizeof(int8_t) * 16); | ||||
| constexpr int interleave4 = 32; | constexpr int interleave4 = 32; | ||||
| @@ -1231,14 +1273,16 @@ static void gemm_s8x8x16_8x8_transpose_pack_B_n(dt_int8* outptr, | |||||
| int K = kmax - k0; | int K = kmax - k0; | ||||
| for (; K > 7; K -= 8) { | for (; K > 7; K -= 8) { | ||||
| transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||||
| inptr6, inptr7, outptr); | |||||
| transpose_8x8_1_b( | |||||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
| outptr); | |||||
| outptr += interleave8; | outptr += interleave8; | ||||
| } | } | ||||
| if (K > 0) { | if (K > 0) { | ||||
| transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||||
| inptr7, outptr, 8, K); | |||||
| transpose_8( | |||||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||||
| outptr, 8, K); | |||||
| outptr += interleave8; | outptr += interleave8; | ||||
| } | } | ||||
| } | } | ||||
| @@ -1259,9 +1303,11 @@ static void gemm_s8x8x16_8x8_transpose_pack_B_n(dt_int8* outptr, | |||||
| if (y + 3 >= ymax) { | if (y + 3 >= ymax) { | ||||
| switch (y + 3 - ymax) { | switch (y + 3 - ymax) { | ||||
| case 2: | case 2: | ||||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr1 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 1: | case 1: | ||||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr2 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 0: | case 0: | ||||
| inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
| break; | break; | ||||
| @@ -1278,9 +1324,11 @@ static void gemm_s8x8x16_8x8_transpose_pack_B_n(dt_int8* outptr, | |||||
| if (y + 3 >= ymax) { | if (y + 3 >= ymax) { | ||||
| switch (y + 3 - ymax) { | switch (y + 3 - ymax) { | ||||
| case 2: | case 2: | ||||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr1 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 1: | case 1: | ||||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||||
| inptr2 = zerobuff; | |||||
| MEGDNN_FALLTHRU | |||||
| case 0: | case 0: | ||||
| inptr3 = zerobuff; | inptr3 = zerobuff; | ||||
| break; | break; | ||||
| @@ -40,11 +40,9 @@ namespace matmul_mk4_16x12x4_a53 { | |||||
| * Accumulator | * Accumulator | ||||
| */ | */ | ||||
| // clang-format on | // clang-format on | ||||
| static __attribute__((noinline)) void kern_16x12(const int16_t* packA, | |||||
| const int8_t* packB, int K, | |||||
| int16_t* output, int LDC, | |||||
| bool is_first_k, | |||||
| int remain_n) { | |||||
| static __attribute__((noinline)) void kern_16x12( | |||||
| const int16_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||||
| bool is_first_k, int remain_n) { | |||||
| K /= 4; | K /= 4; | ||||
| const int16_t* a_ptr = packA; | const int16_t* a_ptr = packA; | ||||
| const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
| @@ -521,15 +519,15 @@ static __attribute__((noinline)) void kern_16x12(const int16_t* packA, | |||||
| "6:\n" STORE_C | "6:\n" STORE_C | ||||
| "101:\n" | "101:\n" | ||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||||
| [outptr] "+r"(outptr), [remain_n] "+r"(remain_n) | |||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC), | |||||
| [is_first_k] "+r"(is_first_k), [outptr] "+r"(outptr), | |||||
| [remain_n] "+r"(remain_n) | |||||
| : | : | ||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||||
| "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||||
| "v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7", | |||||
| "x8", "x9", "x10", "cc", "memory"); | |||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||||
| "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||||
| "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", | |||||
| "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc", | |||||
| "memory"); | |||||
| #undef STORE_C | #undef STORE_C | ||||
| #undef STORE_LINE | #undef STORE_LINE | ||||
| @@ -554,10 +552,9 @@ static __attribute__((noinline)) void kern_16x12(const int16_t* packA, | |||||
| * Accumulator | * Accumulator | ||||
| */ | */ | ||||
| // clang-format on | // clang-format on | ||||
| static __attribute__((noinline)) void kern_8x12(const int16_t* packA, | |||||
| const int8_t* packB, int K, | |||||
| int16_t* output, int LDC, | |||||
| bool is_first_k, int remain_n) { | |||||
| static __attribute__((noinline)) void kern_8x12( | |||||
| const int16_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||||
| bool is_first_k, int remain_n) { | |||||
| K /= 4; | K /= 4; | ||||
| const int16_t* a_ptr = packA; | const int16_t* a_ptr = packA; | ||||
| const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
| @@ -858,14 +855,13 @@ static __attribute__((noinline)) void kern_8x12(const int16_t* packA, | |||||
| "6:\n" STORE_C | "6:\n" STORE_C | ||||
| "101:\n" | "101:\n" | ||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||||
| [outptr] "+r"(outptr), [remain_n] "+r"(remain_n) | |||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC), | |||||
| [is_first_k] "+r"(is_first_k), [outptr] "+r"(outptr), | |||||
| [remain_n] "+r"(remain_n) | |||||
| : | : | ||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||||
| "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc", | |||||
| "memory"); | |||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||||
| "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", "x2", "x3", | |||||
| "x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc", "memory"); | |||||
| #undef STORE_C | #undef STORE_C | ||||
| #undef STORE_LINE | #undef STORE_LINE | ||||
| @@ -890,10 +886,9 @@ static __attribute__((noinline)) void kern_8x12(const int16_t* packA, | |||||
| * Accumulator | * Accumulator | ||||
| */ | */ | ||||
| // clang-format on | // clang-format on | ||||
| static __attribute__((noinline)) void kern_4x12(const int16_t* packA, | |||||
| const int8_t* packB, int K, | |||||
| int16_t* output, int LDC, | |||||
| bool is_first_k, int remain_n) { | |||||
| static __attribute__((noinline)) void kern_4x12( | |||||
| const int16_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||||
| bool is_first_k, int remain_n) { | |||||
| K /= 4; | K /= 4; | ||||
| const int16_t* a_ptr = packA; | const int16_t* a_ptr = packA; | ||||
| const int8_t* b_ptr = packB; | const int8_t* b_ptr = packB; | ||||
| @@ -1162,22 +1157,21 @@ static __attribute__((noinline)) void kern_4x12(const int16_t* packA, | |||||
| "6:\n" STORE_C | "6:\n" STORE_C | ||||
| "101:\n" | "101:\n" | ||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||||
| [outptr] "+r"(outptr), [remain_n] "+r"(remain_n) | |||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC), | |||||
| [is_first_k] "+r"(is_first_k), [outptr] "+r"(outptr), | |||||
| [remain_n] "+r"(remain_n) | |||||
| : | : | ||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||||
| "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc", | |||||
| "memory"); | |||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||||
| "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", "x2", "x3", | |||||
| "x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc", "memory"); | |||||
| #undef STORE_C | #undef STORE_C | ||||
| #undef STORE_LINE | #undef STORE_LINE | ||||
| } | } | ||||
| static void gemm_s8x8x16_mk4_16x12_pack_A(dt_int16* outptr, | |||||
| const dt_int8* inptr, int ldin, | |||||
| int m0, int mmax, int k0, int kmax) { | |||||
| static void gemm_s8x8x16_mk4_16x12_pack_A( | |||||
| dt_int16* outptr, const dt_int8* inptr, int ldin, int m0, int mmax, int k0, | |||||
| int kmax) { | |||||
| megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4"); | megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4"); | ||||
| megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | ||||
| constexpr int pack_m = 16; | constexpr int pack_m = 16; | ||||
| @@ -1224,9 +1218,8 @@ static void gemm_s8x8x16_mk4_16x12_pack_A(dt_int16* outptr, | |||||
| } | } | ||||
| } | } | ||||
| static void gemm_s8x8x16_mk4_16x12_pack_B(dt_int8* out, const dt_int8* in, | |||||
| int ldin, int n0, int nmax, int k0, | |||||
| int kmax) { | |||||
| static void gemm_s8x8x16_mk4_16x12_pack_B( | |||||
| dt_int8* out, const dt_int8* in, int ldin, int n0, int nmax, int k0, int kmax) { | |||||
| megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | ||||
| constexpr int pack_n = 12; | constexpr int pack_n = 12; | ||||
| @@ -43,8 +43,9 @@ namespace matmul_mk4_4x4x8_a72 { | |||||
| */ | */ | ||||
| // clang-format on | // clang-format on | ||||
| static inline void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
| int16_t* output, int LDC, bool, int remain_n) { | |||||
| static inline void kern_4x4( | |||||
| const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, bool, | |||||
| int remain_n) { | |||||
| K = div_ceil(K, 8); | K = div_ceil(K, 8); | ||||
| int oddk = (K & 1); | int oddk = (K & 1); | ||||
| K = ((K + 1) / 2) - 1; | K = ((K + 1) / 2) - 1; | ||||
| @@ -261,15 +262,14 @@ static inline void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||||
| "7:\n" STORE_C | "7:\n" STORE_C | ||||
| "101:\n" | "101:\n" | ||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||||
| [oddk] "+r"(oddk), [LDC] "+r"(LDC), [outptr] "+r"(outptr), | |||||
| [remain_n] "+r"(remain_n) | |||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [oddk] "+r"(oddk), | |||||
| [LDC] "+r"(LDC), [outptr] "+r"(outptr), [remain_n] "+r"(remain_n) | |||||
| : | : | ||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||||
| "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||||
| "v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7", | |||||
| "x8", "x9", "x10", "cc", "memory"); | |||||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||||
| "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||||
| "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", | |||||
| "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc", | |||||
| "memory"); | |||||
| #undef STORE_C | #undef STORE_C | ||||
| #undef STORE_LINE | #undef STORE_LINE | ||||
| @@ -282,26 +282,23 @@ static inline void transpose_8x4_b(const dt_int8* inptr, dt_int8* outptr) { | |||||
| vst1_s8(outptr + 3 * 8, in0.val[3]); | vst1_s8(outptr + 3 * 8, in0.val[3]); | ||||
| } | } | ||||
| static inline void interleve_8x4_b(const dt_int8* inptr, const dt_int8* inptr2, | |||||
| dt_int8* outptr) { | |||||
| static inline void interleve_8x4_b( | |||||
| const dt_int8* inptr, const dt_int8* inptr2, dt_int8* outptr) { | |||||
| int8x16_t in0 = vld1q_s8(inptr); | int8x16_t in0 = vld1q_s8(inptr); | ||||
| int8x16_t in1 = vld1q_s8(inptr2); | int8x16_t in1 = vld1q_s8(inptr2); | ||||
| int32x4x2_t in_x2 = { | |||||
| {vreinterpretq_s32_s8(in0), vreinterpretq_s32_s8(in1)}}; | |||||
| int32x4x2_t in_x2 = {{vreinterpretq_s32_s8(in0), vreinterpretq_s32_s8(in1)}}; | |||||
| vst2q_s32(reinterpret_cast<int32_t*>(outptr), in_x2); | vst2q_s32(reinterpret_cast<int32_t*>(outptr), in_x2); | ||||
| } | } | ||||
| static inline void interleve_8x4_b_pad(const dt_int8* inptr, dt_int8* outptr) { | static inline void interleve_8x4_b_pad(const dt_int8* inptr, dt_int8* outptr) { | ||||
| int8x16_t in0 = vld1q_s8(inptr); | int8x16_t in0 = vld1q_s8(inptr); | ||||
| int8x16_t in1 = vdupq_n_s8(0); | int8x16_t in1 = vdupq_n_s8(0); | ||||
| int32x4x2_t in_x2 = { | |||||
| {vreinterpretq_s32_s8(in0), vreinterpretq_s32_s8(in1)}}; | |||||
| int32x4x2_t in_x2 = {{vreinterpretq_s32_s8(in0), vreinterpretq_s32_s8(in1)}}; | |||||
| vst2q_s32(reinterpret_cast<int32_t*>(outptr), in_x2); | vst2q_s32(reinterpret_cast<int32_t*>(outptr), in_x2); | ||||
| } | } | ||||
| static void gemm_s8x8x16_mk4_4x4x8_pack_A(dt_int8* out, const dt_int8* in, | |||||
| int ldin, int m0, int mmax, int k0, | |||||
| int kmax) { | |||||
| static void gemm_s8x8x16_mk4_4x4x8_pack_A( | |||||
| dt_int8* out, const dt_int8* in, int ldin, int m0, int mmax, int k0, int kmax) { | |||||
| megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4"); | megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4"); | ||||
| megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | ||||
| constexpr int pack_m = 4; | constexpr int pack_m = 4; | ||||
| @@ -330,9 +327,8 @@ static void gemm_s8x8x16_mk4_4x4x8_pack_A(dt_int8* out, const dt_int8* in, | |||||
| } | } | ||||
| } | } | ||||
| static void gemm_s8x8x16_mk4_4x4x8_pack_B(dt_int8* out, const dt_int8* in, | |||||
| int ldin, int n0, int nmax, int k0, | |||||
| int kmax) { | |||||
| static void gemm_s8x8x16_mk4_4x4x8_pack_B( | |||||
| dt_int8* out, const dt_int8* in, int ldin, int n0, int nmax, int k0, int kmax) { | |||||
| megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | ||||
| constexpr int pack_n = 4; | constexpr int pack_n = 4; | ||||
| @@ -18,7 +18,6 @@ namespace megdnn { | |||||
| namespace aarch64 { | namespace aarch64 { | ||||
| namespace matmul_mk4_8x8x8 { | namespace matmul_mk4_8x8x8 { | ||||
| /** | /** | ||||
| * Overview of register layout: | * Overview of register layout: | ||||
| * | * | ||||
| @@ -39,18 +38,18 @@ namespace matmul_mk4_8x8x8 { | |||||
| * | v16 | | v28 | | * | v16 | | v28 | | ||||
| * | v17 | | v29 | | * | v17 | | v29 | | ||||
| * | v16 | | v30 | | * | v16 | | v30 | | ||||
| * | v17 | | v31 | | |||||
| * | v17 | | v31 | | |||||
| * +--------+ - - - - +---------------------------------+ | * +--------+ - - - - +---------------------------------+ | ||||
| * | * | ||||
| * Accumulator | * Accumulator | ||||
| */ | */ | ||||
| static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||||
| int16_t* output, int LDC, bool is_first_k, int m_remain, | |||||
| int n_remain) { | |||||
| static void kern_8x8( | |||||
| const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||||
| bool is_first_k, int m_remain, int n_remain) { | |||||
| K /= 8; | K /= 8; | ||||
| LDC = LDC * sizeof(int16_t); | LDC = LDC * sizeof(int16_t); | ||||
| const int8_t* a_ptr = packB;//packA; | |||||
| const int8_t* b_ptr = packA;//packB; | |||||
| const int8_t* a_ptr = packB; // packA; | |||||
| const int8_t* b_ptr = packA; // packB; | |||||
| // clang-format off | // clang-format off | ||||
| #define LOAD_C_8 \ | #define LOAD_C_8 \ | ||||
| "ld1 {v0.8h}, [x0], #16\n" \ | "ld1 {v0.8h}, [x0], #16\n" \ | ||||
| @@ -291,17 +290,17 @@ static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | ||||
| "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | ||||
| "v29", "v30", "v31"); | "v29", "v30", "v31"); | ||||
| // clang-format on | |||||
| // clang-format on | |||||
| } | } | ||||
| static void kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, | |||||
| int16_t* output, int LDC, bool is_first_k, int m_remain, | |||||
| int n_remain) { | |||||
| static void kern_8x8_remain( | |||||
| const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||||
| bool is_first_k, int m_remain, int n_remain) { | |||||
| K /= 8; | K /= 8; | ||||
| LDC = LDC * sizeof(int16_t); | LDC = LDC * sizeof(int16_t); | ||||
| const int8_t* a_ptr = packB; | const int8_t* a_ptr = packB; | ||||
| const int8_t* b_ptr = packA; | const int8_t* b_ptr = packA; | ||||
| // clang-format off | |||||
| // clang-format off | |||||
| register int16_t* outptr asm("x0") = output; | register int16_t* outptr asm("x0") = output; | ||||
| asm volatile( | asm volatile( | ||||
| "add x1, x0, %x[LDC]\n" | "add x1, x0, %x[LDC]\n" | ||||
| @@ -476,7 +475,7 @@ static void kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, | |||||
| "cbnz %w[K], 1b\n" | "cbnz %w[K], 1b\n" | ||||
| "cmp %w[is_first_k], #1\n" | "cmp %w[is_first_k], #1\n" | ||||
| "beq 2f\n" | |||||
| "beq 2f\n" | |||||
| "cmp %x[m_remain], #8 \n" | "cmp %x[m_remain], #8 \n" | ||||
| "beq 8f \n" | "beq 8f \n" | ||||
| "cmp %x[m_remain], #4 \n" | "cmp %x[m_remain], #4 \n" | ||||
| @@ -633,7 +632,7 @@ static void kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, | |||||
| "zip2 v15.2d, v30.2d, v31.2d \n" | "zip2 v15.2d, v30.2d, v31.2d \n" | ||||
| "add v6.8h, v6.8h, v13.8h \n" | "add v6.8h, v6.8h, v13.8h \n" | ||||
| "add v7.8h, v7.8h, v15.8h \n" | "add v7.8h, v7.8h, v15.8h \n" | ||||
| //save to memory | |||||
| // save to memory | |||||
| "cmp %x[m_remain], #8 \n" | "cmp %x[m_remain], #8 \n" | ||||
| "beq 4f \n" | "beq 4f \n" | ||||
| "cmp %x[m_remain], #4 \n" | "cmp %x[m_remain], #4 \n" | ||||
| @@ -766,31 +765,27 @@ static void kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, | |||||
| "b 1000f \n" | "b 1000f \n" | ||||
| "1000: \n" | "1000: \n" | ||||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||||
| [K] "+r"(K), [LDC] "+r"(LDC), [outptr] "+r"(outptr), | |||||
| [m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain) | |||||
| : | : | ||||
| [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), | |||||
| [ is_first_k ] "+r"(is_first_k), [ K ] "+r"(K), [ LDC ] "+r"(LDC), | |||||
| [ outptr ] "+r"(outptr), [ m_remain ] "+r"(m_remain), | |||||
| [ n_remain ] "+r"(n_remain) | |||||
| : | |||||
| : "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", | |||||
| "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||||
| "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||||
| "v29", "v30", "v31"); | |||||
| // clang-format on | |||||
| : "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "v0", | |||||
| "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", | |||||
| "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", | |||||
| "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); | |||||
| // clang-format on | |||||
| #undef LOAD_C_8 | #undef LOAD_C_8 | ||||
| #undef STORE_C_8 | #undef STORE_C_8 | ||||
| } | } | ||||
| static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, | |||||
| int16_t* output, int LDC, bool is_first_k, int m_remain, | |||||
| int n_remain) { | |||||
| static void kern_4x8( | |||||
| const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||||
| bool is_first_k, int m_remain, int n_remain) { | |||||
| K /= 8; | K /= 8; | ||||
| LDC = LDC * sizeof(int16_t); | LDC = LDC * sizeof(int16_t); | ||||
| const int8_t* a_ptr = packB;//packA; | |||||
| const int8_t* b_ptr = packA;//packB; | |||||
| const int8_t* a_ptr = packB; // packA; | |||||
| const int8_t* b_ptr = packA; // packB; | |||||
| // clang-format off | // clang-format off | ||||
| #define LOAD_C_4 \ | #define LOAD_C_4 \ | ||||
| "ld1 {v0.8h}, [x0], #16\n" \ | "ld1 {v0.8h}, [x0], #16\n" \ | ||||
| @@ -1018,14 +1013,14 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, | |||||
| #undef LOAD_C_4 | #undef LOAD_C_4 | ||||
| #undef STORE_C_4 | #undef STORE_C_4 | ||||
| } | } | ||||
| static void kern_4x8_remain(const int8_t* packA, const int8_t* packB, int K, | |||||
| int16_t* output, int LDC, bool is_first_k, int m_remain, | |||||
| int n_remain) { | |||||
| static void kern_4x8_remain( | |||||
| const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||||
| bool is_first_k, int m_remain, int n_remain) { | |||||
| K /= 8; | K /= 8; | ||||
| LDC = LDC * sizeof(int16_t); | LDC = LDC * sizeof(int16_t); | ||||
| const int8_t* a_ptr = packB;//packA; | |||||
| const int8_t* b_ptr = packA;//packB; | |||||
| // clang-format off | |||||
| const int8_t* a_ptr = packB; // packA; | |||||
| const int8_t* b_ptr = packA; // packB; | |||||
| // clang-format off | |||||
| register int16_t* outptr asm("x0") = output; | register int16_t* outptr asm("x0") = output; | ||||
| asm volatile( | asm volatile( | ||||
| @@ -1324,13 +1319,12 @@ static void kern_4x8_remain(const int8_t* packA, const int8_t* packB, int K, | |||||
| #undef STORE_C_4 | #undef STORE_C_4 | ||||
| } | } | ||||
| //! pack to icxoc | //! pack to icxoc | ||||
| //! (M/4,K/4,4(K),4(M)) pack to (M/8,k/8,8(K_ic_0~3_ic_4~7),8(M_oc0~3_OC_4~7)) | //! (M/4,K/4,4(K),4(M)) pack to (M/8,k/8,8(K_ic_0~3_ic_4~7),8(M_oc0~3_OC_4~7)) | ||||
| //! if M K is not times of 8,pack 0 instead | |||||
| static void gemm_s8x8x16_mk4_8x8x8_pack_A(dt_int8* outptr, | |||||
| const dt_int8* inptr, int ldin, | |||||
| int m0, int mmax, int k0, int kmax) { | |||||
| //! if M K is not times of 8,pack 0 instead | |||||
| static void gemm_s8x8x16_mk4_8x8x8_pack_A( | |||||
| dt_int8* outptr, const dt_int8* inptr, int ldin, int m0, int mmax, int k0, | |||||
| int kmax) { | |||||
| megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4"); | megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4"); | ||||
| megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | ||||
| constexpr int pack_m = 8; | constexpr int pack_m = 8; | ||||
| @@ -1349,8 +1343,8 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_A(dt_int8* outptr, | |||||
| prefetch_2x(inptr0); | prefetch_2x(inptr0); | ||||
| prefetch_2x(inptr1); | prefetch_2x(inptr1); | ||||
| int k_idx = k0; | int k_idx = k0; | ||||
| for ( ; k_idx + 7 < kmax; k_idx += pack_k) { | |||||
| interleave_8x8_mk4_b(inptr0,inptr1,outptr); | |||||
| for (; k_idx + 7 < kmax; k_idx += pack_k) { | |||||
| interleave_8x8_mk4_b(inptr0, inptr1, outptr); | |||||
| } | } | ||||
| if (k_idx < kmax) { | if (k_idx < kmax) { | ||||
| @@ -1368,9 +1362,9 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_A(dt_int8* outptr, | |||||
| prefetch_2x(inptr0); | prefetch_2x(inptr0); | ||||
| prefetch_2x(inptr1); | prefetch_2x(inptr1); | ||||
| int k_idx = k0; | int k_idx = k0; | ||||
| for ( ; k_idx + 7 < kmax; k_idx += pack_k) { | |||||
| for (; k_idx + 7 < kmax; k_idx += pack_k) { | |||||
| inptr1 = zerobuff; | inptr1 = zerobuff; | ||||
| interleave_8x8_mk4_b(inptr0,inptr1,outptr); | |||||
| interleave_8x8_mk4_b(inptr0, inptr1, outptr); | |||||
| } | } | ||||
| if (k_idx < kmax) { | if (k_idx < kmax) { | ||||
| @@ -1383,9 +1377,8 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_A(dt_int8* outptr, | |||||
| } | } | ||||
| //! pack to nxic | //! pack to nxic | ||||
| //! (K/4,N,4) pack to K/8,N,8(ic0~7) ,K is not times of 8 ,pack 0 instead. | //! (K/4,N,4) pack to K/8,N,8(ic0~7) ,K is not times of 8 ,pack 0 instead. | ||||
| static void gemm_s8x8x16_mk4_8x8x8_pack_B(dt_int8* out, const dt_int8* in, | |||||
| int ldin, int n0, int nmax, int k0, | |||||
| int kmax) { | |||||
| static void gemm_s8x8x16_mk4_8x8x8_pack_B( | |||||
| dt_int8* out, const dt_int8* in, int ldin, int n0, int nmax, int k0, int kmax) { | |||||
| megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | ||||
| constexpr int pack_n = 8; | constexpr int pack_n = 8; | ||||
| @@ -1394,14 +1387,14 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_B(dt_int8* out, const dt_int8* in, | |||||
| int8_t tmpbuff0[pack_n * pack_size] = {0}; | int8_t tmpbuff0[pack_n * pack_size] = {0}; | ||||
| int8_t tmpbuff1[pack_n * pack_size] = {0}; | int8_t tmpbuff1[pack_n * pack_size] = {0}; | ||||
| int8_t zerobuff[pack_n * pack_size] = {0}; | int8_t zerobuff[pack_n * pack_size] = {0}; | ||||
| const int ksize = round_up<int>((kmax - k0),8); | |||||
| const int ksize = round_up<int>((kmax - k0), 8); | |||||
| const int nsize = nmax - n0; | const int nsize = nmax - n0; | ||||
| const int n_end = nsize / pack_n * pack_n + n0; | const int n_end = nsize / pack_n * pack_n + n0; | ||||
| const int remain_n = nsize % pack_n; | const int remain_n = nsize % pack_n; | ||||
| int output_stride = ksize * pack_n; | int output_stride = ksize * pack_n; | ||||
| int8_t* outptr_base = out; | int8_t* outptr_base = out; | ||||
| int k_idx = k0; | int k_idx = k0; | ||||
| for ( ; k_idx + 7 < kmax; k_idx += pack_k) { | |||||
| for (; k_idx + 7 < kmax; k_idx += pack_k) { | |||||
| const int8_t* inptr0 = in + k_idx / pack_size * ldin + n0 * pack_size; | const int8_t* inptr0 = in + k_idx / pack_size * ldin + n0 * pack_size; | ||||
| const int8_t* inptr1 = inptr0 + ldin; | const int8_t* inptr1 = inptr0 + ldin; | ||||
| prefetch_3x(inptr0); | prefetch_3x(inptr0); | ||||
| @@ -1410,7 +1403,7 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_B(dt_int8* out, const dt_int8* in, | |||||
| auto outptr = outptr_base; | auto outptr = outptr_base; | ||||
| for (int n_idx = n0; n_idx < n_end; n_idx += pack_n) { | for (int n_idx = n0; n_idx < n_end; n_idx += pack_n) { | ||||
| transpose_8x8_mk4_b(inptr0, inptr1, outptr); | transpose_8x8_mk4_b(inptr0, inptr1, outptr); | ||||
| outptr += output_stride; | |||||
| outptr += output_stride; | |||||
| } | } | ||||
| if (remain_n > 0) { | if (remain_n > 0) { | ||||
| memcpy(tmpbuff0, inptr0, sizeof(int8_t) * remain_n * pack_size); | memcpy(tmpbuff0, inptr0, sizeof(int8_t) * remain_n * pack_size); | ||||
| @@ -1422,8 +1415,8 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_B(dt_int8* out, const dt_int8* in, | |||||
| } | } | ||||
| outptr_base += pack_n * pack_k; | outptr_base += pack_n * pack_k; | ||||
| } | } | ||||
| if(k_idx < kmax){ | |||||
| if (k_idx < kmax) { | |||||
| const int8_t* inptr0 = in + k_idx / pack_size * ldin + n0 * pack_size; | const int8_t* inptr0 = in + k_idx / pack_size * ldin + n0 * pack_size; | ||||
| const int8_t* inptr1 = nullptr; | const int8_t* inptr1 = nullptr; | ||||
| prefetch_3x(inptr0); | prefetch_3x(inptr0); | ||||
| @@ -1444,7 +1437,7 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_B(dt_int8* out, const dt_int8* in, | |||||
| } | } | ||||
| } | } | ||||
| } // namespace matmul_mk4_16x12x4_a53 | |||||
| } // namespace matmul_mk4_8x8x8 | |||||
| } // namespace aarch64 | } // namespace aarch64 | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -10,13 +10,13 @@ | |||||
| * implied. | * implied. | ||||
| */ | */ | ||||
| #include "src/aarch64/matrix_mul/int8x8x16/strategy.h" | |||||
| #include "src/aarch64/matrix_mul/asm/common.h" | #include "src/aarch64/matrix_mul/asm/common.h" | ||||
| #include "src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h" | #include "src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h" | ||||
| #include "src/aarch64/matrix_mul/int8x8x16/kernel_8x8x8.h" | #include "src/aarch64/matrix_mul/int8x8x16/kernel_8x8x8.h" | ||||
| #include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_8x8x8.h" | |||||
| #include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_16x12x4_a53.h" | #include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_16x12x4_a53.h" | ||||
| #include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_4x4x8_a72.h" | #include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_4x4x8_a72.h" | ||||
| #include "src/aarch64/matrix_mul/int8x8x16/strategy.h" | |||||
| #include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_8x8x8.h" | |||||
| #include "src/arm_common/simd_macro/marm_neon.h" | #include "src/arm_common/simd_macro/marm_neon.h" | ||||
| #include "src/common/utils.h" | #include "src/common/utils.h" | ||||
| #include "src/fallback/matrix_mul/gemm_common.h" | #include "src/fallback/matrix_mul/gemm_common.h" | ||||
| @@ -28,39 +28,35 @@ using namespace aarch64::matmul; | |||||
| // ===========================gemm_s8x8x16_4x4================================== | // ===========================gemm_s8x8x16_4x4================================== | ||||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_8x8); | MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_8x8); | ||||
| void gemm_s8x8x16_8x8::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, | |||||
| int ymax, int k0, int kmax, | |||||
| bool transpose) const { | |||||
| void gemm_s8x8x16_8x8::pack_A( | |||||
| dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax, | |||||
| bool transpose) const { | |||||
| if (transpose) { | if (transpose) { | ||||
| matmul_8x8x8::gemm_s8x8x16_8x8_transpose_pack_A_n(out, in, ldin, y0, | |||||
| ymax, k0, kmax); | |||||
| matmul_8x8x8::gemm_s8x8x16_8x8_transpose_pack_A_n( | |||||
| out, in, ldin, y0, ymax, k0, kmax); | |||||
| } else { | } else { | ||||
| matmul_8x8x8::gemm_s8x8x16_8x8_pack_A_n(out, in, ldin, y0, ymax, k0, | |||||
| kmax); | |||||
| matmul_8x8x8::gemm_s8x8x16_8x8_pack_A_n(out, in, ldin, y0, ymax, k0, kmax); | |||||
| } | } | ||||
| } | } | ||||
| void gemm_s8x8x16_8x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||||
| int xmax, int k0, int kmax, | |||||
| bool transpose) const { | |||||
| void gemm_s8x8x16_8x8::pack_B( | |||||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||||
| bool transpose) const { | |||||
| if (transpose) { | if (transpose) { | ||||
| matmul_8x8x8::gemm_s8x8x16_8x8_transpose_pack_B_n(out, in, ldin, x0, | |||||
| xmax, k0, kmax); | |||||
| matmul_8x8x8::gemm_s8x8x16_8x8_transpose_pack_B_n( | |||||
| out, in, ldin, x0, xmax, k0, kmax); | |||||
| } else { | } else { | ||||
| matmul_8x8x8::gemm_s8x8x16_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, | |||||
| kmax); | |||||
| matmul_8x8x8::gemm_s8x8x16_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | |||||
| } | } | ||||
| } | } | ||||
| void gemm_s8x8x16_8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||||
| size_t M, size_t N, size_t K, dt_int16* C, | |||||
| size_t LDC, bool is_first_k, const dt_int16*, | |||||
| dt_int16*) const { | |||||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
| (A_dtype.enumv() == DTypeEnum::Int8 && | |||||
| C_dtype.enumv() == DTypeEnum::Int16), | |||||
| "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||||
| C_dtype.name()); | |||||
| void gemm_s8x8x16_8x8::kern( | |||||
| const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||||
| dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const { | |||||
| megdnn_assert( | |||||
| A_dtype.enumv() == B_dtype.enumv() && (A_dtype.enumv() == DTypeEnum::Int8 && | |||||
| C_dtype.enumv() == DTypeEnum::Int16), | |||||
| "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); | |||||
| MEGDNN_MARK_USED_VAR(A_dtype); | MEGDNN_MARK_USED_VAR(A_dtype); | ||||
| MEGDNN_MARK_USED_VAR(B_dtype); | MEGDNN_MARK_USED_VAR(B_dtype); | ||||
| MEGDNN_MARK_USED_VAR(C_dtype); | MEGDNN_MARK_USED_VAR(C_dtype); | ||||
| @@ -79,15 +75,15 @@ void gemm_s8x8x16_8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||||
| size_t n = 0; | size_t n = 0; | ||||
| const dt_int8* cur_packB = packB; | const dt_int8* cur_packB = packB; | ||||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
| matmul_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, | |||||
| is_first_k); | |||||
| matmul_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, is_first_k); | |||||
| output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
| cur_packB += K8; | cur_packB += K8; | ||||
| } | } | ||||
| for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
| matmul_8x8x8::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k, | |||||
| std::min<size_t>(N - n, 4)); | |||||
| matmul_8x8x8::kern_8x4( | |||||
| packA, cur_packB, K, output, LDC, is_first_k, | |||||
| std::min<size_t>(N - n, 4)); | |||||
| output += 4; | output += 4; | ||||
| cur_packB += K4; | cur_packB += K4; | ||||
| } | } | ||||
| @@ -99,16 +95,17 @@ void gemm_s8x8x16_8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||||
| const dt_int8* cur_packB = packB; | const dt_int8* cur_packB = packB; | ||||
| size_t n = 0; | size_t n = 0; | ||||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
| matmul_8x8x8::kern_4x8(packA, cur_packB, K, output, LDC, is_first_k, | |||||
| std::min<size_t>(M - m, 4)); | |||||
| matmul_8x8x8::kern_4x8( | |||||
| packA, cur_packB, K, output, LDC, is_first_k, | |||||
| std::min<size_t>(M - m, 4)); | |||||
| output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
| cur_packB += K8; | cur_packB += K8; | ||||
| } | } | ||||
| for (; n < N; n += 4) { | for (; n < N; n += 4) { | ||||
| matmul_8x8x8::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, | |||||
| std::min<size_t>(M - m, 4), | |||||
| std::min<size_t>(N - n, 4)); | |||||
| matmul_8x8x8::kern_4x4( | |||||
| packA, cur_packB, K, output, LDC, is_first_k, | |||||
| std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||||
| output += 4; | output += 4; | ||||
| cur_packB += K4; | cur_packB += K4; | ||||
| } | } | ||||
| @@ -119,39 +116,33 @@ void gemm_s8x8x16_8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||||
| // ===========================gemm_s8x8x16_4x4================================== | // ===========================gemm_s8x8x16_4x4================================== | ||||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_4x4); | MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_4x4); | ||||
| void gemm_s8x8x16_4x4::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, | |||||
| int ymax, int k0, int kmax, | |||||
| bool transpose) const { | |||||
| void gemm_s8x8x16_4x4::pack_A( | |||||
| dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax, | |||||
| bool transpose) const { | |||||
| if (transpose) { | if (transpose) { | ||||
| matmul_4x4x16::gemm_s8x8x16_4x4_pack_B_n(out, in, ldin, y0, ymax, k0, | |||||
| kmax); | |||||
| matmul_4x4x16::gemm_s8x8x16_4x4_pack_B_n(out, in, ldin, y0, ymax, k0, kmax); | |||||
| } else { | } else { | ||||
| matmul_4x4x16::gemm_s8x8x16_4x4_pack_A_n(out, in, ldin, y0, ymax, k0, | |||||
| kmax); | |||||
| matmul_4x4x16::gemm_s8x8x16_4x4_pack_A_n(out, in, ldin, y0, ymax, k0, kmax); | |||||
| } | } | ||||
| } | } | ||||
| void gemm_s8x8x16_4x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||||
| int xmax, int k0, int kmax, | |||||
| bool transpose) const { | |||||
| void gemm_s8x8x16_4x4::pack_B( | |||||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||||
| bool transpose) const { | |||||
| if (transpose) { | if (transpose) { | ||||
| matmul_4x4x16::gemm_s8x8x16_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, | |||||
| kmax); | |||||
| matmul_4x4x16::gemm_s8x8x16_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, kmax); | |||||
| } else { | } else { | ||||
| matmul_4x4x16::gemm_s8x8x16_4x4_pack_B_n(out, in, ldin, x0, xmax, k0, | |||||
| kmax); | |||||
| matmul_4x4x16::gemm_s8x8x16_4x4_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | |||||
| } | } | ||||
| } | } | ||||
| void gemm_s8x8x16_4x4::kern(const dt_int8* packA, const dt_int8* packB, | |||||
| size_t M, size_t N, size_t K, dt_int16* C, | |||||
| size_t LDC, bool is_first_k, const dt_int16*, | |||||
| dt_int16*) const { | |||||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
| (A_dtype.enumv() == DTypeEnum::Int8 && | |||||
| C_dtype.enumv() == DTypeEnum::Int16), | |||||
| "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||||
| C_dtype.name()); | |||||
| void gemm_s8x8x16_4x4::kern( | |||||
| const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||||
| dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const { | |||||
| megdnn_assert( | |||||
| A_dtype.enumv() == B_dtype.enumv() && (A_dtype.enumv() == DTypeEnum::Int8 && | |||||
| C_dtype.enumv() == DTypeEnum::Int16), | |||||
| "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); | |||||
| MEGDNN_MARK_USED_VAR(A_dtype); | MEGDNN_MARK_USED_VAR(A_dtype); | ||||
| MEGDNN_MARK_USED_VAR(B_dtype); | MEGDNN_MARK_USED_VAR(B_dtype); | ||||
| MEGDNN_MARK_USED_VAR(C_dtype); | MEGDNN_MARK_USED_VAR(C_dtype); | ||||
| @@ -169,16 +160,17 @@ void gemm_s8x8x16_4x4::kern(const dt_int8* packA, const dt_int8* packB, | |||||
| size_t n = 0; | size_t n = 0; | ||||
| const dt_int8* cur_packB = packB; | const dt_int8* cur_packB = packB; | ||||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | ||||
| matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC, | |||||
| is_first_k, A_INTERLEAVE, B_INTERLEAVE); | |||||
| matmul_4x4x16::kern_4x4( | |||||
| packA, cur_packB, K, output, LDC, is_first_k, A_INTERLEAVE, | |||||
| B_INTERLEAVE); | |||||
| output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
| cur_packB += K4; | cur_packB += K4; | ||||
| } | } | ||||
| for (; n < N; n += B_INTERLEAVE) { | for (; n < N; n += B_INTERLEAVE) { | ||||
| matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC, | |||||
| is_first_k, A_INTERLEAVE, | |||||
| std::min<size_t>(N - n, B_INTERLEAVE)); | |||||
| matmul_4x4x16::kern_4x4( | |||||
| packA, cur_packB, K, output, LDC, is_first_k, A_INTERLEAVE, | |||||
| std::min<size_t>(N - n, B_INTERLEAVE)); | |||||
| output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
| cur_packB += K4; | cur_packB += K4; | ||||
| } | } | ||||
| @@ -191,10 +183,10 @@ void gemm_s8x8x16_4x4::kern(const dt_int8* packA, const dt_int8* packB, | |||||
| size_t n = 0; | size_t n = 0; | ||||
| const dt_int8* cur_packB = packB; | const dt_int8* cur_packB = packB; | ||||
| for (; n < N; n += B_INTERLEAVE) { | for (; n < N; n += B_INTERLEAVE) { | ||||
| matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC, | |||||
| is_first_k, | |||||
| std::min<size_t>(M - m, A_INTERLEAVE), | |||||
| std::min<size_t>(N - n, B_INTERLEAVE)); | |||||
| matmul_4x4x16::kern_4x4( | |||||
| packA, cur_packB, K, output, LDC, is_first_k, | |||||
| std::min<size_t>(M - m, A_INTERLEAVE), | |||||
| std::min<size_t>(N - n, B_INTERLEAVE)); | |||||
| output += B_INTERLEAVE; | output += B_INTERLEAVE; | ||||
| cur_packB += K4; | cur_packB += K4; | ||||
| } | } | ||||
| @@ -205,28 +197,26 @@ void gemm_s8x8x16_4x4::kern(const dt_int8* packA, const dt_int8* packB, | |||||
| // ===========================gemm_s8x8x16_mk4_16x12================================== | // ===========================gemm_s8x8x16_mk4_16x12================================== | ||||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_16x12_a53); | MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_16x12_a53); | ||||
| void gemm_s8x8x16_mk4_16x12_a53::pack_A(dt_int16* out, const dt_int8* in, | |||||
| int ldin, int y0, int ymax, int k0, | |||||
| int kmax, bool) const { | |||||
| matmul_mk4_16x12x4_a53::gemm_s8x8x16_mk4_16x12_pack_A(out, in, ldin, y0, | |||||
| ymax, k0, kmax); | |||||
| void gemm_s8x8x16_mk4_16x12_a53::pack_A( | |||||
| dt_int16* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax, | |||||
| bool) const { | |||||
| matmul_mk4_16x12x4_a53::gemm_s8x8x16_mk4_16x12_pack_A( | |||||
| out, in, ldin, y0, ymax, k0, kmax); | |||||
| } | } | ||||
| void gemm_s8x8x16_mk4_16x12_a53::pack_B(dt_int8* out, const dt_int8* in, | |||||
| int ldin, int x0, int xmax, int k0, | |||||
| int kmax, bool) const { | |||||
| matmul_mk4_16x12x4_a53::gemm_s8x8x16_mk4_16x12_pack_B(out, in, ldin, x0, | |||||
| xmax, k0, kmax); | |||||
| void gemm_s8x8x16_mk4_16x12_a53::pack_B( | |||||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||||
| bool) const { | |||||
| matmul_mk4_16x12x4_a53::gemm_s8x8x16_mk4_16x12_pack_B( | |||||
| out, in, ldin, x0, xmax, k0, kmax); | |||||
| } | } | ||||
| void gemm_s8x8x16_mk4_16x12_a53::kern(const dt_int16* packA, | |||||
| const dt_int8* packB, size_t M, size_t N, | |||||
| size_t K, dt_int16* C, size_t LDC, | |||||
| bool is_first_k, const dt_int16*, | |||||
| dt_int16*) const { | |||||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
| C_dtype.enumv() == DTypeEnum::Int16 && | |||||
| A_dtype.enumv() == DTypeEnum::Int8); | |||||
| void gemm_s8x8x16_mk4_16x12_a53::kern( | |||||
| const dt_int16* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||||
| dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const { | |||||
| megdnn_assert( | |||||
| A_dtype.enumv() == B_dtype.enumv() && C_dtype.enumv() == DTypeEnum::Int16 && | |||||
| A_dtype.enumv() == DTypeEnum::Int8); | |||||
| megdnn_assert(is_first_k == true, "only impl is_first_k"); | megdnn_assert(is_first_k == true, "only impl is_first_k"); | ||||
| MEGDNN_MARK_USED_VAR(A_dtype); | MEGDNN_MARK_USED_VAR(A_dtype); | ||||
| MEGDNN_MARK_USED_VAR(B_dtype); | MEGDNN_MARK_USED_VAR(B_dtype); | ||||
| @@ -246,14 +236,14 @@ void gemm_s8x8x16_mk4_16x12_a53::kern(const dt_int16* packA, | |||||
| size_t n_idx = 0; | size_t n_idx = 0; | ||||
| const int8_t* cur_packB = packB; | const int8_t* cur_packB = packB; | ||||
| for (; n_idx + pack_n <= N; n_idx += pack_n) { | for (; n_idx + pack_n <= N; n_idx += pack_n) { | ||||
| matmul_mk4_16x12x4_a53::kern_16x12(packA, cur_packB, K, output, LDC, | |||||
| is_first_k, pack_n); | |||||
| matmul_mk4_16x12x4_a53::kern_16x12( | |||||
| packA, cur_packB, K, output, LDC, is_first_k, pack_n); | |||||
| output += pack_n * pack_size; | output += pack_n * pack_size; | ||||
| cur_packB += pack_n * K; | cur_packB += pack_n * K; | ||||
| } | } | ||||
| if (remain_n > 0) { | if (remain_n > 0) { | ||||
| matmul_mk4_16x12x4_a53::kern_16x12(packA, cur_packB, K, output, LDC, | |||||
| is_first_k, remain_n); | |||||
| matmul_mk4_16x12x4_a53::kern_16x12( | |||||
| packA, cur_packB, K, output, LDC, is_first_k, remain_n); | |||||
| output += remain_n * pack_size; | output += remain_n * pack_size; | ||||
| cur_packB += pack_n * K; | cur_packB += pack_n * K; | ||||
| } | } | ||||
| @@ -265,14 +255,14 @@ void gemm_s8x8x16_mk4_16x12_a53::kern(const dt_int16* packA, | |||||
| size_t n_idx = 0; | size_t n_idx = 0; | ||||
| const int8_t* cur_packB = packB; | const int8_t* cur_packB = packB; | ||||
| for (; n_idx + pack_n <= N; n_idx += pack_n) { | for (; n_idx + pack_n <= N; n_idx += pack_n) { | ||||
| matmul_mk4_16x12x4_a53::kern_8x12(packA, cur_packB, K, output, LDC, | |||||
| is_first_k, pack_n); | |||||
| matmul_mk4_16x12x4_a53::kern_8x12( | |||||
| packA, cur_packB, K, output, LDC, is_first_k, pack_n); | |||||
| output += pack_n * pack_size; | output += pack_n * pack_size; | ||||
| cur_packB += pack_n * K; | cur_packB += pack_n * K; | ||||
| } | } | ||||
| if (remain_n > 0) { | if (remain_n > 0) { | ||||
| matmul_mk4_16x12x4_a53::kern_8x12(packA, cur_packB, K, output, LDC, | |||||
| is_first_k, remain_n); | |||||
| matmul_mk4_16x12x4_a53::kern_8x12( | |||||
| packA, cur_packB, K, output, LDC, is_first_k, remain_n); | |||||
| output += remain_n * pack_size; | output += remain_n * pack_size; | ||||
| cur_packB += pack_n * K; | cur_packB += pack_n * K; | ||||
| } | } | ||||
| @@ -286,14 +276,14 @@ void gemm_s8x8x16_mk4_16x12_a53::kern(const dt_int16* packA, | |||||
| size_t n_idx = 0; | size_t n_idx = 0; | ||||
| const int8_t* cur_packB = packB; | const int8_t* cur_packB = packB; | ||||
| for (; n_idx + pack_n <= N; n_idx += pack_n) { | for (; n_idx + pack_n <= N; n_idx += pack_n) { | ||||
| matmul_mk4_16x12x4_a53::kern_4x12(packA, cur_packB, K, output, LDC, | |||||
| is_first_k, pack_n); | |||||
| matmul_mk4_16x12x4_a53::kern_4x12( | |||||
| packA, cur_packB, K, output, LDC, is_first_k, pack_n); | |||||
| output += pack_n * pack_size; | output += pack_n * pack_size; | ||||
| cur_packB += pack_n * K; | cur_packB += pack_n * K; | ||||
| } | } | ||||
| if (remain_n > 0) { | if (remain_n > 0) { | ||||
| matmul_mk4_16x12x4_a53::kern_4x12(packA, cur_packB, K, output, LDC, | |||||
| is_first_k, remain_n); | |||||
| matmul_mk4_16x12x4_a53::kern_4x12( | |||||
| packA, cur_packB, K, output, LDC, is_first_k, remain_n); | |||||
| output += remain_n * pack_size; | output += remain_n * pack_size; | ||||
| cur_packB += pack_n * K; | cur_packB += pack_n * K; | ||||
| } | } | ||||
| @@ -303,27 +293,26 @@ void gemm_s8x8x16_mk4_16x12_a53::kern(const dt_int16* packA, | |||||
| // ===========================gemm_s8x8x16_mk4_4x4_a72================================== | // ===========================gemm_s8x8x16_mk4_4x4_a72================================== | ||||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_4x4_a72); | MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_4x4_a72); | ||||
| void gemm_s8x8x16_mk4_4x4_a72::pack_A(dt_int8* out, const dt_int8* in, int ldin, | |||||
| int y0, int ymax, int k0, int kmax, | |||||
| bool) const { | |||||
| matmul_mk4_4x4x8_a72::gemm_s8x8x16_mk4_4x4x8_pack_A(out, in, ldin, y0, ymax, | |||||
| k0, kmax); | |||||
| void gemm_s8x8x16_mk4_4x4_a72::pack_A( | |||||
| dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax, | |||||
| bool) const { | |||||
| matmul_mk4_4x4x8_a72::gemm_s8x8x16_mk4_4x4x8_pack_A( | |||||
| out, in, ldin, y0, ymax, k0, kmax); | |||||
| } | } | ||||
| void gemm_s8x8x16_mk4_4x4_a72::pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||||
| int x0, int xmax, int k0, int kmax, | |||||
| bool) const { | |||||
| matmul_mk4_4x4x8_a72::gemm_s8x8x16_mk4_4x4x8_pack_B(out, in, ldin, x0, xmax, | |||||
| k0, kmax); | |||||
| void gemm_s8x8x16_mk4_4x4_a72::pack_B( | |||||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||||
| bool) const { | |||||
| matmul_mk4_4x4x8_a72::gemm_s8x8x16_mk4_4x4x8_pack_B( | |||||
| out, in, ldin, x0, xmax, k0, kmax); | |||||
| } | } | ||||
| void gemm_s8x8x16_mk4_4x4_a72::kern(const dt_int8* packA, const dt_int8* packB, | |||||
| size_t M, size_t N, size_t K, dt_int16* C, | |||||
| size_t LDC, bool is_first_k, | |||||
| const dt_int16*, dt_int16*) const { | |||||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
| C_dtype.enumv() == DTypeEnum::Int16 && | |||||
| A_dtype.enumv() == DTypeEnum::Int8); | |||||
| void gemm_s8x8x16_mk4_4x4_a72::kern( | |||||
| const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||||
| dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const { | |||||
| megdnn_assert( | |||||
| A_dtype.enumv() == B_dtype.enumv() && C_dtype.enumv() == DTypeEnum::Int16 && | |||||
| A_dtype.enumv() == DTypeEnum::Int8); | |||||
| megdnn_assert(is_first_k == true, "only impl is_first_k"); | megdnn_assert(is_first_k == true, "only impl is_first_k"); | ||||
| MEGDNN_MARK_USED_VAR(A_dtype); | MEGDNN_MARK_USED_VAR(A_dtype); | ||||
| MEGDNN_MARK_USED_VAR(B_dtype); | MEGDNN_MARK_USED_VAR(B_dtype); | ||||
| @@ -343,14 +332,14 @@ void gemm_s8x8x16_mk4_4x4_a72::kern(const dt_int8* packA, const dt_int8* packB, | |||||
| const int8_t* cur_packB = packB; | const int8_t* cur_packB = packB; | ||||
| for (size_t n_idx = 0; n_idx < nend; n_idx += pack_n) { | for (size_t n_idx = 0; n_idx < nend; n_idx += pack_n) { | ||||
| matmul_mk4_4x4x8_a72::kern_4x4(packA, cur_packB, K, output, LDC, | |||||
| is_first_k, pack_n); | |||||
| matmul_mk4_4x4x8_a72::kern_4x4( | |||||
| packA, cur_packB, K, output, LDC, is_first_k, pack_n); | |||||
| output += pack_n * pack_size; | output += pack_n * pack_size; | ||||
| cur_packB += pack_n * packed_k; | cur_packB += pack_n * packed_k; | ||||
| } | } | ||||
| if (remain_n > 0) { | if (remain_n > 0) { | ||||
| matmul_mk4_4x4x8_a72::kern_4x4(packA, cur_packB, K, output, LDC, | |||||
| is_first_k, remain_n); | |||||
| matmul_mk4_4x4x8_a72::kern_4x4( | |||||
| packA, cur_packB, K, output, LDC, is_first_k, remain_n); | |||||
| output += remain_n * pack_size; | output += remain_n * pack_size; | ||||
| cur_packB += pack_n * packed_k; | cur_packB += pack_n * packed_k; | ||||
| } | } | ||||
| @@ -361,27 +350,24 @@ void gemm_s8x8x16_mk4_4x4_a72::kern(const dt_int8* packA, const dt_int8* packB, | |||||
| // ===========================gemm_s8x8x16_mk4_8x8x8================================== | // ===========================gemm_s8x8x16_mk4_8x8x8================================== | ||||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_8x8x8); | MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_8x8x8); | ||||
| void gemm_s8x8x16_mk4_8x8x8::pack_A(dt_int8* out, const dt_int8* in, | |||||
| int ldin, int y0, int ymax, int k0, | |||||
| int kmax, bool) const { | |||||
| matmul_mk4_8x8x8::gemm_s8x8x16_mk4_8x8x8_pack_A(out, in, ldin, y0, | |||||
| ymax, k0, kmax); | |||||
| void gemm_s8x8x16_mk4_8x8x8::pack_A( | |||||
| dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax, | |||||
| bool) const { | |||||
| matmul_mk4_8x8x8::gemm_s8x8x16_mk4_8x8x8_pack_A(out, in, ldin, y0, ymax, k0, kmax); | |||||
| } | } | ||||
| void gemm_s8x8x16_mk4_8x8x8::pack_B(dt_int8* out, const dt_int8* in, | |||||
| int ldin, int x0, int xmax, int k0, | |||||
| int kmax, bool) const { | |||||
| matmul_mk4_8x8x8::gemm_s8x8x16_mk4_8x8x8_pack_B(out, in, ldin, x0, | |||||
| xmax, k0, kmax); | |||||
| void gemm_s8x8x16_mk4_8x8x8::pack_B( | |||||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||||
| bool) const { | |||||
| matmul_mk4_8x8x8::gemm_s8x8x16_mk4_8x8x8_pack_B(out, in, ldin, x0, xmax, k0, kmax); | |||||
| } | } | ||||
| void gemm_s8x8x16_mk4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||||
| size_t M, size_t N, size_t K, dt_int16* C, | |||||
| size_t LDC, bool is_first_k, const dt_int16*, | |||||
| dt_int16*) const { | |||||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||||
| C_dtype.enumv() == DTypeEnum::Int16 && | |||||
| A_dtype.enumv() == DTypeEnum::Int8); | |||||
| void gemm_s8x8x16_mk4_8x8x8::kern( | |||||
| const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||||
| dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const { | |||||
| megdnn_assert( | |||||
| A_dtype.enumv() == B_dtype.enumv() && C_dtype.enumv() == DTypeEnum::Int16 && | |||||
| A_dtype.enumv() == DTypeEnum::Int8); | |||||
| megdnn_assert(is_first_k == true, "only impl is_first_k"); | megdnn_assert(is_first_k == true, "only impl is_first_k"); | ||||
| MEGDNN_MARK_USED_VAR(A_dtype); | MEGDNN_MARK_USED_VAR(A_dtype); | ||||
| MEGDNN_MARK_USED_VAR(B_dtype); | MEGDNN_MARK_USED_VAR(B_dtype); | ||||
| @@ -402,14 +388,14 @@ void gemm_s8x8x16_mk4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||||
| size_t n_idx = 0; | size_t n_idx = 0; | ||||
| const int8_t* cur_packB = packB; | const int8_t* cur_packB = packB; | ||||
| for (; n_idx + pack_n <= N; n_idx += pack_n) { | for (; n_idx + pack_n <= N; n_idx += pack_n) { | ||||
| matmul_mk4_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, | |||||
| is_first_k, pack_m, pack_n); | |||||
| matmul_mk4_8x8x8::kern_8x8( | |||||
| packA, cur_packB, K, output, LDC, is_first_k, pack_m, pack_n); | |||||
| output += pack_n * pack_size; | output += pack_n * pack_size; | ||||
| cur_packB += KSIZE8; | cur_packB += KSIZE8; | ||||
| } | } | ||||
| if (remain_n > 0) { | if (remain_n > 0) { | ||||
| matmul_mk4_8x8x8::kern_8x8_remain(packA, cur_packB, K, output, LDC, | |||||
| is_first_k, pack_m, remain_n); | |||||
| matmul_mk4_8x8x8::kern_8x8_remain( | |||||
| packA, cur_packB, K, output, LDC, is_first_k, pack_m, remain_n); | |||||
| output += remain_n * pack_size; | output += remain_n * pack_size; | ||||
| cur_packB += KSIZE8; | cur_packB += KSIZE8; | ||||
| } | } | ||||
| @@ -421,14 +407,14 @@ void gemm_s8x8x16_mk4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||||
| size_t n_idx = 0; | size_t n_idx = 0; | ||||
| const int8_t* cur_packB = packB; | const int8_t* cur_packB = packB; | ||||
| for (; n_idx + pack_n <= N; n_idx += pack_n) { | for (; n_idx + pack_n <= N; n_idx += pack_n) { | ||||
| matmul_mk4_8x8x8::kern_4x8(packA, cur_packB, K, output, LDC, | |||||
| is_first_k, 4, pack_n); | |||||
| matmul_mk4_8x8x8::kern_4x8( | |||||
| packA, cur_packB, K, output, LDC, is_first_k, 4, pack_n); | |||||
| output += pack_n * pack_size; | output += pack_n * pack_size; | ||||
| cur_packB += pack_n * K; | cur_packB += pack_n * K; | ||||
| } | } | ||||
| if (remain_n > 0) { | if (remain_n > 0) { | ||||
| matmul_mk4_8x8x8::kern_4x8_remain(packA, cur_packB, K, output, LDC, | |||||
| is_first_k, 4, remain_n); | |||||
| matmul_mk4_8x8x8::kern_4x8_remain( | |||||
| packA, cur_packB, K, output, LDC, is_first_k, 4, remain_n); | |||||
| output += remain_n * pack_size; | output += remain_n * pack_size; | ||||
| cur_packB += pack_n * K; | cur_packB += pack_n * K; | ||||
| } | } | ||||