| @@ -23,9 +23,9 @@ | |||
| #pragma GCC diagnostic ignored "-Wunused-parameter" | |||
| #pragma GCC diagnostic ignored "-Wdeprecated-declarations" | |||
| #pragma GCC diagnostic ignored "-Wsign-compare" | |||
| #include <hip/hip_runtime_api.h> | |||
| #include <hip/hip_runtime.h> | |||
| #include <hip/hip_fp16.h> | |||
| #include <hip/hip_runtime.h> | |||
| #include <hip/hip_runtime_api.h> | |||
| #pragma GCC diagnostic pop | |||
| #if !defined(__HIP_PLATFORM_HCC__) | |||
| @@ -11,10 +11,10 @@ | |||
| #pragma once | |||
| #include "megdnn/thin/function.h" | |||
| #include "megcore_cdefs.h" | |||
| #include <cstddef> | |||
| #include <memory> | |||
| #include "megcore_cdefs.h" | |||
| #include "megdnn/thin/function.h" | |||
| #include "megdnn/internal/visibility_prologue.h" | |||
| @@ -26,36 +26,35 @@ namespace megcore { | |||
| * the caller thread immediately. | |||
| */ | |||
| class CPUDispatcher { | |||
| public: | |||
| using Task = megdnn::thin_function<void()>; | |||
| using MultiThreadingTask = megdnn::thin_function<void(size_t, size_t)>; | |||
| virtual ~CPUDispatcher() noexcept; | |||
| /*! | |||
| * \brief dispatch a task on the computing thread | |||
| * \param task the task that would be moved away | |||
| */ | |||
| virtual void dispatch(Task&& task) = 0; | |||
| /*! | |||
| * \brief dispatch a multithreading task on the computing thread | |||
| * \param task the task would be moved away | |||
| * \param parallelism the parallelism of the task. | |||
| */ | |||
| virtual void dispatch(MultiThreadingTask&& task, | |||
| size_t parallelism) = 0; | |||
| /*! | |||
| * \brief synchronize the calling thread with the computing thread | |||
| */ | |||
| virtual void sync() = 0; | |||
| /*! | |||
| * \brief the computing thread number. | |||
| */ | |||
| virtual size_t nr_threads() = 0; | |||
| public: | |||
| using Task = megdnn::thin_function<void()>; | |||
| using MultiThreadingTask = megdnn::thin_function<void(size_t, size_t)>; | |||
| virtual ~CPUDispatcher() noexcept; | |||
| /*! | |||
| * \brief dispatch a task on the computing thread | |||
| * \param task the task that would be moved away | |||
| */ | |||
| virtual void dispatch(Task&& task) = 0; | |||
| /*! | |||
| * \brief dispatch a multithreading task on the computing thread | |||
| * \param task the task would be moved away | |||
| * \param parallelism the parallelism of the task. | |||
| */ | |||
| virtual void dispatch(MultiThreadingTask&& task, size_t parallelism) = 0; | |||
| /*! | |||
| * \brief synchronize the calling thread with the computing thread | |||
| */ | |||
| virtual void sync() = 0; | |||
| /*! | |||
| * \brief the computing thread number. | |||
| */ | |||
| virtual size_t nr_threads() = 0; | |||
| }; | |||
| } // namespace megcore | |||
| } // namespace megcore | |||
| using MegcoreCPUDispatcher = megcore::CPUDispatcher; | |||
| @@ -63,75 +62,62 @@ using MegcoreCPUDispatcher = megcore::CPUDispatcher; | |||
| * \brief Layer 1: device handle | |||
| */ | |||
| struct megcoreDeviceContext; | |||
| typedef struct megcoreDeviceContext *megcoreDeviceHandle_t; | |||
| typedef struct megcoreDeviceContext* megcoreDeviceHandle_t; | |||
| megcoreStatus_t megcoreCreateDeviceHandle( | |||
| megcoreDeviceHandle_t *handle, | |||
| megcorePlatform_t platform, | |||
| int deviceID = -1, | |||
| megcoreDeviceHandle_t* handle, megcorePlatform_t platform, int deviceID = -1, | |||
| unsigned int flags = 0); | |||
| megcoreStatus_t megcoreDestroyDeviceHandle( | |||
| megcoreDeviceHandle_t handle); | |||
| megcoreStatus_t megcoreGetPlatform(megcoreDeviceHandle_t handle, | |||
| megcorePlatform_t *platform); | |||
| megcoreStatus_t megcoreGetDeviceID(megcoreDeviceHandle_t handle, | |||
| int *deviceID); | |||
| megcoreStatus_t megcoreGetMemAlignment(megcoreDeviceHandle_t handle, | |||
| size_t *memAlignmentInBytes); | |||
| megcoreStatus_t megcoreDestroyDeviceHandle(megcoreDeviceHandle_t handle); | |||
| megcoreStatus_t megcoreGetPlatform( | |||
| megcoreDeviceHandle_t handle, megcorePlatform_t* platform); | |||
| megcoreStatus_t megcoreGetDeviceID(megcoreDeviceHandle_t handle, int* deviceID); | |||
| megcoreStatus_t megcoreGetMemAlignment( | |||
| megcoreDeviceHandle_t handle, size_t* memAlignmentInBytes); | |||
| megcoreStatus_t megcoreGetDeviceFlags( | |||
| megcoreDeviceHandle_t handle, | |||
| unsigned int *flags); | |||
| megcoreDeviceHandle_t handle, unsigned int* flags); | |||
| megcoreStatus_t megcoreActivate(megcoreDeviceHandle_t handle); | |||
| megcoreStatus_t megcoreDeactivate(megcoreDeviceHandle_t handle); | |||
| megcoreStatus_t megcoreMalloc(megcoreDeviceHandle_t handle, | |||
| void **devPtr, size_t sizeInBytes); | |||
| megcoreStatus_t megcoreFree(megcoreDeviceHandle_t handle, | |||
| void *devPtr); | |||
| megcoreStatus_t megcoreMalloc( | |||
| megcoreDeviceHandle_t handle, void** devPtr, size_t sizeInBytes); | |||
| megcoreStatus_t megcoreFree(megcoreDeviceHandle_t handle, void* devPtr); | |||
| /** | |||
| * \brief Layer 2: computing handle | |||
| */ | |||
| struct megcoreComputingContext; | |||
| typedef struct megcoreComputingContext *megcoreComputingHandle_t; | |||
| typedef struct megcoreComputingContext* megcoreComputingHandle_t; | |||
| megcoreStatus_t megcoreCreateComputingHandle( | |||
| megcoreComputingHandle_t *compHandle, | |||
| megcoreDeviceHandle_t devHandle, | |||
| megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | |||
| unsigned int flags = 0); | |||
| megcoreStatus_t megcoreCreateComputingHandleWithCPUDispatcher( | |||
| megcoreComputingHandle_t *compHandle, | |||
| megcoreDeviceHandle_t devHandle, | |||
| megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | |||
| const std::shared_ptr<MegcoreCPUDispatcher>& dispatcher, | |||
| unsigned int flags = 0); | |||
| megcoreStatus_t megcoreDestroyComputingHandle( | |||
| megcoreComputingHandle_t handle); | |||
| megcoreStatus_t megcoreDestroyComputingHandle(megcoreComputingHandle_t handle); | |||
| megcoreStatus_t megcoreGetDeviceHandle( | |||
| megcoreComputingHandle_t compHandle, | |||
| megcoreDeviceHandle_t *devHandle); | |||
| megcoreComputingHandle_t compHandle, megcoreDeviceHandle_t* devHandle); | |||
| megcoreStatus_t megcoreGetComputingFlags( | |||
| megcoreComputingHandle_t handle, | |||
| unsigned int *flags); | |||
| megcoreComputingHandle_t handle, unsigned int* flags); | |||
| MegcoreCPUDispatcher* megcoreGetCPUDispatcher(megcoreComputingHandle_t handle); | |||
| megcoreStatus_t megcoreMemcpy( | |||
| megcoreComputingHandle_t handle, | |||
| void *dst, const void *src, size_t sizeInBytes, | |||
| megcoreComputingHandle_t handle, void* dst, const void* src, size_t sizeInBytes, | |||
| megcoreMemcpyKind_t kind); | |||
| megcoreStatus_t megcoreMemset( | |||
| megcoreComputingHandle_t handle, | |||
| void *dst, int value, size_t sizeInBytes); | |||
| megcoreComputingHandle_t handle, void* dst, int value, size_t sizeInBytes); | |||
| megcoreStatus_t megcoreSynchronize(megcoreComputingHandle_t handle); | |||
| /** | |||
| * \brief Miscellaneous | |||
| */ | |||
| const char *megcoreGetErrorName(megcoreStatus_t status); | |||
| const char* megcoreGetErrorName(megcoreStatus_t status); | |||
| #include "megdnn/internal/visibility_epilogue.h" | |||
| @@ -33,8 +33,7 @@ megcoreStatus_t createComputingHandleWithAtlasContext( | |||
| megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | |||
| unsigned int flags, const AtlasContext& ctx); | |||
| megcoreStatus_t getAtlasContext(megcoreComputingHandle_t handle, | |||
| AtlasContext* ctx); | |||
| megcoreStatus_t getAtlasContext(megcoreComputingHandle_t handle, AtlasContext* ctx); | |||
| namespace atlas { | |||
| //! convert acl error code to error string | |||
| @@ -47,12 +46,12 @@ inline megcoreStatus_t megcoreCreateComputingHandleWithACLStream( | |||
| megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | |||
| unsigned int flags, aclrtStream stream) { | |||
| megcore::AtlasContext ctx{stream}; | |||
| return megcore::createComputingHandleWithAtlasContext(compHandle, devHandle, | |||
| flags, ctx); | |||
| return megcore::createComputingHandleWithAtlasContext( | |||
| compHandle, devHandle, flags, ctx); | |||
| } | |||
| inline megcoreStatus_t megcoreGetACLStream(megcoreComputingHandle_t handle, | |||
| aclrtStream* stream) { | |||
| inline megcoreStatus_t megcoreGetACLStream( | |||
| megcoreComputingHandle_t handle, aclrtStream* stream) { | |||
| megcore::AtlasContext ctx; | |||
| auto ret = megcore::getAtlasContext(handle, &ctx); | |||
| *stream = ctx.stream; | |||
| @@ -34,8 +34,8 @@ megcoreStatus_t createComputingHandleWithCambriconContext( | |||
| megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | |||
| unsigned int flags, const CambriconContext& ctx); | |||
| megcoreStatus_t getCambriconContext(megcoreComputingHandle_t handle, | |||
| CambriconContext* ctx); | |||
| megcoreStatus_t getCambriconContext( | |||
| megcoreComputingHandle_t handle, CambriconContext* ctx); | |||
| } // namespace megcore | |||
| @@ -58,4 +58,3 @@ static inline megcoreStatus_t megcoreGetCNRTQueue( | |||
| #include "megdnn/internal/visibility_epilogue.h" | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -40,7 +40,6 @@ typedef enum { | |||
| megcoreErrorInternalError = 5, | |||
| } megcoreStatus_t; | |||
| /** | |||
| * \brief Memcpy kind | |||
| */ | |||
| @@ -70,6 +69,6 @@ struct AsyncErrorInfo { | |||
| char msg[228]; | |||
| int msg_args[4]; | |||
| }; | |||
| } // namespace megcore | |||
| } // namespace megcore | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -33,8 +33,7 @@ megcoreStatus_t createComputingHandleWithCUDAContext( | |||
| megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | |||
| unsigned int flags, const CudaContext& ctx); | |||
| megcoreStatus_t getCUDAContext(megcoreComputingHandle_t handle, | |||
| CudaContext* ctx); | |||
| megcoreStatus_t getCUDAContext(megcoreComputingHandle_t handle, CudaContext* ctx); | |||
| } // namespace megcore | |||
| @@ -43,8 +42,8 @@ static inline megcoreStatus_t megcoreCreateComputingHandleWithCUDAStream( | |||
| unsigned int flags, cudaStream_t stream) { | |||
| megcore::CudaContext ctx; | |||
| ctx.stream = stream; | |||
| return megcore::createComputingHandleWithCUDAContext(compHandle, devHandle, | |||
| flags, ctx); | |||
| return megcore::createComputingHandleWithCUDAContext( | |||
| compHandle, devHandle, flags, ctx); | |||
| } | |||
| static inline megcoreStatus_t megcoreGetCUDAStream( | |||
| @@ -23,7 +23,9 @@ struct ROCMContext { | |||
| hipStream_t stream = nullptr; | |||
| static std::atomic_bool sm_miopen_algo_search; | |||
| static inline bool enable_miopen_algo_search() { return sm_miopen_algo_search.load(); } | |||
| static inline bool enable_miopen_algo_search() { | |||
| return sm_miopen_algo_search.load(); | |||
| } | |||
| static inline void enable_miopen_algo_search(bool enable_algo_search) { | |||
| sm_miopen_algo_search.store(enable_algo_search); | |||
| } | |||
| @@ -40,8 +42,7 @@ megcoreStatus_t createComputingHandleWithROCMContext( | |||
| megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | |||
| unsigned int flags, const ROCMContext& ctx); | |||
| megcoreStatus_t getROCMContext(megcoreComputingHandle_t handle, | |||
| ROCMContext* ctx); | |||
| megcoreStatus_t getROCMContext(megcoreComputingHandle_t handle, ROCMContext* ctx); | |||
| // Set MIOpen algo search enabled or disabled | |||
| megcoreStatus_t enableMIOpenAlgoSearch(bool enable_algo_search = true); | |||
| @@ -55,8 +56,8 @@ static inline megcoreStatus_t megcoreCreateComputingHandleWithROCMStream( | |||
| unsigned int flags, hipStream_t stream) { | |||
| megcore::ROCMContext ctx; | |||
| ctx.stream = stream; | |||
| return megcore::createComputingHandleWithROCMContext(compHandle, devHandle, | |||
| flags, ctx); | |||
| return megcore::createComputingHandleWithROCMContext( | |||
| compHandle, devHandle, flags, ctx); | |||
| } | |||
| static inline megcoreStatus_t megcoreGetROCMStream( | |||
| @@ -10,7 +10,7 @@ | |||
| */ | |||
| #pragma once | |||
| #include "megdnn/version.h" | |||
| #include "megdnn/oprs.h" | |||
| #include "megdnn/version.h" | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -14,20 +14,20 @@ | |||
| #include "megdnn/config/config.h" | |||
| #if defined(__GNUC__) || defined(__clang__) | |||
| #if !defined (__clang__) | |||
| // gcc specific | |||
| #define GCC_VERSION (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__) | |||
| #if GCC_VERSION < 40800 | |||
| #error "GCC version should be at least 4.8.0." | |||
| #endif // GCC_VERSION < 40800 | |||
| #endif // !defined(__clang__) | |||
| #if !defined(__clang__) | |||
| // gcc specific | |||
| #define GCC_VERSION (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__) | |||
| #if GCC_VERSION < 40800 | |||
| #error "GCC version should be at least 4.8.0." | |||
| #endif // GCC_VERSION < 40800 | |||
| #endif // !defined(__clang__) | |||
| #ifndef megdnn_trap | |||
| #define megdnn_trap() __builtin_trap() | |||
| #endif | |||
| #ifndef megdnn_trap | |||
| #define megdnn_trap() __builtin_trap() | |||
| #endif | |||
| #define megdnn_likely(v) __builtin_expect(bool(v), 1) | |||
| #define megdnn_unlikely(v) __builtin_expect(bool(v), 0) | |||
| #define megdnn_likely(v) __builtin_expect(bool(v), 1) | |||
| #define megdnn_unlikely(v) __builtin_expect(bool(v), 0) | |||
| #if !defined(__clang__) && MEGDNN_ARMV7 && !defined(NDEBUG) | |||
| //! Thumb2 limit code length | |||
| @@ -36,123 +36,122 @@ | |||
| #define MEGDNN_ALWAYS_INLINE inline __attribute__((__always_inline__)) | |||
| #endif | |||
| #define MEGDNN_DEPRECATED __attribute__((deprecated)) | |||
| #define MEGDNN_PACKED __attribute__((packed)) | |||
| #define MEGDNN_CONSTEXPR constexpr | |||
| #define MEGDNN_NOEXCEPT noexcept | |||
| #define MEGDNN_STATIC_ASSERT static_assert | |||
| #define MEGDNN_FINAL final | |||
| #define MEGDNN_NORETURN __attribute__((noreturn)) | |||
| #define MEGDNN_WARN_UNUSED_RESULT __attribute__((warn_unused_result)) | |||
| #define MEGDNN_ATTRIBUTE_TARGET(simd) __attribute__((target(simd))) | |||
| #if defined(__clang_major__) && (__clang_major__ >= 7) | |||
| #define MEGDNN_LAMBDA_ATTRIBUTE_TARGET(simd) __attribute__((target(simd))) | |||
| #else | |||
| #define MEGDNN_LAMBDA_ATTRIBUTE_TARGET(simd) [[gnu::target(simd)]] | |||
| #endif | |||
| #define MEGDNN_NOINLINE __attribute__((noinline)) | |||
| #define megdnn_isatty(x) isatty(x) | |||
| #define MEGDNN_DEPRECATED __attribute__((deprecated)) | |||
| #define MEGDNN_PACKED __attribute__((packed)) | |||
| #define MEGDNN_CONSTEXPR constexpr | |||
| #define MEGDNN_NOEXCEPT noexcept | |||
| #define MEGDNN_STATIC_ASSERT static_assert | |||
| #define MEGDNN_FINAL final | |||
| #define MEGDNN_NORETURN __attribute__((noreturn)) | |||
| #define MEGDNN_WARN_UNUSED_RESULT __attribute__((warn_unused_result)) | |||
| #define MEGDNN_ATTRIBUTE_TARGET(simd) __attribute__((target(simd))) | |||
| #if defined(__clang_major__) && (__clang_major__ >= 7) | |||
| #define MEGDNN_LAMBDA_ATTRIBUTE_TARGET(simd) __attribute__((target(simd))) | |||
| #else | |||
| #define MEGDNN_LAMBDA_ATTRIBUTE_TARGET(simd) [[gnu::target(simd)]] | |||
| #endif | |||
| #define MEGDNN_NOINLINE __attribute__((noinline)) | |||
| #define megdnn_isatty(x) isatty(x) | |||
| #elif defined(__INTEL_COMPILER) || defined(_MSC_VER) | |||
| #ifndef megdnn_trap | |||
| #define megdnn_trap() __debugbreak() | |||
| #endif | |||
| #define megdnn_likely(v) (bool(v)) | |||
| #define megdnn_likely(v) (bool(v)) | |||
| #define megdnn_unlikely(v) (bool(v)) | |||
| #define MEGDNN_DEPRECATED | |||
| #define MEGDNN_PACKED | |||
| #define MEGDNN_CONSTEXPR constexpr | |||
| #define MEGDNN_NOEXCEPT noexcept | |||
| #define MEGDNN_CONSTEXPR constexpr | |||
| #define MEGDNN_NOEXCEPT noexcept | |||
| #define MEGDNN_STATIC_ASSERT static_assert | |||
| #define MEGDNN_FINAL final | |||
| #define MEGDNN_FINAL final | |||
| #if defined(_MSC_VER) | |||
| #define MEGDNN_NORETURN __declspec(noreturn) | |||
| #define MEGDNN_NOINLINE __declspec(noinline) | |||
| #define MEGDNN_NORETURN __declspec(noreturn) | |||
| #define MEGDNN_NOINLINE __declspec(noinline) | |||
| #else | |||
| #define MEGDNN_NORETURN | |||
| #define MEGDNN_FORCE_NOINLINE | |||
| #endif // _MSC_VER | |||
| #define MEGDNN_NORETURN | |||
| #define MEGDNN_FORCE_NOINLINE | |||
| #endif // _MSC_VER | |||
| #define MEGDNN_WARN_UNUSED_RESULT | |||
| #define megdnn_isatty(x) _isatty(x) | |||
| #else | |||
| #error "unknown compiler" | |||
| #endif // __GNUC__ | |||
| #error "unknown compiler" | |||
| #endif // __GNUC__ | |||
| // __cpp_exceptions and __cpp_rtti is referred from | |||
| // https://isocpp.org/std/standing-documentssd-6-sg10-feature-test-recommendations | |||
| // gcc < 5 does not define __cpp_exceptions but __EXCEPTIONS, | |||
| // gcc < 5 does not define __cpp_exceptions but __EXCEPTIONS, | |||
| // similar for __GXX_RTTI | |||
| // _CPPUNWIND and _CPPRTTI is used by MSVC, see | |||
| // https://docs.microsoft.com/en-us/cpp/preprocessor/predefined-macrosview=vs-2019 | |||
| #ifndef MEGDNN_ENABLE_EXCEPTIONS | |||
| #if __cpp_exceptions || __EXCEPTIONS || \ | |||
| (defined(_MSC_VER) && defined(_CPPUNWIND)) | |||
| #define MEGDNN_ENABLE_EXCEPTIONS 1 | |||
| #else | |||
| #define MEGDNN_ENABLE_EXCEPTIONS 0 | |||
| #endif | |||
| #if __cpp_exceptions || __EXCEPTIONS || (defined(_MSC_VER) && defined(_CPPUNWIND)) | |||
| #define MEGDNN_ENABLE_EXCEPTIONS 1 | |||
| #else | |||
| #define MEGDNN_ENABLE_EXCEPTIONS 0 | |||
| #endif | |||
| #endif | |||
| #ifndef MEGDNN_ENABLE_RTTI | |||
| #if __cpp_rtti || __GXX_RTTI || (defined(_MSC_VER) && defined(_CPPRTTI)) | |||
| #define MEGDNN_ENABLE_RTTI 1 | |||
| #else | |||
| #define MEGDNN_ENABLE_RTTI 0 | |||
| #endif | |||
| #if __cpp_rtti || __GXX_RTTI || (defined(_MSC_VER) && defined(_CPPRTTI)) | |||
| #define MEGDNN_ENABLE_RTTI 1 | |||
| #else | |||
| #define MEGDNN_ENABLE_RTTI 0 | |||
| #endif | |||
| #endif | |||
| #ifdef __CUDACC__ | |||
| #define MEGDNN_CC_CUDA 1 | |||
| #undef MEGDNN_CONSTEXPR | |||
| #define MEGDNN_CONSTEXPR const | |||
| #define MEGDNN_CC_CUDA 1 | |||
| #undef MEGDNN_CONSTEXPR | |||
| #define MEGDNN_CONSTEXPR const | |||
| #if defined(__CUDACC_VER_MAJOR__) | |||
| #if __CUDACC_VER_MAJOR__ >= 9 | |||
| #undef MEGDNN_STATIC_ASSERT | |||
| #define MEGDNN_STATIC_ASSERT(cond, msg) static_assert(cond, msg); | |||
| #undef MEGDNN_STATIC_ASSERT | |||
| #define MEGDNN_STATIC_ASSERT(cond, msg) static_assert(cond, msg); | |||
| #else | |||
| #undef MEGDNN_STATIC_ASSERT | |||
| #define MEGDNN_STATIC_ASSERT(cond, msg) | |||
| #undef MEGDNN_STATIC_ASSERT | |||
| #define MEGDNN_STATIC_ASSERT(cond, msg) | |||
| #endif | |||
| #endif | |||
| #define nullptr NULL | |||
| #undef MEGDNN_FINAL | |||
| #define MEGDNN_FINAL | |||
| #define nullptr NULL | |||
| #undef MEGDNN_FINAL | |||
| #define MEGDNN_FINAL | |||
| #elif defined(__HIPCC__) | |||
| #define MEGDNN_CC_CUDA 1 | |||
| #define MEGDNN_CC_CUDA 1 | |||
| #else | |||
| #define MEGDNN_CC_HOST 1 | |||
| #endif // __CUDACC__ | |||
| #define MEGDNN_CC_HOST 1 | |||
| #endif // __CUDACC__ | |||
| // MEGDNN_HOST and MEGDNN_DEVICE | |||
| #if MEGDNN_CC_CUDA | |||
| #define MEGDNN_HOST __host__ | |||
| #define MEGDNN_DEVICE __device__ | |||
| #define MEGDNN_HOST __host__ | |||
| #define MEGDNN_DEVICE __device__ | |||
| #else | |||
| #define MEGDNN_HOST | |||
| #define MEGDNN_DEVICE | |||
| #define MEGDNN_HOST | |||
| #define MEGDNN_DEVICE | |||
| #endif | |||
| #if MEGDNN_CC_CUDA | |||
| #define MEGDNN_FORCE_INLINE __forceinline__ | |||
| #define MEGDNN_FORCE_INLINE __forceinline__ | |||
| #else | |||
| #if __GNUC__ || __has_attribute(always_inline) | |||
| #define MEGDNN_FORCE_INLINE inline __attribute__((always_inline)) | |||
| #define MEGDNN_FORCE_INLINE inline __attribute__((always_inline)) | |||
| #else | |||
| #define MEGDNN_FORCE_INLINE inline | |||
| #define MEGDNN_FORCE_INLINE inline | |||
| #endif | |||
| #endif | |||
| #if defined(_MSC_VER) || defined(WIN32) | |||
| #define ATTR_ALIGNED(v) __declspec(align(v)) | |||
| #define ATTR_ALIGNED(v) __declspec(align(v)) | |||
| #else | |||
| #define ATTR_ALIGNED(v) __attribute__((aligned(v))) | |||
| #define ATTR_ALIGNED(v) __attribute__((aligned(v))) | |||
| #endif | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -16,10 +16,10 @@ | |||
| #include "megdnn/internal/defs.h" | |||
| #if MEGDNN_CC_HOST | |||
| #include <cstdarg> | |||
| #include <string> | |||
| #include <type_traits> | |||
| #include <vector> | |||
| #include <cstdarg> | |||
| #include "megdnn/thin/small_vector.h" | |||
| #endif // MEGDNN_CC_HOST | |||
| @@ -35,8 +35,7 @@ class ErrorHandler { | |||
| protected: | |||
| MEGDNN_NORETURN virtual void do_on_megdnn_error(const std::string& msg) = 0; | |||
| MEGDNN_NORETURN virtual void do_on_tensor_reshape_error( | |||
| const std::string& msg) { | |||
| MEGDNN_NORETURN virtual void do_on_tensor_reshape_error(const std::string& msg) { | |||
| on_megdnn_error(msg); | |||
| } | |||
| @@ -70,8 +69,9 @@ public: | |||
| #if MEGDNN_CC_HOST | |||
| enum class LogLevel { DEBUG, INFO, WARN, ERROR }; | |||
| typedef void (*LogHandler)(LogLevel level, const char* file, const char* func, | |||
| int line, const char* fmt, va_list ap); | |||
| typedef void (*LogHandler)( | |||
| LogLevel level, const char* file, const char* func, int line, const char* fmt, | |||
| va_list ap); | |||
| /*! | |||
| * \brief set the callback to receive all log messages | |||
| @@ -144,8 +144,7 @@ struct TensorLayout : public TensorShape { | |||
| ptrdiff_t low_elem, low_byte; | |||
| size_t high_elem, high_byte; | |||
| Span(ptrdiff_t low_elem, ptrdiff_t low_byte, size_t high_elem, | |||
| size_t high_byte) | |||
| Span(ptrdiff_t low_elem, ptrdiff_t low_byte, size_t high_elem, size_t high_byte) | |||
| : low_elem(low_elem), | |||
| low_byte(low_byte), | |||
| high_elem(high_elem), | |||
| @@ -235,11 +234,13 @@ struct TensorLayout : public TensorShape { | |||
| TensorLayout(const TensorShape& shape, DType dtype, Format format); | |||
| //! creating layout with user-specified shape and stride. | |||
| TensorLayout(const TensorShape& shape, const std::vector<ptrdiff_t>& stride, | |||
| DType dtype); | |||
| TensorLayout( | |||
| const TensorShape& shape, const std::vector<ptrdiff_t>& stride, | |||
| DType dtype); | |||
| TensorLayout(const TensorShape& shape, const std::vector<ptrdiff_t>& stride, | |||
| DType dtype, Format format); | |||
| TensorLayout( | |||
| const TensorShape& shape, const std::vector<ptrdiff_t>& stride, DType dtype, | |||
| Format format); | |||
| /* =================== inplace modifiers =================== */ | |||
| @@ -310,8 +311,7 @@ struct TensorLayout : public TensorShape { | |||
| * | |||
| * \throw TensorReshapeError if no stride exists for target shape. | |||
| */ | |||
| TensorLayout reshape(const TensorShape& shape) const | |||
| MEGDNN_WARN_UNUSED_RESULT; | |||
| TensorLayout reshape(const TensorShape& shape) const MEGDNN_WARN_UNUSED_RESULT; | |||
| /*! | |||
| * \brief try to reshape to another view; return whether these two shapes | |||
| @@ -319,15 +319,14 @@ struct TensorLayout : public TensorShape { | |||
| * \return true iff there exists target stride so this layout can be | |||
| * converted to target shape and the elements can match. | |||
| */ | |||
| bool try_reshape(TensorLayout& output, | |||
| const TensorShape& shape) const MEGDNN_WARN_UNUSED_RESULT; | |||
| bool try_reshape(TensorLayout& output, const TensorShape& shape) const | |||
| MEGDNN_WARN_UNUSED_RESULT; | |||
| /*! | |||
| * \brief Broadcast on dims with shape == 1 to match target *shape*. | |||
| * \throw TensorReshapeError if could not be satisfied | |||
| */ | |||
| TensorLayout broadcast(const TensorShape& shape) const | |||
| MEGDNN_WARN_UNUSED_RESULT; | |||
| TensorLayout broadcast(const TensorShape& shape) const MEGDNN_WARN_UNUSED_RESULT; | |||
| /*! | |||
| * \brief Collapse consecutive axes with contiguous layout together | |||
| @@ -441,8 +440,7 @@ struct Workspace { | |||
| Workspace() : raw_ptr(NULL), size(0) {} | |||
| Workspace(dt_byte* raw_ptr_, size_t size_) | |||
| : raw_ptr(raw_ptr_), size(size_) {} | |||
| Workspace(dt_byte* raw_ptr_, size_t size_) : raw_ptr(raw_ptr_), size(size_) {} | |||
| template <typename T> | |||
| T* ptr(size_t offset_in_bytes = 0) const { | |||
| @@ -467,9 +465,8 @@ public: | |||
| * \param shape requested output shape | |||
| * \param user_data extra user data passed in DynOutMallocPolicyCall | |||
| */ | |||
| virtual TensorND alloc_output(size_t id, DType dtype, | |||
| const TensorShape& shape, | |||
| void* user_data) = 0; | |||
| virtual TensorND alloc_output( | |||
| size_t id, DType dtype, const TensorShape& shape, void* user_data) = 0; | |||
| /*! | |||
| * \brief allocate workspace memory | |||
| @@ -508,19 +505,15 @@ struct DynOutMallocPolicyCall { | |||
| */ | |||
| template <typename T = void, typename elem = T> | |||
| T* alloc_workspace(size_t nr_elem) { | |||
| using real_elem = | |||
| typename std::conditional<std::is_same<elem, void>::value, | |||
| uint8_t, elem>::type; | |||
| return static_cast<T*>(policy->alloc_workspace( | |||
| nr_elem * sizeof(real_elem), user_data)); | |||
| using real_elem = typename std::conditional< | |||
| std::is_same<elem, void>::value, uint8_t, elem>::type; | |||
| return static_cast<T*>( | |||
| policy->alloc_workspace(nr_elem * sizeof(real_elem), user_data)); | |||
| } | |||
| void free_workspace(void* ptr) { | |||
| return policy->free_workspace(ptr, user_data); | |||
| } | |||
| void free_workspace(void* ptr) { return policy->free_workspace(ptr, user_data); } | |||
| }; | |||
| template <typename T> | |||
| class EnumClassBit { | |||
| std::underlying_type_t<T> m_val; | |||
| @@ -528,8 +521,7 @@ class EnumClassBit { | |||
| constexpr EnumClassBit(std::underlying_type_t<T> v) : m_val(v) {} | |||
| public: | |||
| constexpr EnumClassBit(T v) | |||
| : m_val(static_cast<std::underlying_type_t<T>>(v)) {} | |||
| constexpr EnumClassBit(T v) : m_val(static_cast<std::underlying_type_t<T>>(v)) {} | |||
| constexpr operator T() const { return static_cast<T>(m_val); } | |||
| @@ -542,7 +534,7 @@ public: | |||
| DEF_OPR(&) | |||
| DEF_OPR(|) | |||
| DEF_OPR (^) | |||
| DEF_OPR(^) | |||
| constexpr EnumClassBit operator~() const { return ~m_val; } | |||
| @@ -553,14 +545,13 @@ public: | |||
| } // namespace megdnn | |||
| #define _MEGDNN_DECBO_SINGLE_OPR(cls, op) \ | |||
| inline constexpr ::megdnn::EnumClassBit<cls> operator op(cls x, cls y) { \ | |||
| return ::megdnn::EnumClassBit<cls>(x) \ | |||
| op ::megdnn::EnumClassBit<cls>(y); \ | |||
| } \ | |||
| inline constexpr ::megdnn::EnumClassBit<cls> operator op( \ | |||
| ::megdnn::EnumClassBit<cls> x, cls y) { \ | |||
| return x op ::megdnn::EnumClassBit<cls>(y); \ | |||
| #define _MEGDNN_DECBO_SINGLE_OPR(cls, op) \ | |||
| inline constexpr ::megdnn::EnumClassBit<cls> operator op(cls x, cls y) { \ | |||
| return ::megdnn::EnumClassBit<cls>(x) op ::megdnn::EnumClassBit<cls>(y); \ | |||
| } \ | |||
| inline constexpr ::megdnn::EnumClassBit<cls> operator op( \ | |||
| ::megdnn::EnumClassBit<cls> x, cls y) { \ | |||
| return x op ::megdnn::EnumClassBit<cls>(y); \ | |||
| } | |||
| #define _MEGDNN_DECBO_SINGLE_OPR_ASSIGN(cls, op) \ | |||
| @@ -14,14 +14,14 @@ | |||
| #include "megbrain_build_config.h" | |||
| #if MGB_ENABLE_GETENV | |||
| #define MGB_GETENV ::std::getenv | |||
| #define MGB_GETENV ::std::getenv | |||
| #else | |||
| #define MGB_GETENV(_name) static_cast<char*>(nullptr) | |||
| #define MGB_GETENV(_name) static_cast<char*>(nullptr) | |||
| #endif | |||
| #ifdef WIN32 | |||
| #define unsetenv(_name) _putenv_s(_name, ""); | |||
| #define setenv(name,value,overwrite) _putenv_s(name,value) | |||
| #define unsetenv(_name) _putenv_s(_name, ""); | |||
| #define setenv(name, value, overwrite) _putenv_s(name, value) | |||
| #endif | |||
| namespace megdnn { | |||
| @@ -32,8 +32,7 @@ namespace megdnn { | |||
| */ | |||
| template <class Opr, typename... Args> | |||
| bool has_available_algo(Opr* opr, Args&&... args) { | |||
| const typename Opr::AlgoBase::SizeArgs size_args( | |||
| opr, std::forward<Args>(args)...); | |||
| const typename Opr::AlgoBase::SizeArgs size_args(opr, std::forward<Args>(args)...); | |||
| for (auto i : Opr::algo_pack().all_algos) { | |||
| if (i->is_available(size_args)) { | |||
| return true; | |||
| @@ -42,6 +41,6 @@ bool has_available_algo(Opr* opr, Args&&... args) { | |||
| return false; | |||
| } | |||
| } | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -17,11 +17,11 @@ | |||
| #include "megdnn/internal/visibility_prologue.h" | |||
| namespace megdnn { | |||
| std::unique_ptr<Handle> make_cuda_handle_with_stream(cudaStream_t stream, | |||
| int device_id = -1); | |||
| cudaStream_t get_cuda_stream(Handle *handle); | |||
| std::unique_ptr<Handle> make_cuda_handle_with_stream( | |||
| cudaStream_t stream, int device_id = -1); | |||
| cudaStream_t get_cuda_stream(Handle* handle); | |||
| } // namespace megdnn | |||
| } // namespace megdnn | |||
| #include "megdnn/internal/visibility_epilogue.h" | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -3,17 +3,22 @@ | |||
| * | |||
| * Copyright (c) 2012-2013 Christian Rau <rauy@users.sourceforge.net> | |||
| * | |||
| * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation | |||
| * files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, | |||
| * modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the | |||
| * Software is furnished to do so, subject to the following conditions: | |||
| * | |||
| * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. | |||
| * | |||
| * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE | |||
| * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR | |||
| * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, | |||
| * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | |||
| * Permission is hereby granted, free of charge, to any person obtaining a copy of this | |||
| * software and associated documentation files (the "Software"), to deal in the Software | |||
| * without restriction, including without limitation the rights to use, copy, modify, | |||
| * merge, publish, distribute, sublicense, and/or sell copies of the Software, and to | |||
| * permit persons to whom the Software is furnished to do so, subject to the following | |||
| * conditions: | |||
| * | |||
| * The above copyright notice and this permission notice shall be included in all copies | |||
| * or substantial portions of the Software. | |||
| * | |||
| * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, | |||
| * INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A | |||
| * PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT | |||
| * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF | |||
| * CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE | |||
| * OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | |||
| * | |||
| * Version 1.11.0 | |||
| * \file | |||
| @@ -41,8 +46,8 @@ | |||
| #undef HALF_NOEXCEPT | |||
| #undef HALF_NOTHROW | |||
| #ifdef HALF_POP_WARNINGS | |||
| #pragma warning(pop) | |||
| #undef HALF_POP_WARNINGS | |||
| #pragma warning(pop) | |||
| #undef HALF_POP_WARNINGS | |||
| #endif | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -3,17 +3,22 @@ | |||
| * | |||
| * Copyright (c) 2012-2013 Christian Rau <rauy@users.sourceforge.net> | |||
| * | |||
| * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation | |||
| * files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, | |||
| * modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the | |||
| * Software is furnished to do so, subject to the following conditions: | |||
| * Permission is hereby granted, free of charge, to any person obtaining a copy of this | |||
| * software and associated documentation files (the "Software"), to deal in the Software | |||
| * without restriction, including without limitation the rights to use, copy, modify, | |||
| * merge, publish, distribute, sublicense, and/or sell copies of the Software, and to | |||
| * permit persons to whom the Software is furnished to do so, subject to the following | |||
| * conditions: | |||
| * | |||
| * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. | |||
| * The above copyright notice and this permission notice shall be included in all copies | |||
| * or substantial portions of the Software. | |||
| * | |||
| * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE | |||
| * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR | |||
| * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, | |||
| * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | |||
| * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, | |||
| * INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A | |||
| * PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT | |||
| * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF | |||
| * CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE | |||
| * OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. | |||
| * | |||
| * Version 1.11.0 | |||
| * \file | |||
| @@ -39,166 +44,164 @@ | |||
| #include "megdnn/arch.h" | |||
| /// Combined gcc version number. | |||
| #define HALF_GNUC_VERSION (__GNUC__*100+__GNUC_MINOR__) | |||
| #define HALF_GNUC_VERSION (__GNUC__ * 100 + __GNUC_MINOR__) | |||
| //check C++11 language features | |||
| #if defined(__clang__) //clang | |||
| #if __has_feature(cxx_static_assert) && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||
| #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||
| #endif | |||
| #if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR) | |||
| #define HALF_ENABLE_CPP11_CONSTEXPR 1 | |||
| #endif | |||
| #if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT) | |||
| #define HALF_ENABLE_CPP11_NOEXCEPT 1 | |||
| #endif | |||
| #if __has_feature(cxx_user_literals) && !defined(HALF_ENABLE_CPP11_USER_LITERALS) | |||
| #define HALF_ENABLE_CPP11_USER_LITERALS 1 | |||
| #endif | |||
| #if (defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && !defined(HALF_ENABLE_CPP11_LONG_LONG) | |||
| #define HALF_ENABLE_CPP11_LONG_LONG 1 | |||
| #endif | |||
| /*#elif defined(__INTEL_COMPILER) //Intel C++ | |||
| #if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) ???????? | |||
| #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||
| #endif | |||
| #if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) ???????? | |||
| #define HALF_ENABLE_CPP11_CONSTEXPR 1 | |||
| #endif | |||
| #if __INTEL_COMPILER >= 1300 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) ???????? | |||
| #define HALF_ENABLE_CPP11_NOEXCEPT 1 | |||
| #endif | |||
| #if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_LONG_LONG) ???????? | |||
| #define HALF_ENABLE_CPP11_LONG_LONG 1 | |||
| #endif*/ | |||
| #elif defined(__GNUC__) //gcc | |||
| #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L | |||
| #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||
| #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||
| #endif | |||
| #if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) | |||
| #define HALF_ENABLE_CPP11_CONSTEXPR 1 | |||
| #endif | |||
| #if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) | |||
| #define HALF_ENABLE_CPP11_NOEXCEPT 1 | |||
| #endif | |||
| #if HALF_GNUC_VERSION >= 407 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) | |||
| #define HALF_ENABLE_CPP11_USER_LITERALS 1 | |||
| #endif | |||
| #if !defined(HALF_ENABLE_CPP11_LONG_LONG) | |||
| #define HALF_ENABLE_CPP11_LONG_LONG 1 | |||
| #endif | |||
| #endif | |||
| #elif defined(_MSC_VER) //Visual C++ | |||
| #if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||
| #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||
| #endif | |||
| #if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG) | |||
| #define HALF_ENABLE_CPP11_LONG_LONG 1 | |||
| #endif | |||
| #define HALF_POP_WARNINGS 1 | |||
| #pragma warning(push) | |||
| //! 4521 and 4522 is multiple copy/assigment operator specified | |||
| #pragma warning(disable : 4099 4127 4146 4521 4522) //struct vs class, constant in if, negative unsigned | |||
| // check C++11 language features | |||
| #if defined(__clang__) // clang | |||
| #if __has_feature(cxx_static_assert) && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||
| #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||
| #endif | |||
| #if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR) | |||
| #define HALF_ENABLE_CPP11_CONSTEXPR 1 | |||
| #endif | |||
| #if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT) | |||
| #define HALF_ENABLE_CPP11_NOEXCEPT 1 | |||
| #endif | |||
| #if __has_feature(cxx_user_literals) && !defined(HALF_ENABLE_CPP11_USER_LITERALS) | |||
| #define HALF_ENABLE_CPP11_USER_LITERALS 1 | |||
| #endif | |||
| #if (defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && \ | |||
| !defined(HALF_ENABLE_CPP11_LONG_LONG) | |||
| #define HALF_ENABLE_CPP11_LONG_LONG 1 | |||
| #endif | |||
| /*#elif defined(__INTEL_COMPILER) | |||
| //Intel C++ #if __INTEL_COMPILER >= 1100 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||
| ???????? #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 #endif #if __INTEL_COMPILER >= | |||
| 1300 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) ???????? #define | |||
| HALF_ENABLE_CPP11_CONSTEXPR 1 #endif #if __INTEL_COMPILER >= 1300 && | |||
| !defined(HALF_ENABLE_CPP11_NOEXCEPT) ???????? #define | |||
| HALF_ENABLE_CPP11_NOEXCEPT 1 #endif #if __INTEL_COMPILER >= 1100 && | |||
| !defined(HALF_ENABLE_CPP11_LONG_LONG) ???????? #define | |||
| HALF_ENABLE_CPP11_LONG_LONG 1 #endif*/ | |||
| #elif defined(__GNUC__) // gcc | |||
| #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L | |||
| #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||
| #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||
| #endif | |||
| #if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) | |||
| #define HALF_ENABLE_CPP11_CONSTEXPR 1 | |||
| #endif | |||
| #if HALF_GNUC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) | |||
| #define HALF_ENABLE_CPP11_NOEXCEPT 1 | |||
| #endif | |||
| #if HALF_GNUC_VERSION >= 407 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) | |||
| #define HALF_ENABLE_CPP11_USER_LITERALS 1 | |||
| #endif | |||
| #if !defined(HALF_ENABLE_CPP11_LONG_LONG) | |||
| #define HALF_ENABLE_CPP11_LONG_LONG 1 | |||
| #endif | |||
| #endif | |||
| #elif defined(_MSC_VER) // Visual C++ | |||
| #if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) | |||
| #define HALF_ENABLE_CPP11_STATIC_ASSERT 1 | |||
| #endif | |||
| #if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG) | |||
| #define HALF_ENABLE_CPP11_LONG_LONG 1 | |||
| #endif | |||
| #define HALF_POP_WARNINGS 1 | |||
| #pragma warning(push) | |||
| //! 4521 and 4522 is multiple copy/assigment operator specified | |||
| #pragma warning(disable : 4099 4127 4146 4521 4522) // struct vs class, constant in if, | |||
| // negative unsigned | |||
| #endif | |||
| //check C++11 library features | |||
| // check C++11 library features | |||
| #include <utility> | |||
| #if defined(_LIBCPP_VERSION) //libc++ | |||
| #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 | |||
| #ifndef HALF_ENABLE_CPP11_TYPE_TRAITS | |||
| #define HALF_ENABLE_CPP11_TYPE_TRAITS 1 | |||
| #endif | |||
| #ifndef HALF_ENABLE_CPP11_CSTDINT | |||
| #define HALF_ENABLE_CPP11_CSTDINT 1 | |||
| #endif | |||
| #ifndef HALF_ENABLE_CPP11_CMATH | |||
| #define HALF_ENABLE_CPP11_CMATH 1 | |||
| #endif | |||
| #ifndef HALF_ENABLE_CPP11_HASH | |||
| #define HALF_ENABLE_CPP11_HASH 1 | |||
| #endif | |||
| #endif | |||
| #elif defined(__GLIBCXX__) //libstdc++ | |||
| #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 | |||
| #ifdef __clang__ | |||
| #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) | |||
| #define HALF_ENABLE_CPP11_TYPE_TRAITS 1 | |||
| #endif | |||
| #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT) | |||
| #define HALF_ENABLE_CPP11_CSTDINT 1 | |||
| #endif | |||
| #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH) | |||
| #define HALF_ENABLE_CPP11_CMATH 1 | |||
| #endif | |||
| #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH) | |||
| #define HALF_ENABLE_CPP11_HASH 1 | |||
| #endif | |||
| #else | |||
| #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CSTDINT) | |||
| #define HALF_ENABLE_CPP11_CSTDINT 1 | |||
| #endif | |||
| #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CMATH) | |||
| #define HALF_ENABLE_CPP11_CMATH 1 | |||
| #endif | |||
| #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH) | |||
| #define HALF_ENABLE_CPP11_HASH 1 | |||
| #endif | |||
| #endif | |||
| #endif | |||
| #elif defined(_CPPLIB_VER) //Dinkumware/Visual C++ | |||
| #if _CPPLIB_VER >= 520 | |||
| #ifndef HALF_ENABLE_CPP11_TYPE_TRAITS | |||
| #define HALF_ENABLE_CPP11_TYPE_TRAITS 1 | |||
| #endif | |||
| #ifndef HALF_ENABLE_CPP11_CSTDINT | |||
| #define HALF_ENABLE_CPP11_CSTDINT 1 | |||
| #endif | |||
| #ifndef HALF_ENABLE_CPP11_HASH | |||
| #define HALF_ENABLE_CPP11_HASH 1 | |||
| #endif | |||
| #endif | |||
| #if _CPPLIB_VER >= 610 | |||
| #ifndef HALF_ENABLE_CPP11_CMATH | |||
| #define HALF_ENABLE_CPP11_CMATH 1 | |||
| #endif | |||
| #endif | |||
| #if defined(_LIBCPP_VERSION) // libc++ | |||
| #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 | |||
| #ifndef HALF_ENABLE_CPP11_TYPE_TRAITS | |||
| #define HALF_ENABLE_CPP11_TYPE_TRAITS 1 | |||
| #endif | |||
| #ifndef HALF_ENABLE_CPP11_CSTDINT | |||
| #define HALF_ENABLE_CPP11_CSTDINT 1 | |||
| #endif | |||
| #ifndef HALF_ENABLE_CPP11_CMATH | |||
| #define HALF_ENABLE_CPP11_CMATH 1 | |||
| #endif | |||
| #ifndef HALF_ENABLE_CPP11_HASH | |||
| #define HALF_ENABLE_CPP11_HASH 1 | |||
| #endif | |||
| #endif | |||
| #elif defined(__GLIBCXX__) // libstdc++ | |||
| #if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 | |||
| #ifdef __clang__ | |||
| #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) | |||
| #define HALF_ENABLE_CPP11_TYPE_TRAITS 1 | |||
| #endif | |||
| #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT) | |||
| #define HALF_ENABLE_CPP11_CSTDINT 1 | |||
| #endif | |||
| #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH) | |||
| #define HALF_ENABLE_CPP11_CMATH 1 | |||
| #endif | |||
| #if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH) | |||
| #define HALF_ENABLE_CPP11_HASH 1 | |||
| #endif | |||
| #else | |||
| #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CSTDINT) | |||
| #define HALF_ENABLE_CPP11_CSTDINT 1 | |||
| #endif | |||
| #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CMATH) | |||
| #define HALF_ENABLE_CPP11_CMATH 1 | |||
| #endif | |||
| #if HALF_GNUC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH) | |||
| #define HALF_ENABLE_CPP11_HASH 1 | |||
| #endif | |||
| #endif | |||
| #endif | |||
| #elif defined(_CPPLIB_VER) // Dinkumware/Visual C++ | |||
| #if _CPPLIB_VER >= 520 | |||
| #ifndef HALF_ENABLE_CPP11_TYPE_TRAITS | |||
| #define HALF_ENABLE_CPP11_TYPE_TRAITS 1 | |||
| #endif | |||
| #ifndef HALF_ENABLE_CPP11_CSTDINT | |||
| #define HALF_ENABLE_CPP11_CSTDINT 1 | |||
| #endif | |||
| #ifndef HALF_ENABLE_CPP11_HASH | |||
| #define HALF_ENABLE_CPP11_HASH 1 | |||
| #endif | |||
| #endif | |||
| #if _CPPLIB_VER >= 610 | |||
| #ifndef HALF_ENABLE_CPP11_CMATH | |||
| #define HALF_ENABLE_CPP11_CMATH 1 | |||
| #endif | |||
| #endif | |||
| #endif | |||
| #undef HALF_GNUC_VERSION | |||
| //support constexpr | |||
| // support constexpr | |||
| #if HALF_ENABLE_CPP11_CONSTEXPR | |||
| #define HALF_CONSTEXPR constexpr | |||
| #define HALF_CONSTEXPR_CONST constexpr | |||
| #define HALF_CONSTEXPR constexpr | |||
| #define HALF_CONSTEXPR_CONST constexpr | |||
| #else | |||
| #define HALF_CONSTEXPR | |||
| #define HALF_CONSTEXPR_CONST const | |||
| #define HALF_CONSTEXPR | |||
| #define HALF_CONSTEXPR_CONST const | |||
| #endif | |||
| //support noexcept | |||
| // support noexcept | |||
| #if HALF_ENABLE_CPP11_NOEXCEPT | |||
| #define HALF_NOEXCEPT noexcept | |||
| #define HALF_NOTHROW noexcept | |||
| #define HALF_NOEXCEPT noexcept | |||
| #define HALF_NOTHROW noexcept | |||
| #else | |||
| #define HALF_NOEXCEPT | |||
| #define HALF_NOTHROW throw() | |||
| #define HALF_NOEXCEPT | |||
| #define HALF_NOTHROW throw() | |||
| #endif | |||
| #include <algorithm> | |||
| #include <limits> | |||
| #include <climits> | |||
| #include <cmath> | |||
| #include <cstring> | |||
| #include <ostream> | |||
| #include <istream> | |||
| #include <limits> | |||
| #include <ostream> | |||
| #if HALF_ENABLE_CPP11_TYPE_TRAITS | |||
| #include <type_traits> | |||
| #include <type_traits> | |||
| #endif | |||
| #if HALF_ENABLE_CPP11_CSTDINT | |||
| #include <cstdint> | |||
| #include <cstdint> | |||
| #endif | |||
| #if HALF_ENABLE_CPP11_HASH | |||
| #include <functional> | |||
| #include <functional> | |||
| #endif | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -12,8 +12,8 @@ | |||
| #pragma once | |||
| #include "megcore.h" | |||
| #include "megdnn/config/config.h" | |||
| #include "megdnn/basic_types.h" | |||
| #include "megdnn/config/config.h" | |||
| #include <functional> | |||
| #include <memory> | |||
| @@ -24,150 +24,147 @@ namespace megdnn { | |||
| class OperatorBase; | |||
| class Handle { | |||
| public: | |||
| enum class HandleType { | |||
| NAIVE = 0, | |||
| FALLBACK = 1, | |||
| X86 = 2, | |||
| ARM_COMMON = 3, | |||
| ARMV7 = 4, | |||
| AARCH64 = 5, | |||
| CUDA = 6, | |||
| ROCM = 11, | |||
| ATLAS = 13, | |||
| CAMBRICON = 12, | |||
| }; | |||
| //! Device vendor | |||
| enum class HandleVendorType : uint32_t { | |||
| NOT_SPEC = 0, | |||
| MALI = 1, | |||
| ADRENO = 2, | |||
| CUDA = 3, | |||
| INTEL = 4, | |||
| POWERVR = 5, | |||
| AMD = 6, | |||
| }; | |||
| protected: | |||
| Handle(megcoreComputingHandle_t computing_handle, HandleType type); | |||
| public: | |||
| /** | |||
| * \brief Create a MegDNN handle from a MegCore Computing handle. | |||
| * | |||
| * \param[in] computing_handle MegCore computing handle. Please note | |||
| * that computing_handle would not be released when this Handle is | |||
| * destructed | |||
| * \param[in] debug_level | |||
| * Applicable for CPU computing handle. | |||
| * 0 means taking the fastest possible code path; it may contains | |||
| * platform-specific instructions such as SSE for x86_64 or NEON for | |||
| * armv7v7. | |||
| * 1 means taking the fastest possible code path without | |||
| * platform-specific instructions in C++ code. Note that the compiled | |||
| * binary file still contains platform-specific codes. | |||
| * 2 means taking the naive code path. Performance is severely | |||
| * hampered, but it is less error-prone since the internal | |||
| * implementation is rather straightforward. | |||
| * | |||
| * **Debug level 1 and 2 should not be used in productions.** | |||
| */ | |||
| static std::unique_ptr<Handle> make( | |||
| megcoreComputingHandle_t computing_handle, | |||
| int debug_level = 0); | |||
| public: | |||
| enum class HandleType { | |||
| NAIVE = 0, | |||
| FALLBACK = 1, | |||
| X86 = 2, | |||
| ARM_COMMON = 3, | |||
| ARMV7 = 4, | |||
| AARCH64 = 5, | |||
| CUDA = 6, | |||
| ROCM = 11, | |||
| ATLAS = 13, | |||
| CAMBRICON = 12, | |||
| }; | |||
| //! Device vendor | |||
| enum class HandleVendorType : uint32_t { | |||
| NOT_SPEC = 0, | |||
| MALI = 1, | |||
| ADRENO = 2, | |||
| CUDA = 3, | |||
| INTEL = 4, | |||
| POWERVR = 5, | |||
| AMD = 6, | |||
| }; | |||
| protected: | |||
| Handle(megcoreComputingHandle_t computing_handle, HandleType type); | |||
| public: | |||
| /** | |||
| * \brief Create a MegDNN handle from a MegCore Computing handle. | |||
| * | |||
| * \param[in] computing_handle MegCore computing handle. Please note | |||
| * that computing_handle would not be released when this Handle is | |||
| * destructed | |||
| * \param[in] debug_level | |||
| * Applicable for CPU computing handle. | |||
| * 0 means taking the fastest possible code path; it may contains | |||
| * platform-specific instructions such as SSE for x86_64 or NEON for | |||
| * armv7v7. | |||
| * 1 means taking the fastest possible code path without | |||
| * platform-specific instructions in C++ code. Note that the compiled | |||
| * binary file still contains platform-specific codes. | |||
| * 2 means taking the naive code path. Performance is severely | |||
| * hampered, but it is less error-prone since the internal | |||
| * implementation is rather straightforward. | |||
| * | |||
| * **Debug level 1 and 2 should not be used in productions.** | |||
| */ | |||
| static std::unique_ptr<Handle> make( | |||
| megcoreComputingHandle_t computing_handle, int debug_level = 0); | |||
| #if MEGDNN_WITH_CUDA | |||
| static std::unique_ptr<Handle> make_cuda_handle( | |||
| megcoreComputingHandle_t computing_handle); | |||
| template <typename opr> | |||
| std::unique_ptr<opr> create_cuda_operator(); | |||
| static std::unique_ptr<Handle> make_cuda_handle( | |||
| megcoreComputingHandle_t computing_handle); | |||
| template <typename opr> | |||
| std::unique_ptr<opr> create_cuda_operator(); | |||
| #endif | |||
| #if MEGDNN_WITH_ROCM | |||
| static std::unique_ptr<Handle> make_rocm_handle( | |||
| megcoreComputingHandle_t computing_handle); | |||
| template <typename opr> | |||
| std::unique_ptr<opr> create_rocm_operator(); | |||
| static std::unique_ptr<Handle> make_rocm_handle( | |||
| megcoreComputingHandle_t computing_handle); | |||
| template <typename opr> | |||
| std::unique_ptr<opr> create_rocm_operator(); | |||
| #endif | |||
| virtual ~Handle(); | |||
| /*! | |||
| * \brief Get the underlying megcore computing handle. | |||
| */ | |||
| megcoreComputingHandle_t megcore_computing_handle() const { | |||
| return m_computing_handle; | |||
| } | |||
| /*! | |||
| * \brief set a callback function to be invoked when this handle is | |||
| * destructed, so associated resources can be released (e.g. | |||
| * computing handle) | |||
| * | |||
| * This function can be called at most once. | |||
| */ | |||
| void set_destructor(const thin_function<void()> &d); | |||
| /*! | |||
| * \brief set a callback to be invoked when an operator is destructed | |||
| * \param[in,out] cb the callback function; it would be set to the | |||
| * previous callback function | |||
| */ | |||
| void set_opr_destruct_callback(thin_function<void(OperatorBase*)> &cb) { | |||
| cb.swap(m_on_opr_destructed); | |||
| } | |||
| void on_opr_destructed(OperatorBase* opr); | |||
| /** | |||
| * \brief Create operator of Opr type. | |||
| */ | |||
| template <typename Opr> | |||
| std::unique_ptr<Opr> create_operator(); | |||
| /* | |||
| * ============================================================= | |||
| * Users should call functions below to query memory requirement. | |||
| * ============================================================= | |||
| */ | |||
| /** | |||
| * \brief The internal data pointer of TensorND should be aligned to | |||
| * alignment_requirement() in bytes. | |||
| */ | |||
| virtual size_t alignment_requirement() const; | |||
| //! get alignment in bytes for rows of image 2D tensor format | |||
| virtual size_t image2d_pitch_alignment() const; | |||
| //! get vendor type | |||
| virtual HandleVendorType vendor_type() const; | |||
| HandleType type() const { | |||
| return m_handle_type; | |||
| } | |||
| /** | |||
| * \brief Check is the layout satisfy cross device copy constraint. | |||
| * 1. The handle of the src and the dst is the same kind | |||
| * 2. The dst is continguous. | |||
| */ | |||
| virtual bool check_cross_dev_copy_constraint(const TensorLayout &src); | |||
| private: | |||
| static constexpr uint32_t ALIVE_MAGIC = 0x8595e9d2u; | |||
| volatile uint32_t m_alive_magic = ALIVE_MAGIC; | |||
| megcoreComputingHandle_t m_computing_handle; | |||
| const HandleType m_handle_type; | |||
| thin_function<void()> m_destructor; | |||
| thin_function<void(OperatorBase*)> m_on_opr_destructed; | |||
| Handle() = delete; | |||
| Handle(const Handle &rhs) = delete; | |||
| Handle &operator=(const Handle &rhs) = delete; | |||
| virtual ~Handle(); | |||
| /*! | |||
| * \brief Get the underlying megcore computing handle. | |||
| */ | |||
| megcoreComputingHandle_t megcore_computing_handle() const { | |||
| return m_computing_handle; | |||
| } | |||
| /*! | |||
| * \brief set a callback function to be invoked when this handle is | |||
| * destructed, so associated resources can be released (e.g. | |||
| * computing handle) | |||
| * | |||
| * This function can be called at most once. | |||
| */ | |||
| void set_destructor(const thin_function<void()>& d); | |||
| /*! | |||
| * \brief set a callback to be invoked when an operator is destructed | |||
| * \param[in,out] cb the callback function; it would be set to the | |||
| * previous callback function | |||
| */ | |||
| void set_opr_destruct_callback(thin_function<void(OperatorBase*)>& cb) { | |||
| cb.swap(m_on_opr_destructed); | |||
| } | |||
| void on_opr_destructed(OperatorBase* opr); | |||
| /** | |||
| * \brief Create operator of Opr type. | |||
| */ | |||
| template <typename Opr> | |||
| std::unique_ptr<Opr> create_operator(); | |||
| /* | |||
| * ============================================================= | |||
| * Users should call functions below to query memory requirement. | |||
| * ============================================================= | |||
| */ | |||
| /** | |||
| * \brief The internal data pointer of TensorND should be aligned to | |||
| * alignment_requirement() in bytes. | |||
| */ | |||
| virtual size_t alignment_requirement() const; | |||
| //! get alignment in bytes for rows of image 2D tensor format | |||
| virtual size_t image2d_pitch_alignment() const; | |||
| //! get vendor type | |||
| virtual HandleVendorType vendor_type() const; | |||
| HandleType type() const { return m_handle_type; } | |||
| /** | |||
| * \brief Check is the layout satisfy cross device copy constraint. | |||
| * 1. The handle of the src and the dst is the same kind | |||
| * 2. The dst is continguous. | |||
| */ | |||
| virtual bool check_cross_dev_copy_constraint(const TensorLayout& src); | |||
| private: | |||
| static constexpr uint32_t ALIVE_MAGIC = 0x8595e9d2u; | |||
| volatile uint32_t m_alive_magic = ALIVE_MAGIC; | |||
| megcoreComputingHandle_t m_computing_handle; | |||
| const HandleType m_handle_type; | |||
| thin_function<void()> m_destructor; | |||
| thin_function<void(OperatorBase*)> m_on_opr_destructed; | |||
| Handle() = delete; | |||
| Handle(const Handle& rhs) = delete; | |||
| Handle& operator=(const Handle& rhs) = delete; | |||
| }; | |||
| } // namespace megdnn | |||
| } // namespace megdnn | |||
| #include "megdnn/internal/visibility_epilogue.h" | |||
| @@ -49,8 +49,9 @@ public: | |||
| mutable std::string m_input; | |||
| public: | |||
| Key(Handle* opr_handle, Algorithm::OprType opr_type, const TensorLayout* inp_layouts_ptr, | |||
| size_t inp_layouts_size, const void* param_ptr = nullptr, size_t param_size = 0) | |||
| Key(Handle* opr_handle, Algorithm::OprType opr_type, | |||
| const TensorLayout* inp_layouts_ptr, size_t inp_layouts_size, | |||
| const void* param_ptr = nullptr, size_t param_size = 0) | |||
| : m_handle{opr_handle}, | |||
| m_opr_type{static_cast<uint32_t>(opr_type)}, | |||
| m_inp_layouts_ptr{inp_layouts_ptr}, | |||
| @@ -16,20 +16,19 @@ | |||
| * \brief iterate through small (usually used) ndim values | |||
| */ | |||
| #define MEGDNN_FOREACH_TENSOR_NDIM_SMALL(cb, ...) \ | |||
| cb(1 ,##__VA_ARGS__) cb(2 ,##__VA_ARGS__) cb(3 ,##__VA_ARGS__) | |||
| cb(1, ##__VA_ARGS__) cb(2, ##__VA_ARGS__) cb(3, ##__VA_ARGS__) | |||
| /*! | |||
| * \brief iterate through large (rarely used) ndim values | |||
| */ | |||
| #define MEGDNN_FOREACH_TENSOR_NDIM_LARGE(cb, ...) \ | |||
| cb(4 ,##__VA_ARGS__) cb(5 ,##__VA_ARGS__) cb(6 ,##__VA_ARGS__) \ | |||
| cb(7, ##__VA_ARGS__) | |||
| cb(4, ##__VA_ARGS__) cb(5, ##__VA_ARGS__) cb(6, ##__VA_ARGS__) cb(7, ##__VA_ARGS__) | |||
| /*! | |||
| * \brief iterate through all ndim values | |||
| */ | |||
| #define MEGDNN_FOREACH_TENSOR_NDIM(cb, ...) \ | |||
| MEGDNN_FOREACH_TENSOR_NDIM_SMALL(cb ,##__VA_ARGS__) \ | |||
| MEGDNN_FOREACH_TENSOR_NDIM_LARGE(cb ,##__VA_ARGS__) | |||
| #define MEGDNN_FOREACH_TENSOR_NDIM(cb, ...) \ | |||
| MEGDNN_FOREACH_TENSOR_NDIM_SMALL(cb, ##__VA_ARGS__) \ | |||
| MEGDNN_FOREACH_TENSOR_NDIM_LARGE(cb, ##__VA_ARGS__) | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -11,14 +11,14 @@ | |||
| // intentional no header guard here | |||
| #include "megdnn/handle.h" | |||
| #include "megdnn/oprs/base.h" | |||
| #include "megdnn/opr_param_defs.h" | |||
| #include "megdnn/opr_result_defs.h" | |||
| #include "megdnn/oprs/base.h" | |||
| #include "./visibility_prologue.h" | |||
| #include <limits> | |||
| #include <array> | |||
| #include <limits> | |||
| #ifndef _megdnn_in | |||
| #define _megdnn_in | |||
| @@ -29,36 +29,37 @@ | |||
| #endif | |||
| #ifndef _megdnn_tensor_in | |||
| #define _megdnn_tensor_in const TensorND & | |||
| #define _megdnn_tensor_in const TensorND& | |||
| #endif | |||
| #ifndef _megdnn_tensor_out | |||
| #define _megdnn_tensor_out const TensorND & | |||
| #define _megdnn_tensor_out const TensorND& | |||
| #endif | |||
| #ifndef _megdnn_tensor_inout | |||
| #define _megdnn_tensor_inout const TensorND & | |||
| #define _megdnn_tensor_inout const TensorND& | |||
| #endif | |||
| #ifndef _megdnn_workspace | |||
| #define _megdnn_workspace const Workspace & | |||
| #define _megdnn_workspace const Workspace& | |||
| #endif | |||
| #define DEF_OPR_IMPL_CTOR(_opr_name, _base_name) \ | |||
| public: \ | |||
| _opr_name(Handle *handle): _base_name(handle) {} \ | |||
| #define DEF_OPR_IMPL_CTOR(_opr_name, _base_name) \ | |||
| public: \ | |||
| _opr_name(Handle* handle) : _base_name(handle) {} | |||
| #define DEF_OPR_IMPL(_opr_name, _base_name, _nr_inputs, _nr_outputs) \ | |||
| DEF_OPR_IMPL_CTOR(_opr_name, _base_name) \ | |||
| static MEGDNN_CONSTEXPR int NR_INPUTS = _nr_inputs; \ | |||
| static MEGDNN_CONSTEXPR int NR_OUTPUTS = _nr_outputs; \ | |||
| DEF_OPR_IMPL_CTOR(_opr_name, _base_name) \ | |||
| static MEGDNN_CONSTEXPR int NR_INPUTS = _nr_inputs; \ | |||
| static MEGDNN_CONSTEXPR int NR_OUTPUTS = _nr_outputs; | |||
| #define DEF_OPR_PARAM(_pname) \ | |||
| public: \ | |||
| using Param = param::_pname; \ | |||
| Param& param() { return m_param; } \ | |||
| const Param& param() const { return m_param; } \ | |||
| protected: \ | |||
| Param m_param | |||
| #define DEF_OPR_PARAM(_pname) \ | |||
| public: \ | |||
| using Param = param::_pname; \ | |||
| Param& param() { return m_param; } \ | |||
| const Param& param() const { return m_param; } \ | |||
| \ | |||
| protected: \ | |||
| Param m_param | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -20,4 +20,3 @@ | |||
| #endif | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -16,25 +16,21 @@ | |||
| namespace megdnn { | |||
| namespace opr_result { | |||
| struct Checksum { | |||
| uint32_t checksum; | |||
| union { | |||
| int32_t iv; | |||
| float fv; | |||
| } last_val; | |||
| bool operator == (const Checksum &rhs) const { | |||
| return checksum == rhs.checksum && | |||
| last_val.iv == rhs.last_val.iv; | |||
| } | |||
| bool operator != (const Checksum &rhs) const { | |||
| return !operator==(rhs); | |||
| } | |||
| }; | |||
| } // namespace opr_result | |||
| } // namespace megdnn | |||
| struct Checksum { | |||
| uint32_t checksum; | |||
| union { | |||
| int32_t iv; | |||
| float fv; | |||
| } last_val; | |||
| bool operator==(const Checksum& rhs) const { | |||
| return checksum == rhs.checksum && last_val.iv == rhs.last_val.iv; | |||
| } | |||
| bool operator!=(const Checksum& rhs) const { return !operator==(rhs); } | |||
| }; | |||
| } // namespace opr_result | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -12,11 +12,11 @@ | |||
| #include "megdnn/oprs/cv.h" | |||
| #include "megdnn/oprs/general.h" | |||
| #include "megdnn/oprs/imgproc.h" | |||
| #include "megdnn/oprs/linalg.h" | |||
| #include "megdnn/oprs/nn.h" | |||
| #include "megdnn/oprs/nn_int.h" | |||
| #include "megdnn/oprs/imgproc.h" | |||
| #include "megdnn/oprs/utils.h" | |||
| #include "megdnn/oprs/linalg.h" | |||
| template <typename Opr> | |||
| struct OprArityTrait; | |||
| @@ -53,6 +53,4 @@ INST_ARITY(megdnn::PoolingBackward, 3, 1); | |||
| #undef INST_ARITY | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -90,7 +90,7 @@ enum class AlgoDataType : uint32_t { | |||
| INT8X8X16 = 1 << 4, | |||
| INT16X16X32 = 1 << 5, | |||
| INT4X4X16 = 1 << 6, | |||
| QINT4x4x32 = 1 << 7, | |||
| QINT4x4x32 = 1 << 7, | |||
| }; | |||
| /*! | |||
| @@ -195,16 +195,16 @@ public: | |||
| Handle::HandleType handle_type() const { return m_handle_type; } | |||
| Info::Desc desc() const { return {handle_type(), type(), param(), name()}; } | |||
| Info info() const { | |||
| return {desc(), attribute()}; | |||
| } | |||
| Info info() const { return {desc(), attribute()}; } | |||
| template <typename T> | |||
| static void serialize_write_pod(const T& val, std::string& result) { | |||
| static_assert(std::is_trivially_copyable<T>::value, | |||
| "type should be trivially copyable"); | |||
| static_assert(!std::is_pointer<T>::value, | |||
| "serialize pointer is unsafe in eager execution mode"); | |||
| static_assert( | |||
| std::is_trivially_copyable<T>::value, | |||
| "type should be trivially copyable"); | |||
| static_assert( | |||
| !std::is_pointer<T>::value, | |||
| "serialize pointer is unsafe in eager execution mode"); | |||
| result.append(reinterpret_cast<const char*>(&val), sizeof(T)); | |||
| } | |||
| @@ -231,9 +231,8 @@ public: | |||
| return ret; | |||
| } | |||
| static std::string deserialize_read_pod(const std::string& data, | |||
| size_t offset = 0, | |||
| size_t size = 0) { | |||
| static std::string deserialize_read_pod( | |||
| const std::string& data, size_t offset = 0, size_t size = 0) { | |||
| return std::string(data.data() + offset, size); | |||
| } | |||
| @@ -286,8 +285,8 @@ public: | |||
| * \param layouts origin layouts of the parent opr | |||
| * \param opr parent opr | |||
| */ | |||
| virtual std::vector<SearchItem> get_subopr_list(const TensorLayoutArray&, | |||
| const OperatorBase*) const { | |||
| virtual std::vector<SearchItem> get_subopr_list( | |||
| const TensorLayoutArray&, const OperatorBase*) const { | |||
| return {}; | |||
| } | |||
| @@ -333,9 +332,7 @@ public: | |||
| ExecutionPolicy& execution_policy() { return m_execution_policy; } | |||
| const ExecutionPolicy& execution_policy() const { | |||
| return m_execution_policy; | |||
| } | |||
| const ExecutionPolicy& execution_policy() const { return m_execution_policy; } | |||
| virtual Algorithm* get_algorithm_from_desc(const AlgorithmDesc&) = 0; | |||
| @@ -355,8 +352,8 @@ public: | |||
| using AlgoAttribute = detail::Algorithm::Attribute; | |||
| //! get all possible algorithm decriptions for the specified layouts | |||
| std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0, | |||
| const TensorLayout& p1) { | |||
| std::vector<AlgorithmInfo> get_all_algorithms_info( | |||
| const TensorLayout& p0, const TensorLayout& p1) { | |||
| std::vector<AlgorithmInfo> ret; | |||
| for (auto&& algo : get_all_algorithms(p0, p1)) { | |||
| ret.emplace_back(algo->info()); | |||
| @@ -364,8 +361,8 @@ public: | |||
| return ret; | |||
| } | |||
| std::vector<AlgorithmInfo> get_all_algorithms_info_safe(const TensorLayout& p0, | |||
| const TensorLayout& p1) { | |||
| std::vector<AlgorithmInfo> get_all_algorithms_info_safe( | |||
| const TensorLayout& p0, const TensorLayout& p1) { | |||
| std::vector<AlgorithmInfo> ret; | |||
| for (auto&& algo : get_all_algorithms_safe(p0, p1)) { | |||
| ret.emplace_back(algo->info()); | |||
| @@ -382,12 +379,11 @@ public: | |||
| */ | |||
| AlgorithmInfo get_algorithm_info_heuristic( | |||
| const TensorLayout& p0, const TensorLayout& p1, | |||
| size_t workspace_limit_in_bytes = | |||
| std::numeric_limits<size_t>::max(), | |||
| size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(), | |||
| const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | |||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { | |||
| return get_algorithm_heuristic(p0, p1, workspace_limit_in_bytes, | |||
| positive_attr, negative_attr) | |||
| return get_algorithm_heuristic( | |||
| p0, p1, workspace_limit_in_bytes, positive_attr, negative_attr) | |||
| ->info(); | |||
| } | |||
| @@ -408,8 +404,7 @@ protected: | |||
| */ | |||
| virtual Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& p0, const TensorLayout& p1, | |||
| size_t workspace_limit_in_bytes = | |||
| std::numeric_limits<size_t>::max(), | |||
| size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(), | |||
| const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | |||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; | |||
| }; | |||
| @@ -423,9 +418,8 @@ public: | |||
| using AlgoAttribute = detail::Algorithm::Attribute; | |||
| //! get all possible algorithm decriptions for the specified layouts | |||
| std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0, | |||
| const TensorLayout& p1, | |||
| const TensorLayout& p2) { | |||
| std::vector<AlgorithmInfo> get_all_algorithms_info( | |||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2) { | |||
| std::vector<AlgorithmInfo> ret; | |||
| for (auto&& algo : get_all_algorithms(p0, p1, p2)) { | |||
| ret.emplace_back(algo->info()); | |||
| @@ -433,9 +427,8 @@ public: | |||
| return ret; | |||
| } | |||
| std::vector<AlgorithmInfo> get_all_algorithms_info_safe(const TensorLayout& p0, | |||
| const TensorLayout& p1, | |||
| const TensorLayout& p2) { | |||
| std::vector<AlgorithmInfo> get_all_algorithms_info_safe( | |||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2) { | |||
| std::vector<AlgorithmInfo> ret; | |||
| for (auto&& algo : get_all_algorithms_safe(p0, p1, p2)) { | |||
| ret.emplace_back(algo->info()); | |||
| @@ -451,14 +444,13 @@ public: | |||
| * \p workspace_limit_in_bytes. | |||
| */ | |||
| AlgorithmInfo get_algorithm_info_heuristic( | |||
| const TensorLayout& p0, const TensorLayout& p1, | |||
| const TensorLayout& p2, | |||
| size_t workspace_limit_in_bytes = | |||
| std::numeric_limits<size_t>::max(), | |||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||
| size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(), | |||
| const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | |||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { | |||
| return get_algorithm_heuristic(p0, p1, p2, workspace_limit_in_bytes, | |||
| positive_attr, negative_attr) | |||
| return get_algorithm_heuristic( | |||
| p0, p1, p2, workspace_limit_in_bytes, positive_attr, | |||
| negative_attr) | |||
| ->info(); | |||
| } | |||
| @@ -467,11 +459,9 @@ protected: | |||
| //! get all possible algorithms for the specified layouts | |||
| virtual std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& p0, const TensorLayout& p1, | |||
| const TensorLayout& p2) = 0; | |||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2) = 0; | |||
| virtual std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& p0, const TensorLayout& p1, | |||
| const TensorLayout& p2) = 0; | |||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2) = 0; | |||
| /** | |||
| * \brief Returns the best algorithm by heuristic. | |||
| @@ -480,10 +470,8 @@ protected: | |||
| * \p workspace_limit_in_bytes. | |||
| */ | |||
| virtual Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& p0, const TensorLayout& p1, | |||
| const TensorLayout& p2, | |||
| size_t workspace_limit_in_bytes = | |||
| std::numeric_limits<size_t>::max(), | |||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||
| size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(), | |||
| const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | |||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; | |||
| }; | |||
| @@ -497,10 +485,9 @@ public: | |||
| using AlgoAttribute = detail::Algorithm::Attribute; | |||
| //! get all possible algorithm decriptions for the specified layouts | |||
| std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0, | |||
| const TensorLayout& p1, | |||
| const TensorLayout& p2, | |||
| const TensorLayout& p3) { | |||
| std::vector<AlgorithmInfo> get_all_algorithms_info( | |||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||
| const TensorLayout& p3) { | |||
| std::vector<AlgorithmInfo> ret; | |||
| for (auto&& algo : get_all_algorithms(p0, p1, p2, p3)) { | |||
| ret.emplace_back(algo->info()); | |||
| @@ -508,10 +495,9 @@ public: | |||
| return ret; | |||
| } | |||
| std::vector<AlgorithmInfo> get_all_algorithms_info_safe(const TensorLayout& p0, | |||
| const TensorLayout& p1, | |||
| const TensorLayout& p2, | |||
| const TensorLayout& p3) { | |||
| std::vector<AlgorithmInfo> get_all_algorithms_info_safe( | |||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||
| const TensorLayout& p3) { | |||
| std::vector<AlgorithmInfo> ret; | |||
| for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3)) { | |||
| ret.emplace_back(algo->info()); | |||
| @@ -527,14 +513,14 @@ public: | |||
| * \p workspace_limit_in_bytes. | |||
| */ | |||
| AlgorithmInfo get_algorithm_info_heuristic( | |||
| const TensorLayout& p0, const TensorLayout& p1, | |||
| const TensorLayout& p2, const TensorLayout& p3, | |||
| size_t workspace_limit_in_bytes = | |||
| std::numeric_limits<size_t>::max(), | |||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||
| const TensorLayout& p3, | |||
| size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(), | |||
| const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | |||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { | |||
| return get_algorithm_heuristic(p0, p1, p2, p3, workspace_limit_in_bytes, | |||
| positive_attr, negative_attr) | |||
| return get_algorithm_heuristic( | |||
| p0, p1, p2, p3, workspace_limit_in_bytes, positive_attr, | |||
| negative_attr) | |||
| ->info(); | |||
| } | |||
| @@ -543,11 +529,11 @@ protected: | |||
| //! get all possible algorithms for the specified layouts | |||
| virtual std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& p0, const TensorLayout& p1, | |||
| const TensorLayout& p2, const TensorLayout& p3) = 0; | |||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||
| const TensorLayout& p3) = 0; | |||
| virtual std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& p0, const TensorLayout& p1, | |||
| const TensorLayout& p2, const TensorLayout& p3) = 0; | |||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||
| const TensorLayout& p3) = 0; | |||
| /** | |||
| * \brief Returns the best algorithm by heuristic. | |||
| @@ -556,10 +542,9 @@ protected: | |||
| * \p workspace_limit_in_bytes. | |||
| */ | |||
| virtual Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& p0, const TensorLayout& p1, | |||
| const TensorLayout& p2, const TensorLayout& p3, | |||
| size_t workspace_limit_in_bytes = | |||
| std::numeric_limits<size_t>::max(), | |||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||
| const TensorLayout& p3, | |||
| size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(), | |||
| const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | |||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; | |||
| }; | |||
| @@ -573,11 +558,9 @@ public: | |||
| using AlgoAttribute = detail::Algorithm::Attribute; | |||
| //! get all possible algorithm decriptions for the specified layouts | |||
| std::vector<AlgorithmInfo> get_all_algorithms_info(const TensorLayout& p0, | |||
| const TensorLayout& p1, | |||
| const TensorLayout& p2, | |||
| const TensorLayout& p3, | |||
| const TensorLayout& p4) { | |||
| std::vector<AlgorithmInfo> get_all_algorithms_info( | |||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||
| const TensorLayout& p3, const TensorLayout& p4) { | |||
| std::vector<AlgorithmInfo> ret; | |||
| for (auto&& algo : get_all_algorithms(p0, p1, p2, p3, p4)) { | |||
| ret.emplace_back(algo->info()); | |||
| @@ -585,11 +568,9 @@ public: | |||
| return ret; | |||
| } | |||
| std::vector<AlgorithmInfo> get_all_algorithms_info_safe(const TensorLayout& p0, | |||
| const TensorLayout& p1, | |||
| const TensorLayout& p2, | |||
| const TensorLayout& p3, | |||
| const TensorLayout& p4) { | |||
| std::vector<AlgorithmInfo> get_all_algorithms_info_safe( | |||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||
| const TensorLayout& p3, const TensorLayout& p4) { | |||
| std::vector<AlgorithmInfo> ret; | |||
| for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3, p4)) { | |||
| ret.emplace_back(algo->info()); | |||
| @@ -605,16 +586,14 @@ public: | |||
| * \p workspace_limit_in_bytes. | |||
| */ | |||
| AlgorithmInfo get_algorithm_info_heuristic( | |||
| const TensorLayout& p0, const TensorLayout& p1, | |||
| const TensorLayout& p2, const TensorLayout& p3, | |||
| const TensorLayout& p4, | |||
| size_t workspace_limit_in_bytes = | |||
| std::numeric_limits<size_t>::max(), | |||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||
| const TensorLayout& p3, const TensorLayout& p4, | |||
| size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(), | |||
| const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | |||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { | |||
| return get_algorithm_heuristic(p0, p1, p2, p3, p4, | |||
| workspace_limit_in_bytes, positive_attr, | |||
| negative_attr) | |||
| return get_algorithm_heuristic( | |||
| p0, p1, p2, p3, p4, workspace_limit_in_bytes, positive_attr, | |||
| negative_attr) | |||
| ->info(); | |||
| } | |||
| @@ -622,14 +601,12 @@ protected: | |||
| ~MultiAlgoOpr() = default; | |||
| //! get all possible algorithms for the specified layouts | |||
| virtual std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& p0, const TensorLayout& p1, | |||
| const TensorLayout& p2, const TensorLayout& p3, | |||
| const TensorLayout& p4) = 0; | |||
| virtual std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||
| const TensorLayout& p3, const TensorLayout& p4) = 0; | |||
| virtual std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& p0, const TensorLayout& p1, | |||
| const TensorLayout& p2, const TensorLayout& p3, | |||
| const TensorLayout& p4) = 0; | |||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||
| const TensorLayout& p3, const TensorLayout& p4) = 0; | |||
| /** | |||
| * \brief Returns the best algorithm by heuristic. | |||
| @@ -638,11 +615,9 @@ protected: | |||
| * \p workspace_limit_in_bytes. | |||
| */ | |||
| virtual Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& p0, const TensorLayout& p1, | |||
| const TensorLayout& p2, const TensorLayout& p3, | |||
| const TensorLayout& p4, | |||
| size_t workspace_limit_in_bytes = | |||
| std::numeric_limits<size_t>::max(), | |||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||
| const TensorLayout& p3, const TensorLayout& p4, | |||
| size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(), | |||
| const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | |||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; | |||
| }; | |||
| @@ -657,9 +632,8 @@ public: | |||
| //! get all possible algorithm decriptions for the specified layouts | |||
| std::vector<AlgorithmInfo> get_all_algorithms_info( | |||
| const TensorLayout& p0, const TensorLayout& p1, | |||
| const TensorLayout& p2, const TensorLayout& p3, | |||
| const TensorLayout& p4, const TensorLayout& p5, | |||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||
| const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5, | |||
| const TensorLayout& p6, const TensorLayout& p7) { | |||
| std::vector<AlgorithmInfo> ret; | |||
| for (auto&& algo : get_all_algorithms(p0, p1, p2, p3, p4, p5, p6, p7)) { | |||
| @@ -669,9 +643,8 @@ public: | |||
| } | |||
| std::vector<AlgorithmInfo> get_all_algorithms_info_safe( | |||
| const TensorLayout& p0, const TensorLayout& p1, | |||
| const TensorLayout& p2, const TensorLayout& p3, | |||
| const TensorLayout& p4, const TensorLayout& p5, | |||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||
| const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5, | |||
| const TensorLayout& p6, const TensorLayout& p7) { | |||
| std::vector<AlgorithmInfo> ret; | |||
| for (auto&& algo : get_all_algorithms_safe(p0, p1, p2, p3, p4, p5, p6, p7)) { | |||
| @@ -687,17 +660,15 @@ public: | |||
| * The selected algorithm should not use workspace more than | |||
| */ | |||
| AlgorithmInfo get_algorithm_info_heuristic( | |||
| const TensorLayout& p0, const TensorLayout& p1, | |||
| const TensorLayout& p2, const TensorLayout& p3, | |||
| const TensorLayout& p4, const TensorLayout& p5, | |||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||
| const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5, | |||
| const TensorLayout& p6, const TensorLayout& p7, | |||
| size_t workspace_limit_in_bytes = | |||
| std::numeric_limits<size_t>::max(), | |||
| size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(), | |||
| const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | |||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) { | |||
| return get_algorithm_heuristic(p0, p1, p2, p3, p4, p5, p6, p7, | |||
| workspace_limit_in_bytes, positive_attr, | |||
| negative_attr) | |||
| return get_algorithm_heuristic( | |||
| p0, p1, p2, p3, p4, p5, p6, p7, workspace_limit_in_bytes, | |||
| positive_attr, negative_attr) | |||
| ->info(); | |||
| } | |||
| @@ -705,15 +676,13 @@ protected: | |||
| ~MultiAlgoOpr() = default; | |||
| //! get all possible algorithms for the specified layouts | |||
| virtual std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& p0, const TensorLayout& p1, | |||
| const TensorLayout& p2, const TensorLayout& p3, | |||
| const TensorLayout& p4, const TensorLayout& p5, | |||
| virtual std::vector<Algorithm*> get_all_algorithms( | |||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||
| const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5, | |||
| const TensorLayout& p6, const TensorLayout& p7) = 0; | |||
| virtual std::vector<Algorithm*> get_all_algorithms_safe( | |||
| const TensorLayout& p0, const TensorLayout& p1, | |||
| const TensorLayout& p2, const TensorLayout& p3, | |||
| const TensorLayout& p4, const TensorLayout& p5, | |||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||
| const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5, | |||
| const TensorLayout& p6, const TensorLayout& p7) = 0; | |||
| /** | |||
| @@ -723,12 +692,10 @@ protected: | |||
| * \p workspace_limit_in_bytes. | |||
| */ | |||
| virtual Algorithm* get_algorithm_heuristic( | |||
| const TensorLayout& p0, const TensorLayout& p1, | |||
| const TensorLayout& p2, const TensorLayout& p3, | |||
| const TensorLayout& p4, const TensorLayout& p5, | |||
| const TensorLayout& p0, const TensorLayout& p1, const TensorLayout& p2, | |||
| const TensorLayout& p3, const TensorLayout& p4, const TensorLayout& p5, | |||
| const TensorLayout& p6, const TensorLayout& p7, | |||
| size_t workspace_limit_in_bytes = | |||
| std::numeric_limits<size_t>::max(), | |||
| size_t workspace_limit_in_bytes = std::numeric_limits<size_t>::max(), | |||
| const AlgoAttribute& positive_attr = AlgoAttribute::DEFAULT, | |||
| const AlgoAttribute& negative_attr = AlgoAttribute::DEFAULT) = 0; | |||
| }; | |||
| @@ -31,15 +31,17 @@ class FlipForward : public FlipBase { | |||
| DEF_OPR_IMPL(FlipForward, FlipBase, 1, 1); | |||
| public: | |||
| virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) = 0; | |||
| virtual void exec( | |||
| _megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) = 0; | |||
| void deduce_layout(const TensorLayout& src, TensorLayout& dst); | |||
| virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||
| const TensorLayout& dst) = 0; | |||
| virtual size_t get_workspace_in_bytes( | |||
| const TensorLayout& src, const TensorLayout& dst) = 0; | |||
| protected: | |||
| void check_exec(const TensorLayout& src, const TensorLayout& dst, | |||
| size_t workspace_in_bytes); | |||
| void check_exec( | |||
| const TensorLayout& src, const TensorLayout& dst, | |||
| size_t workspace_in_bytes); | |||
| }; | |||
| using Flip = FlipForward; | |||
| @@ -56,15 +58,17 @@ class RotateForward : public RotateBase { | |||
| DEF_OPR_IMPL(RotateForward, RotateBase, 1, 1); | |||
| public: | |||
| virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) = 0; | |||
| virtual void exec( | |||
| _megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) = 0; | |||
| void deduce_layout(const TensorLayout& src, TensorLayout& dst); | |||
| virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||
| const TensorLayout& dst) = 0; | |||
| virtual size_t get_workspace_in_bytes( | |||
| const TensorLayout& src, const TensorLayout& dst) = 0; | |||
| protected: | |||
| void check_exec(const TensorLayout& src, const TensorLayout& dst, | |||
| size_t workspace_in_bytes); | |||
| void check_exec( | |||
| const TensorLayout& src, const TensorLayout& dst, | |||
| size_t workspace_in_bytes); | |||
| }; | |||
| using Rotate = RotateForward; | |||
| @@ -81,15 +85,17 @@ class ROICopyForward : public ROICopyBase { | |||
| DEF_OPR_IMPL(ROICopyForward, ROICopyBase, 1, 1); | |||
| public: | |||
| virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) = 0; | |||
| virtual void exec( | |||
| _megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) = 0; | |||
| void deduce_layout(const TensorLayout& src, TensorLayout& dst); | |||
| virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||
| const TensorLayout& dst) = 0; | |||
| virtual size_t get_workspace_in_bytes( | |||
| const TensorLayout& src, const TensorLayout& dst) = 0; | |||
| protected: | |||
| void check_exec(const TensorLayout& src, const TensorLayout& dst, | |||
| size_t workspace_in_bytes); | |||
| void check_exec( | |||
| const TensorLayout& src, const TensorLayout& dst, | |||
| size_t workspace_in_bytes); | |||
| }; | |||
| using ROICopy = ROICopyForward; | |||
| @@ -106,15 +112,17 @@ class CvtColorForward : public CvtColorBase { | |||
| DEF_OPR_IMPL(CvtColorForward, CvtColorBase, 1, 1); | |||
| public: | |||
| virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) = 0; | |||
| virtual void exec( | |||
| _megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) = 0; | |||
| void deduce_layout(const TensorLayout& src, TensorLayout& dst); | |||
| virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||
| const TensorLayout& dst) = 0; | |||
| virtual size_t get_workspace_in_bytes( | |||
| const TensorLayout& src, const TensorLayout& dst) = 0; | |||
| protected: | |||
| void check_exec(const TensorLayout& src, const TensorLayout& dst, | |||
| size_t workspace_in_bytes); | |||
| void check_exec( | |||
| const TensorLayout& src, const TensorLayout& dst, | |||
| size_t workspace_in_bytes); | |||
| }; | |||
| using CvtColor = CvtColorForward; | |||
| @@ -130,8 +138,9 @@ public: | |||
| using BorderMode = Param::BorderMode; | |||
| protected: | |||
| void check_layout_fwd(const TensorLayout& src, const TensorLayout& trans, | |||
| const TensorLayout& dst); | |||
| void check_layout_fwd( | |||
| const TensorLayout& src, const TensorLayout& trans, | |||
| const TensorLayout& dst); | |||
| std::string param_msg() const; | |||
| int get_real_coord(int p, int len); | |||
| }; | |||
| @@ -148,15 +157,17 @@ public: | |||
| * \warning src, trans, border_value, dst should be contiguous | |||
| * The size of trans is N * 2 * 3 | |||
| */ | |||
| virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in trans, | |||
| _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; | |||
| virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||
| const TensorLayout& trans, | |||
| const TensorLayout& dst) = 0; | |||
| virtual void exec( | |||
| _megdnn_tensor_in src, _megdnn_tensor_in trans, _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) = 0; | |||
| virtual size_t get_workspace_in_bytes( | |||
| const TensorLayout& src, const TensorLayout& trans, | |||
| const TensorLayout& dst) = 0; | |||
| protected: | |||
| void check_exec(const TensorLayout& src, const TensorLayout& trans, | |||
| const TensorLayout& dst, size_t workspace_in_bytes); | |||
| void check_exec( | |||
| const TensorLayout& src, const TensorLayout& trans, const TensorLayout& dst, | |||
| size_t workspace_in_bytes); | |||
| }; | |||
| using WarpAffine = WarpAffineForward; | |||
| @@ -173,15 +184,17 @@ class GaussianBlurForward : public GaussianBlurBase { | |||
| DEF_OPR_IMPL(GaussianBlurForward, GaussianBlurBase, 1, 1); | |||
| public: | |||
| virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) = 0; | |||
| virtual void exec( | |||
| _megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) = 0; | |||
| void deduce_layout(const TensorLayout& src, TensorLayout& dst); | |||
| virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||
| const TensorLayout& dst) = 0; | |||
| virtual size_t get_workspace_in_bytes( | |||
| const TensorLayout& src, const TensorLayout& dst) = 0; | |||
| protected: | |||
| void check_exec(const TensorLayout& src, const TensorLayout& dst, | |||
| size_t workspace_in_bytes); | |||
| void check_exec( | |||
| const TensorLayout& src, const TensorLayout& dst, | |||
| size_t workspace_in_bytes); | |||
| }; | |||
| using GaussianBlur = GaussianBlurForward; | |||
| @@ -212,15 +225,17 @@ class ResizeForward : public ResizeBase { | |||
| DEF_OPR_IMPL(ResizeForward, ResizeBase, 1, 1); | |||
| public: | |||
| virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) = 0; | |||
| virtual void exec( | |||
| _megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) = 0; | |||
| virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||
| const TensorLayout& dst) = 0; | |||
| virtual size_t get_workspace_in_bytes( | |||
| const TensorLayout& src, const TensorLayout& dst) = 0; | |||
| protected: | |||
| void check_exec(const TensorLayout& src, const TensorLayout& dst, | |||
| size_t workspace_in_bytes); | |||
| void check_exec( | |||
| const TensorLayout& src, const TensorLayout& dst, | |||
| size_t workspace_in_bytes); | |||
| }; | |||
| using Resize = ResizeForward; | |||
| @@ -228,15 +243,17 @@ class ResizeBackward : public ResizeBase { | |||
| DEF_OPR_IMPL(ResizeBackward, ResizeBase, 1, 1); | |||
| public: | |||
| virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace) = 0; | |||
| virtual void exec( | |||
| _megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace) = 0; | |||
| virtual size_t get_workspace_in_bytes(const TensorLayout& diff, | |||
| const TensorLayout& mat) = 0; | |||
| virtual size_t get_workspace_in_bytes( | |||
| const TensorLayout& diff, const TensorLayout& mat) = 0; | |||
| protected: | |||
| void check_exec(const TensorLayout& diff, const TensorLayout& mat, | |||
| size_t workspace_in_bytes); | |||
| void check_exec( | |||
| const TensorLayout& diff, const TensorLayout& mat, | |||
| size_t workspace_in_bytes); | |||
| }; | |||
| /** | |||
| @@ -251,29 +268,32 @@ public: | |||
| using BorderMode = Param::BorderMode; | |||
| protected: | |||
| void check_layout_fwd(const TensorLayout& src, const TensorLayout& map_xy, | |||
| const TensorLayout& dst); | |||
| void deduce_layout_fwd(const TensorLayout& src, const TensorLayout& map_xy, | |||
| TensorLayout& dst); | |||
| void check_layout_fwd( | |||
| const TensorLayout& src, const TensorLayout& map_xy, | |||
| const TensorLayout& dst); | |||
| void deduce_layout_fwd( | |||
| const TensorLayout& src, const TensorLayout& map_xy, TensorLayout& dst); | |||
| }; | |||
| class RemapForward : public RemapBase { | |||
| DEF_OPR_IMPL(RemapForward, RemapBase, 2, 1); | |||
| public: | |||
| virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy, | |||
| _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; | |||
| virtual void exec( | |||
| _megdnn_tensor_in src, _megdnn_tensor_in map_xy, _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) = 0; | |||
| void deduce_layout(const TensorLayout& src, const TensorLayout& map_xy, | |||
| TensorLayout& dst); | |||
| void deduce_layout( | |||
| const TensorLayout& src, const TensorLayout& map_xy, TensorLayout& dst); | |||
| virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||
| const TensorLayout& map_xy, | |||
| const TensorLayout& dst) = 0; | |||
| virtual size_t get_workspace_in_bytes( | |||
| const TensorLayout& src, const TensorLayout& map_xy, | |||
| const TensorLayout& dst) = 0; | |||
| protected: | |||
| void check_exec(const TensorLayout& src, const TensorLayout& map_xy, | |||
| const TensorLayout& dst, size_t workspace_in_bytes); | |||
| void check_exec( | |||
| const TensorLayout& src, const TensorLayout& map_xy, | |||
| const TensorLayout& dst, size_t workspace_in_bytes); | |||
| }; | |||
| using Remap = RemapForward; | |||
| @@ -281,35 +301,37 @@ class RemapBackwardData : public RemapBase { | |||
| DEF_OPR_IMPL(RemapBackwardData, RemapBase, 2, 1); | |||
| public: | |||
| virtual void exec(_megdnn_tensor_in map_xy, _megdnn_tensor_in diff, | |||
| _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; | |||
| virtual void exec( | |||
| _megdnn_tensor_in map_xy, _megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace) = 0; | |||
| virtual size_t get_workspace_in_bytes(const TensorLayout& map_xy, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad) = 0; | |||
| virtual size_t get_workspace_in_bytes( | |||
| const TensorLayout& map_xy, const TensorLayout& diff, | |||
| const TensorLayout& grad) = 0; | |||
| protected: | |||
| void check_exec(const TensorLayout& map_xy, const TensorLayout& diff, | |||
| const TensorLayout& grad, size_t workspace_in_bytes); | |||
| void check_exec( | |||
| const TensorLayout& map_xy, const TensorLayout& diff, | |||
| const TensorLayout& grad, size_t workspace_in_bytes); | |||
| }; | |||
| class RemapBackwardMat : public RemapBase { | |||
| DEF_OPR_IMPL(RemapBackwardMat, RemapBase, 3, 1); | |||
| public: | |||
| virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in map_xy, | |||
| _megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace) = 0; | |||
| virtual void exec( | |||
| _megdnn_tensor_in src, _megdnn_tensor_in map_xy, _megdnn_tensor_in diff, | |||
| _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; | |||
| virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||
| const TensorLayout& map_xy, | |||
| const TensorLayout& diff, | |||
| const TensorLayout& grad) = 0; | |||
| virtual size_t get_workspace_in_bytes( | |||
| const TensorLayout& src, const TensorLayout& map_xy, | |||
| const TensorLayout& diff, const TensorLayout& grad) = 0; | |||
| protected: | |||
| void check_exec(const TensorLayout& src, const TensorLayout& map_xy, | |||
| const TensorLayout& diff, const TensorLayout& grad, | |||
| size_t workspace_in_bytes); | |||
| void check_exec( | |||
| const TensorLayout& src, const TensorLayout& map_xy, | |||
| const TensorLayout& diff, const TensorLayout& grad, | |||
| size_t workspace_in_bytes); | |||
| }; | |||
| class SeparableFilterBase : public OperatorBase { | |||
| @@ -317,32 +339,34 @@ class SeparableFilterBase : public OperatorBase { | |||
| DEF_OPR_PARAM(SeparableFilter); | |||
| protected: | |||
| void deduce_layout_fwd(const TensorLayout& src, | |||
| const TensorLayout& filter_x, | |||
| const TensorLayout& filter_y, TensorLayout& dst); | |||
| void check_layout_fwd(const TensorLayout& src, const TensorLayout& filter_x, | |||
| const TensorLayout& filter_y, | |||
| const TensorLayout& dst); | |||
| void deduce_layout_fwd( | |||
| const TensorLayout& src, const TensorLayout& filter_x, | |||
| const TensorLayout& filter_y, TensorLayout& dst); | |||
| void check_layout_fwd( | |||
| const TensorLayout& src, const TensorLayout& filter_x, | |||
| const TensorLayout& filter_y, const TensorLayout& dst); | |||
| }; | |||
| class SeparableFilterForward : public SeparableFilterBase { | |||
| DEF_OPR_IMPL(SeparableFilterForward, SeparableFilterBase, 3, 1); | |||
| public: | |||
| virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter_x, | |||
| _megdnn_tensor_in filter_y, _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) = 0; | |||
| void deduce_layout(const TensorLayout& src, const TensorLayout& filter_x, | |||
| const TensorLayout& filter_y, TensorLayout& dst); | |||
| virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||
| const TensorLayout& filter_x, | |||
| const TensorLayout& filter_y, | |||
| const TensorLayout& dst) = 0; | |||
| virtual void exec( | |||
| _megdnn_tensor_in src, _megdnn_tensor_in filter_x, | |||
| _megdnn_tensor_in filter_y, _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) = 0; | |||
| void deduce_layout( | |||
| const TensorLayout& src, const TensorLayout& filter_x, | |||
| const TensorLayout& filter_y, TensorLayout& dst); | |||
| virtual size_t get_workspace_in_bytes( | |||
| const TensorLayout& src, const TensorLayout& filter_x, | |||
| const TensorLayout& filter_y, const TensorLayout& dst) = 0; | |||
| protected: | |||
| void check_exec(const TensorLayout& src, const TensorLayout& filter_x, | |||
| const TensorLayout& filter_y, const TensorLayout& dst, | |||
| size_t workspace_in_bytes); | |||
| void check_exec( | |||
| const TensorLayout& src, const TensorLayout& filter_x, | |||
| const TensorLayout& filter_y, const TensorLayout& dst, | |||
| size_t workspace_in_bytes); | |||
| }; | |||
| using SeparableFilter = SeparableFilterForward; | |||
| @@ -13,173 +13,162 @@ | |||
| namespace megdnn { | |||
| class WarpPerspectiveBase: public OperatorBase { | |||
| class WarpPerspectiveBase : public OperatorBase { | |||
| DEF_OPR_IMPL_CTOR(WarpPerspectiveBase, OperatorBase); | |||
| DEF_OPR_PARAM(WarpPerspective); | |||
| public: | |||
| using InterpolationMode = Param::InterpolationMode; | |||
| using BorderMode = Param::BorderMode; | |||
| protected: | |||
| void check_layout_fwd(const TensorLayout &src, const TensorLayout &mat, | |||
| const TensorLayout &dst) { | |||
| check_layout_fwd(src, mat, {}, dst); | |||
| } | |||
| void check_layout_fwd(const TensorLayout &src, const TensorLayout &mat, | |||
| const TensorLayout &mat_idx, const TensorLayout &dst); | |||
| std::string param_msg() const; | |||
| int get_real_coord(int p, int len); | |||
| public: | |||
| using InterpolationMode = Param::InterpolationMode; | |||
| using BorderMode = Param::BorderMode; | |||
| protected: | |||
| void check_layout_fwd( | |||
| const TensorLayout& src, const TensorLayout& mat, const TensorLayout& dst) { | |||
| check_layout_fwd(src, mat, {}, dst); | |||
| } | |||
| void check_layout_fwd( | |||
| const TensorLayout& src, const TensorLayout& mat, | |||
| const TensorLayout& mat_idx, const TensorLayout& dst); | |||
| std::string param_msg() const; | |||
| int get_real_coord(int p, int len); | |||
| }; | |||
| class WarpPerspectiveForward: public WarpPerspectiveBase { | |||
| class WarpPerspectiveForward : public WarpPerspectiveBase { | |||
| DEF_OPR_IMPL(WarpPerspectiveForward, WarpPerspectiveBase, 0, 1); | |||
| public: | |||
| /** | |||
| * \param[in] src (n, channel, in_height, in_width) | |||
| * \param[in] mat (n, 3, 3) | |||
| * \param[out] dst (n, channel, out_height, out_width) | |||
| * | |||
| * \see http://docs.opencv.org/2.4/modules/imgproc/doc/geometric_transformations.html?highlight=warpaffine | |||
| * | |||
| * denominator = mat[2][0]*w+mat[2][1]*h+mat[2][2] | |||
| * dst(h, w) = src((mat[1][0]*w+mat[1][1]*h+mat[1][2])/denominator, | |||
| * (mat[0][0]*w+mat[0][1]*h+mat[0][2])/denominator) | |||
| * | |||
| * src and dst can have different shapes, as long as their n and c agree. | |||
| * src, mat and dst should be contiguous. | |||
| */ | |||
| void exec(_megdnn_tensor_in src, | |||
| _megdnn_tensor_in mat, | |||
| _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) { | |||
| exec(src, mat, {}, dst, workspace); | |||
| } | |||
| /** | |||
| * \p src should have batch size m, and \p mat and \p mat_idx should | |||
| * both have batch size n. Each item in \p mat_idx must be in the range | |||
| * of [0, m-1]. | |||
| * | |||
| * \param mat_idx the indices of input image that each matrix in \p mat | |||
| * should act on. It can also be empty and in such case \p mat | |||
| * should have the same batch size as \p src. | |||
| */ | |||
| virtual void exec(_megdnn_tensor_in src, | |||
| _megdnn_tensor_in mat, | |||
| _megdnn_tensor_in mat_idx, | |||
| _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) = 0; | |||
| size_t get_workspace_in_bytes(const TensorLayout &src, | |||
| const TensorLayout &mat, | |||
| const TensorLayout &dst) { | |||
| return get_workspace_in_bytes(src, mat, {}, dst); | |||
| } | |||
| virtual size_t get_workspace_in_bytes(const TensorLayout &src, | |||
| const TensorLayout &mat, | |||
| const TensorLayout &mat_idx, | |||
| const TensorLayout &dst) = 0; | |||
| protected: | |||
| void check_exec(const TensorLayout &src, | |||
| const TensorLayout &mat, | |||
| const TensorLayout &mat_idx, | |||
| const TensorLayout &dst, | |||
| size_t workspace_in_bytes); | |||
| void check_exec_allow_nhwc_mat_idx(const TensorLayout &src, | |||
| const TensorLayout &mat, | |||
| const TensorLayout &mat_idx, | |||
| const TensorLayout &dst, | |||
| size_t workspace_in_bytes); | |||
| public: | |||
| /** | |||
| * \param[in] src (n, channel, in_height, in_width) | |||
| * \param[in] mat (n, 3, 3) | |||
| * \param[out] dst (n, channel, out_height, out_width) | |||
| * | |||
| * \see | |||
| * http://docs.opencv.org/2.4/modules/imgproc/doc/geometric_transformations.html?highlight=warpaffine | |||
| * | |||
| * denominator = mat[2][0]*w+mat[2][1]*h+mat[2][2] | |||
| * dst(h, w) = src((mat[1][0]*w+mat[1][1]*h+mat[1][2])/denominator, | |||
| * (mat[0][0]*w+mat[0][1]*h+mat[0][2])/denominator) | |||
| * | |||
| * src and dst can have different shapes, as long as their n and c agree. | |||
| * src, mat and dst should be contiguous. | |||
| */ | |||
| void exec( | |||
| _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) { | |||
| exec(src, mat, {}, dst, workspace); | |||
| } | |||
| /** | |||
| * \p src should have batch size m, and \p mat and \p mat_idx should | |||
| * both have batch size n. Each item in \p mat_idx must be in the range | |||
| * of [0, m-1]. | |||
| * | |||
| * \param mat_idx the indices of input image that each matrix in \p mat | |||
| * should act on. It can also be empty and in such case \p mat | |||
| * should have the same batch size as \p src. | |||
| */ | |||
| virtual void exec( | |||
| _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, | |||
| _megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; | |||
| size_t get_workspace_in_bytes( | |||
| const TensorLayout& src, const TensorLayout& mat, const TensorLayout& dst) { | |||
| return get_workspace_in_bytes(src, mat, {}, dst); | |||
| } | |||
| virtual size_t get_workspace_in_bytes( | |||
| const TensorLayout& src, const TensorLayout& mat, | |||
| const TensorLayout& mat_idx, const TensorLayout& dst) = 0; | |||
| protected: | |||
| void check_exec( | |||
| const TensorLayout& src, const TensorLayout& mat, | |||
| const TensorLayout& mat_idx, const TensorLayout& dst, | |||
| size_t workspace_in_bytes); | |||
| void check_exec_allow_nhwc_mat_idx( | |||
| const TensorLayout& src, const TensorLayout& mat, | |||
| const TensorLayout& mat_idx, const TensorLayout& dst, | |||
| size_t workspace_in_bytes); | |||
| }; | |||
| using WarpPerspective = WarpPerspectiveForward; | |||
| class WarpPerspectiveBackwardData: public WarpPerspectiveBase { | |||
| class WarpPerspectiveBackwardData : public WarpPerspectiveBase { | |||
| DEF_OPR_IMPL(WarpPerspectiveBackwardData, WarpPerspectiveBase, 2, 1); | |||
| public: | |||
| /** | |||
| * \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec | |||
| * \param[in] diff the backpropagated gradient wrt. dst | |||
| * \param[out] grad the backpropagated gradient wrt. src | |||
| * \param[out] workspace temporary workspace to perform backward | |||
| */ | |||
| void exec(_megdnn_tensor_in mat, | |||
| _megdnn_tensor_in diff, | |||
| _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace) { | |||
| exec(mat, {}, diff, grad, workspace); | |||
| } | |||
| virtual void exec(_megdnn_tensor_in mat, | |||
| _megdnn_tensor_in mat_idx, | |||
| _megdnn_tensor_in diff, | |||
| _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace) = 0; | |||
| size_t get_workspace_in_bytes(const TensorLayout &mat, | |||
| const TensorLayout &diff, | |||
| const TensorLayout &grad) { | |||
| return get_workspace_in_bytes(mat, {}, diff, grad); | |||
| } | |||
| virtual size_t get_workspace_in_bytes(const TensorLayout &mat, | |||
| const TensorLayout &mat_idx, | |||
| const TensorLayout &diff, | |||
| const TensorLayout &grad) = 0; | |||
| protected: | |||
| void check_exec(const TensorLayout &mat, | |||
| const TensorLayout &mat_idx, | |||
| const TensorLayout &diff, | |||
| const TensorLayout &grad, | |||
| size_t workspace_in_bytes); | |||
| public: | |||
| /** | |||
| * \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec | |||
| * \param[in] diff the backpropagated gradient wrt. dst | |||
| * \param[out] grad the backpropagated gradient wrt. src | |||
| * \param[out] workspace temporary workspace to perform backward | |||
| */ | |||
| void exec( | |||
| _megdnn_tensor_in mat, _megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace) { | |||
| exec(mat, {}, diff, grad, workspace); | |||
| } | |||
| virtual void exec( | |||
| _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, _megdnn_tensor_in diff, | |||
| _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; | |||
| size_t get_workspace_in_bytes( | |||
| const TensorLayout& mat, const TensorLayout& diff, | |||
| const TensorLayout& grad) { | |||
| return get_workspace_in_bytes(mat, {}, diff, grad); | |||
| } | |||
| virtual size_t get_workspace_in_bytes( | |||
| const TensorLayout& mat, const TensorLayout& mat_idx, | |||
| const TensorLayout& diff, const TensorLayout& grad) = 0; | |||
| protected: | |||
| void check_exec( | |||
| const TensorLayout& mat, const TensorLayout& mat_idx, | |||
| const TensorLayout& diff, const TensorLayout& grad, | |||
| size_t workspace_in_bytes); | |||
| }; | |||
| class WarpPerspectiveBackwardMat: public WarpPerspectiveBase { | |||
| class WarpPerspectiveBackwardMat : public WarpPerspectiveBase { | |||
| DEF_OPR_IMPL(WarpPerspectiveBackwardMat, WarpPerspectiveBase, 3, 1); | |||
| public: | |||
| /** | |||
| * \param[in] src the `src' parameter in WarpPerspectiveForward::exec | |||
| * \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec | |||
| * \param[in] diff the backpropagated gradient wrt. dst | |||
| * \param[out] grad the backpropagated gradient wrt. mat | |||
| * \param[out] workspace temporary workspace to perform backward | |||
| */ | |||
| void exec(_megdnn_tensor_in src, | |||
| _megdnn_tensor_in mat, | |||
| _megdnn_tensor_in diff, | |||
| _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace) { | |||
| exec(src, mat, {}, diff, grad, workspace); | |||
| } | |||
| virtual void exec(_megdnn_tensor_in src, | |||
| _megdnn_tensor_in mat, | |||
| _megdnn_tensor_in mat_idx, | |||
| _megdnn_tensor_in diff, | |||
| _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace) = 0; | |||
| size_t get_workspace_in_bytes(const TensorLayout &src, | |||
| const TensorLayout &mat, | |||
| const TensorLayout &diff, | |||
| const TensorLayout &grad) { | |||
| return get_workspace_in_bytes(src, mat, {}, diff, grad); | |||
| } | |||
| virtual size_t get_workspace_in_bytes(const TensorLayout &src, | |||
| const TensorLayout &mat, | |||
| const TensorLayout &mat_idx, | |||
| const TensorLayout &diff, | |||
| const TensorLayout &grad) = 0; | |||
| protected: | |||
| void check_exec(const TensorLayout &src, | |||
| const TensorLayout &mat, | |||
| const TensorLayout &mat_idx, | |||
| const TensorLayout &diff, | |||
| const TensorLayout &grad, | |||
| size_t workspace_in_bytes); | |||
| public: | |||
| /** | |||
| * \param[in] src the `src' parameter in WarpPerspectiveForward::exec | |||
| * \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec | |||
| * \param[in] diff the backpropagated gradient wrt. dst | |||
| * \param[out] grad the backpropagated gradient wrt. mat | |||
| * \param[out] workspace temporary workspace to perform backward | |||
| */ | |||
| void exec( | |||
| _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in diff, | |||
| _megdnn_tensor_out grad, _megdnn_workspace workspace) { | |||
| exec(src, mat, {}, diff, grad, workspace); | |||
| } | |||
| virtual void exec( | |||
| _megdnn_tensor_in src, _megdnn_tensor_in mat, _megdnn_tensor_in mat_idx, | |||
| _megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace) = 0; | |||
| size_t get_workspace_in_bytes( | |||
| const TensorLayout& src, const TensorLayout& mat, const TensorLayout& diff, | |||
| const TensorLayout& grad) { | |||
| return get_workspace_in_bytes(src, mat, {}, diff, grad); | |||
| } | |||
| virtual size_t get_workspace_in_bytes( | |||
| const TensorLayout& src, const TensorLayout& mat, | |||
| const TensorLayout& mat_idx, const TensorLayout& diff, | |||
| const TensorLayout& grad) = 0; | |||
| protected: | |||
| void check_exec( | |||
| const TensorLayout& src, const TensorLayout& mat, | |||
| const TensorLayout& mat_idx, const TensorLayout& diff, | |||
| const TensorLayout& grad, size_t workspace_in_bytes); | |||
| }; | |||
| class DctChannelSelectForward : public OperatorBase { | |||
| @@ -194,37 +183,32 @@ public: | |||
| * \param[dst] DctChannelSelectForward output, default fp32 nchw tensor | |||
| * \param[out] workspace temporary workspace to perform forward | |||
| */ | |||
| virtual void exec(_megdnn_tensor_in src, | |||
| _megdnn_tensor_in mask_offset, | |||
| _megdnn_tensor_in mask_val, | |||
| _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) = 0; | |||
| void deduce_layout(const TensorLayout& src, | |||
| const TensorLayout& mask_offset, | |||
| const TensorLayout& mask_val, | |||
| TensorLayout& dst); | |||
| virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||
| const TensorLayout& mask_offset, | |||
| const TensorLayout& mask_val, | |||
| const TensorLayout& dst) = 0; | |||
| virtual void exec( | |||
| _megdnn_tensor_in src, _megdnn_tensor_in mask_offset, | |||
| _megdnn_tensor_in mask_val, _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) = 0; | |||
| void deduce_layout( | |||
| const TensorLayout& src, const TensorLayout& mask_offset, | |||
| const TensorLayout& mask_val, TensorLayout& dst); | |||
| virtual size_t get_workspace_in_bytes( | |||
| const TensorLayout& src, const TensorLayout& mask_offset, | |||
| const TensorLayout& mask_val, const TensorLayout& dst) = 0; | |||
| protected: | |||
| void check_layout_fwd(const TensorLayout& src, | |||
| const TensorLayout& mask_offset, | |||
| const TensorLayout& mask_val, | |||
| const TensorLayout& dst); | |||
| void deduce_layout_fwd(const TensorLayout& src, | |||
| const TensorLayout& mask_offset, | |||
| const TensorLayout& mask_val, | |||
| TensorLayout& dst); | |||
| void check_layout_fwd( | |||
| const TensorLayout& src, const TensorLayout& mask_offset, | |||
| const TensorLayout& mask_val, const TensorLayout& dst); | |||
| void deduce_layout_fwd( | |||
| const TensorLayout& src, const TensorLayout& mask_offset, | |||
| const TensorLayout& mask_val, TensorLayout& dst); | |||
| std::string param_msg() const; | |||
| }; | |||
| } // namespace megdnn | |||
| } // namespace megdnn | |||
| #include "megdnn/internal/opr_header_epilogue.h" | |||
| @@ -33,22 +33,22 @@ public: | |||
| * op(A) = A if transposeA is false, otherwise op(A) = A^t. | |||
| * op(B) = B if transposeB is false, otherwise op(B) = B^t. | |||
| */ | |||
| virtual void exec(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||
| _megdnn_tensor_out C, _megdnn_workspace workspace) = 0; | |||
| void deduce_dtype(DType A, DType B, DType &C); | |||
| void deduce_layout(const TensorLayout& A, const TensorLayout& B, | |||
| TensorLayout& C); | |||
| virtual size_t get_workspace_in_bytes(const TensorLayout& A, | |||
| const TensorLayout& B, | |||
| const TensorLayout& C) = 0; | |||
| virtual void exec( | |||
| _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, | |||
| _megdnn_workspace workspace) = 0; | |||
| void deduce_dtype(DType A, DType B, DType& C); | |||
| void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C); | |||
| virtual size_t get_workspace_in_bytes( | |||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0; | |||
| static Algorithm::OprType get_opr_type() { | |||
| return Algorithm::OprType::BATCHED_MATRIX_MUL_FORWARD; | |||
| } | |||
| protected: | |||
| void check_exec(const TensorLayout& A, const TensorLayout& B, | |||
| const TensorLayout& C, size_t workspace_in_bytes); | |||
| void check_exec( | |||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | |||
| size_t workspace_in_bytes); | |||
| }; | |||
| using BatchedMatrixMul = BatchedMatrixMulForward; | |||
| @@ -70,24 +70,24 @@ public: | |||
| * op(A) = A if transposeA is false, otherwise op(A) = A^t. | |||
| * op(B) = B if transposeB is false, otherwise op(B) = B^t. | |||
| */ | |||
| virtual void exec(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||
| _megdnn_tensor_out C, _megdnn_workspace workspace) = 0; | |||
| virtual void exec( | |||
| _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, | |||
| _megdnn_workspace workspace) = 0; | |||
| void deduce_dtype(DType A, DType B, DType& C); | |||
| void deduce_layout(const TensorLayout& A, const TensorLayout& B, | |||
| TensorLayout& C); | |||
| virtual size_t get_workspace_in_bytes(const TensorLayout& A, | |||
| const TensorLayout& B, | |||
| const TensorLayout& C) = 0; | |||
| void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C); | |||
| virtual size_t get_workspace_in_bytes( | |||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0; | |||
| static size_t pack_size (const Param::Format format); | |||
| static size_t pack_size(const Param::Format format); | |||
| static Algorithm::OprType get_opr_type() { | |||
| return Algorithm::OprType::MATRIX_MUL_FORWARD; | |||
| } | |||
| protected: | |||
| void check_exec(const TensorLayout& A, const TensorLayout& B, | |||
| const TensorLayout& C, size_t workspace_in_bytes); | |||
| void check_exec( | |||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | |||
| size_t workspace_in_bytes); | |||
| }; | |||
| using MatrixMul = MatrixMulForward; | |||
| @@ -104,11 +104,11 @@ class MatrixInverse : public OperatorBase { | |||
| DEF_OPR_PARAM(Empty); | |||
| public: | |||
| virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) = 0; | |||
| virtual void exec( | |||
| _megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) = 0; | |||
| void deduce_layout(const TensorLayout& src, TensorLayout& dst); | |||
| size_t get_workspace_in_bytes(const TensorLayout& src, | |||
| const TensorLayout& dst); | |||
| size_t get_workspace_in_bytes(const TensorLayout& src, const TensorLayout& dst); | |||
| protected: | |||
| /*! | |||
| @@ -116,8 +116,7 @@ protected: | |||
| * | |||
| * Note that \p batch and \p n can be null | |||
| */ | |||
| static void canonize_params(const TensorLayout& layout, size_t* batch, | |||
| size_t* n); | |||
| static void canonize_params(const TensorLayout& layout, size_t* batch, size_t* n); | |||
| /*! | |||
| * \brief canonize and validate input params for exec() impls | |||
| @@ -125,11 +124,12 @@ protected: | |||
| * Since get_workspace_in_bytes() would be called, \p batch and \p n can not | |||
| * be null | |||
| */ | |||
| void check_exec(const TensorLayout& src, const TensorLayout& dst, | |||
| _megdnn_workspace workspace, size_t* batch, size_t* n); | |||
| void check_exec( | |||
| const TensorLayout& src, const TensorLayout& dst, | |||
| _megdnn_workspace workspace, size_t* batch, size_t* n); | |||
| virtual size_t get_workspace_in_bytes(size_t batch, size_t n, | |||
| size_t dtype_size) = 0; | |||
| virtual size_t get_workspace_in_bytes( | |||
| size_t batch, size_t n, size_t dtype_size) = 0; | |||
| }; | |||
| //! inter-product of two vectors | |||
| @@ -147,17 +147,17 @@ public: | |||
| * A, B, C must be contiguous. A and B must have the same 1-dimensional | |||
| * shape and non-negative strides. C must be scalar. | |||
| */ | |||
| virtual void exec(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||
| _megdnn_tensor_out C, _megdnn_workspace workspace) = 0; | |||
| void deduce_layout(const TensorLayout& A, const TensorLayout& B, | |||
| TensorLayout& C); | |||
| virtual size_t get_workspace_in_bytes(const TensorLayout& A, | |||
| const TensorLayout& B, | |||
| const TensorLayout& C) = 0; | |||
| virtual void exec( | |||
| _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, | |||
| _megdnn_workspace workspace) = 0; | |||
| void deduce_layout(const TensorLayout& A, const TensorLayout& B, TensorLayout& C); | |||
| virtual size_t get_workspace_in_bytes( | |||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) = 0; | |||
| protected: | |||
| void check_exec(const TensorLayout& A, const TensorLayout& B, | |||
| const TensorLayout& C, size_t workspace_in_bytes); | |||
| void check_exec( | |||
| const TensorLayout& A, const TensorLayout& B, const TensorLayout& C, | |||
| size_t workspace_in_bytes); | |||
| }; | |||
| using Dot = DotForward; | |||
| @@ -193,23 +193,24 @@ public: | |||
| * if compute_uv is false (default to true). | |||
| * | |||
| */ | |||
| virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out u, | |||
| _megdnn_tensor_out s, _megdnn_tensor_out vt, | |||
| _megdnn_workspace workspace) = 0; | |||
| void deduce_layout(const TensorLayout& src, TensorLayout& u, | |||
| TensorLayout& s, TensorLayout& vt); | |||
| size_t get_workspace_in_bytes(const TensorLayout& src, | |||
| const TensorLayout& u, const TensorLayout& s, | |||
| const TensorLayout& vt); | |||
| virtual void exec( | |||
| _megdnn_tensor_in src, _megdnn_tensor_out u, _megdnn_tensor_out s, | |||
| _megdnn_tensor_out vt, _megdnn_workspace workspace) = 0; | |||
| void deduce_layout( | |||
| const TensorLayout& src, TensorLayout& u, TensorLayout& s, | |||
| TensorLayout& vt); | |||
| size_t get_workspace_in_bytes( | |||
| const TensorLayout& src, const TensorLayout& u, const TensorLayout& s, | |||
| const TensorLayout& vt); | |||
| protected: | |||
| static void canonize_params(const TensorLayout& layout, size_t* batch, | |||
| size_t* m, size_t* n); | |||
| virtual size_t get_workspace_in_bytes(size_t block_cnt, size_t m, size_t n, | |||
| size_t dtype_size) = 0; | |||
| void check_exec(const TensorLayout& src, const TensorLayout& u, | |||
| const TensorLayout& s, const TensorLayout& vt, | |||
| size_t workspace_in_bytes); | |||
| static void canonize_params( | |||
| const TensorLayout& layout, size_t* batch, size_t* m, size_t* n); | |||
| virtual size_t get_workspace_in_bytes( | |||
| size_t block_cnt, size_t m, size_t n, size_t dtype_size) = 0; | |||
| void check_exec( | |||
| const TensorLayout& src, const TensorLayout& u, const TensorLayout& s, | |||
| const TensorLayout& vt, size_t workspace_in_bytes); | |||
| }; | |||
| using SVD = SVDForward; | |||
| @@ -36,7 +36,7 @@ public: | |||
| struct ModeTrait { | |||
| uint32_t arity = 0; //!< number of inputs needed | |||
| CheckDtypeFunc check_inp[MAX_ARITY]; | |||
| SetOrCheckDtypeFunc check_out; //!< dtype of output var | |||
| SetOrCheckDtypeFunc check_out; //!< dtype of output var | |||
| bool need_specify_out_dtype = | |||
| false; //!< the dtype should be setup externally, otherwise | |||
| //!< would be inferred by check_out(dtype, false) | |||
| @@ -46,13 +46,10 @@ public: | |||
| static const ModeTrait& from_mode(Mode mode); | |||
| }; | |||
| virtual void exec(_megdnn_in const TensorNDArray& src, | |||
| _megdnn_tensor_out dst) = 0; | |||
| virtual void exec(_megdnn_in const TensorNDArray& src, _megdnn_tensor_out dst) = 0; | |||
| //! get trait of current mode | |||
| const ModeTrait& mode_trait() const { | |||
| return ModeTrait::from_mode(m_param.mode); | |||
| } | |||
| const ModeTrait& mode_trait() const { return ModeTrait::from_mode(m_param.mode); } | |||
| //! deduce output layout | |||
| void deduce_layout(const TensorLayoutArray& src, TensorLayout& dst); | |||
| @@ -60,8 +57,8 @@ public: | |||
| protected: | |||
| //! throw exception if incorrect layout; broadcast input shape to | |||
| //! output shape | |||
| void check_layout_and_broadcast(const TensorLayoutPtrArray& src, | |||
| const TensorLayout& dst); | |||
| void check_layout_and_broadcast( | |||
| const TensorLayoutPtrArray& src, const TensorLayout& dst); | |||
| }; | |||
| } // namespace megdnn | |||
| @@ -15,84 +15,97 @@ | |||
| namespace megdnn { | |||
| //! base class for random number generators | |||
| class RNGBase: public OperatorBase { | |||
| class RNGBase : public OperatorBase { | |||
| DEF_OPR_IMPL_CTOR(RNGBase, OperatorBase); | |||
| public: | |||
| virtual void exec(_megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) = 0; | |||
| virtual size_t get_workspace_in_bytes(const TensorLayout &dst) = 0; | |||
| protected: | |||
| virtual void check_exec(const TensorLayout &dst, size_t workspace_in_bytes) = 0; | |||
| public: | |||
| virtual void exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; | |||
| virtual size_t get_workspace_in_bytes(const TensorLayout& dst) = 0; | |||
| protected: | |||
| virtual void check_exec(const TensorLayout& dst, size_t workspace_in_bytes) = 0; | |||
| }; | |||
| //! sample from poisson distribution | |||
| class PoissonRNG: public OperatorBase { | |||
| class PoissonRNG : public OperatorBase { | |||
| DEF_OPR_IMPL(PoissonRNG, OperatorBase, 1, 1); | |||
| DEF_OPR_PARAM(PoissonRNG); | |||
| public: | |||
| virtual void exec(_megdnn_tensor_in lam, | |||
| _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) = 0; | |||
| virtual size_t get_workspace_in_bytes(const TensorLayout &lam, | |||
| const TensorLayout &dst) = 0; | |||
| protected: | |||
| void check_exec(const TensorLayout &lam, const TensorLayout &dst, | |||
| size_t workspace_in_bytes); | |||
| public: | |||
| virtual void exec( | |||
| _megdnn_tensor_in lam, _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) = 0; | |||
| virtual size_t get_workspace_in_bytes( | |||
| const TensorLayout& lam, const TensorLayout& dst) = 0; | |||
| protected: | |||
| void check_exec( | |||
| const TensorLayout& lam, const TensorLayout& dst, | |||
| size_t workspace_in_bytes); | |||
| }; | |||
| //! sample from beta distribution | |||
| class BetaRNG: public OperatorBase { | |||
| class BetaRNG : public OperatorBase { | |||
| DEF_OPR_IMPL(BetaRNG, OperatorBase, 2, 1); | |||
| DEF_OPR_PARAM(BetaRNG); | |||
| public: | |||
| virtual void exec(_megdnn_tensor_in alpha, | |||
| _megdnn_tensor_in beta, | |||
| _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) = 0; | |||
| virtual size_t get_workspace_in_bytes(const TensorLayout &alpha, | |||
| const TensorLayout &beta, const TensorLayout &dst) = 0; | |||
| protected: | |||
| void check_exec(const TensorLayout &alpha, const TensorLayout &beta, | |||
| const TensorLayout &dst, size_t workspace_in_bytes); | |||
| public: | |||
| virtual void exec( | |||
| _megdnn_tensor_in alpha, _megdnn_tensor_in beta, _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) = 0; | |||
| virtual size_t get_workspace_in_bytes( | |||
| const TensorLayout& alpha, const TensorLayout& beta, | |||
| const TensorLayout& dst) = 0; | |||
| protected: | |||
| void check_exec( | |||
| const TensorLayout& alpha, const TensorLayout& beta, | |||
| const TensorLayout& dst, size_t workspace_in_bytes); | |||
| }; | |||
| //! sample from gamma distribution | |||
| class GammaRNG: public OperatorBase { | |||
| class GammaRNG : public OperatorBase { | |||
| DEF_OPR_IMPL(GammaRNG, OperatorBase, 2, 1); | |||
| DEF_OPR_PARAM(GammaRNG); | |||
| public: | |||
| virtual void exec(_megdnn_tensor_in shape, | |||
| _megdnn_tensor_in scale, | |||
| _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) = 0; | |||
| virtual size_t get_workspace_in_bytes(const TensorLayout &shape, | |||
| const TensorLayout &scale, const TensorLayout &dst) = 0; | |||
| protected: | |||
| void check_exec(const TensorLayout &shape, const TensorLayout &scale, | |||
| const TensorLayout &dst, size_t workspace_in_bytes); | |||
| public: | |||
| virtual void exec( | |||
| _megdnn_tensor_in shape, _megdnn_tensor_in scale, _megdnn_tensor_out dst, | |||
| _megdnn_workspace workspace) = 0; | |||
| virtual size_t get_workspace_in_bytes( | |||
| const TensorLayout& shape, const TensorLayout& scale, | |||
| const TensorLayout& dst) = 0; | |||
| protected: | |||
| void check_exec( | |||
| const TensorLayout& shape, const TensorLayout& scale, | |||
| const TensorLayout& dst, size_t workspace_in_bytes); | |||
| }; | |||
| //! sample from uniform distribution on the interval (0, 1] | |||
| class UniformRNG: public RNGBase { | |||
| class UniformRNG : public RNGBase { | |||
| DEF_OPR_IMPL(UniformRNG, RNGBase, 0, 1); | |||
| DEF_OPR_PARAM(UniformRNG); | |||
| protected: | |||
| void check_exec(const TensorLayout &dst, size_t workspace_in_bytes); | |||
| protected: | |||
| void check_exec(const TensorLayout& dst, size_t workspace_in_bytes); | |||
| }; | |||
| //! sample from gaussian distribution | |||
| class GaussianRNG: public RNGBase { | |||
| class GaussianRNG : public RNGBase { | |||
| DEF_OPR_IMPL(GaussianRNG, RNGBase, 0, 1); | |||
| DEF_OPR_PARAM(GaussianRNG); | |||
| protected: | |||
| void check_exec(const TensorLayout &dst, size_t workspace_in_bytes); | |||
| protected: | |||
| void check_exec(const TensorLayout& dst, size_t workspace_in_bytes); | |||
| }; | |||
| class PermutationRNG: public RNGBase { | |||
| class PermutationRNG : public RNGBase { | |||
| DEF_OPR_IMPL(PermutationRNG, RNGBase, 0, 1); | |||
| DEF_OPR_PARAM(PermutationRNG); | |||
| protected: | |||
| void check_exec(const TensorLayout &dst, size_t workspace_in_bytes); | |||
| protected: | |||
| void check_exec(const TensorLayout& dst, size_t workspace_in_bytes); | |||
| }; | |||
| class ShuffleRNGForward : public OperatorBase { | |||
| @@ -100,18 +113,19 @@ class ShuffleRNGForward : public OperatorBase { | |||
| DEF_OPR_PARAM(ShuffleRNG); | |||
| public: | |||
| virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
| _megdnn_tensor_out indices, | |||
| _megdnn_workspace workspace) = 0; | |||
| void deduce_layout(const TensorLayout& src, TensorLayout& dst, | |||
| TensorLayout& indices); | |||
| virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||
| const TensorLayout& dst, | |||
| const TensorLayout& indices) = 0; | |||
| virtual void exec( | |||
| _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_tensor_out indices, | |||
| _megdnn_workspace workspace) = 0; | |||
| void deduce_layout( | |||
| const TensorLayout& src, TensorLayout& dst, TensorLayout& indices); | |||
| virtual size_t get_workspace_in_bytes( | |||
| const TensorLayout& src, const TensorLayout& dst, | |||
| const TensorLayout& indices) = 0; | |||
| protected: | |||
| void check_exec(const TensorLayout& src, const TensorLayout& dst, | |||
| const TensorLayout& indices, size_t workspace_in_bytes); | |||
| void check_exec( | |||
| const TensorLayout& src, const TensorLayout& dst, | |||
| const TensorLayout& indices, size_t workspace_in_bytes); | |||
| }; | |||
| using ShuffleRNG = ShuffleRNGForward; | |||
| @@ -120,27 +134,29 @@ class ShuffleRNGBackward : public OperatorBase { | |||
| DEF_OPR_PARAM(ShuffleRNG); | |||
| public: | |||
| virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in indices, | |||
| _megdnn_tensor_out grad, _megdnn_workspace workspace) = 0; | |||
| virtual size_t get_workspace_in_bytes(const TensorLayout& diff, | |||
| const TensorLayout& indices, | |||
| const TensorLayout& grad) = 0; | |||
| virtual void exec( | |||
| _megdnn_tensor_in diff, _megdnn_tensor_in indices, _megdnn_tensor_out grad, | |||
| _megdnn_workspace workspace) = 0; | |||
| virtual size_t get_workspace_in_bytes( | |||
| const TensorLayout& diff, const TensorLayout& indices, | |||
| const TensorLayout& grad) = 0; | |||
| protected: | |||
| void check_exec(const TensorLayout& diff, const TensorLayout& indices, | |||
| const TensorLayout& grad, size_t workspace_in_bytes); | |||
| void check_exec( | |||
| const TensorLayout& diff, const TensorLayout& indices, | |||
| const TensorLayout& grad, size_t workspace_in_bytes); | |||
| }; | |||
| /*! | |||
| * \brief sleep for specific time on the computing device; useful for testing | |||
| * async problems | |||
| */ | |||
| class SleepForward: public OperatorBase { | |||
| class SleepForward : public OperatorBase { | |||
| DEF_OPR_IMPL(SleepForward, OperatorBase, 0, 0); | |||
| DEF_OPR_PARAM(Sleep); | |||
| public: | |||
| virtual void exec() = 0; | |||
| public: | |||
| virtual void exec() = 0; | |||
| }; | |||
| using Sleep = SleepForward; | |||
| @@ -149,20 +165,19 @@ using Sleep = SleepForward; | |||
| * | |||
| * data must be a one-dimensional contiguous tensor with dtype byte | |||
| */ | |||
| class ChecksumForward: public OperatorBase { | |||
| class ChecksumForward : public OperatorBase { | |||
| DEF_OPR_PARAM(Empty); | |||
| DEF_OPR_IMPL(ChecksumForward, OperatorBase, 0, 1); | |||
| public: | |||
| using Result = opr_result::Checksum; | |||
| public: | |||
| using Result = opr_result::Checksum; | |||
| virtual size_t get_workspace_in_bytes(const TensorLayout &data) = 0; | |||
| virtual size_t get_workspace_in_bytes(const TensorLayout& data) = 0; | |||
| virtual Result exec(_megdnn_tensor_in data, | |||
| _megdnn_workspace workspace) = 0; | |||
| virtual Result exec(_megdnn_tensor_in data, _megdnn_workspace workspace) = 0; | |||
| protected: | |||
| void check_exec(const TensorLayout &layout, size_t workspace_in_bytes); | |||
| protected: | |||
| void check_exec(const TensorLayout& layout, size_t workspace_in_bytes); | |||
| }; | |||
| using Checksum = ChecksumForward; | |||
| @@ -175,21 +190,22 @@ class MaxTensorDiff : public OperatorBase { | |||
| DEF_OPR_PARAM(Empty); | |||
| DEF_OPR_IMPL(MaxTensorDiff, OperatorBase, 0, 2); | |||
| public: | |||
| virtual size_t get_workspace_in_bytes(const TensorLayout& layout1, | |||
| const TensorLayout& layout2) = 0; | |||
| public: | |||
| virtual size_t get_workspace_in_bytes( | |||
| const TensorLayout& layout1, const TensorLayout& layout2) = 0; | |||
| virtual float exec(_megdnn_tensor_in src1, _megdnn_tensor_in src2, | |||
| _megdnn_workspace workspace) = 0; | |||
| virtual float exec( | |||
| _megdnn_tensor_in src1, _megdnn_tensor_in src2, | |||
| _megdnn_workspace workspace) = 0; | |||
| protected: | |||
| void check_exec(const TensorLayout& layout1, | |||
| const TensorLayout& layout2, size_t workspace_in_bytes); | |||
| protected: | |||
| void check_exec( | |||
| const TensorLayout& layout1, const TensorLayout& layout2, | |||
| size_t workspace_in_bytes); | |||
| }; | |||
| bool check_bias_share_in_channel(const TensorLayout& bias, | |||
| const param::ConvBias::Format format); | |||
| bool check_bias_share_in_channel( | |||
| const TensorLayout& bias, const param::ConvBias::Format format); | |||
| } // namespace megdnn | |||
| @@ -18,9 +18,9 @@ | |||
| namespace megdnn { | |||
| enum class TensorFormat::Type { | |||
| DEFAULT = 0, //!< see DefaultTensorFormat | |||
| IMAGE2D_PACK4 = 1, //!< see Image2DPack4TensorFormat | |||
| LOWBITS_ALIGNED_TO_BYTE = 2, //!< | |||
| DEFAULT = 0, //!< see DefaultTensorFormat | |||
| IMAGE2D_PACK4 = 1, //!< see Image2DPack4TensorFormat | |||
| LOWBITS_ALIGNED_TO_BYTE = 2, //!< | |||
| }; | |||
| class TensorFormat::ImplBase { | |||
| @@ -33,8 +33,7 @@ public: | |||
| virtual bool is_contiguous_spec(const TensorLayout& layout) const = 0; | |||
| virtual TensorLayout collapse_contiguous_spec( | |||
| const TensorLayout& layout) const = 0; | |||
| virtual TensorLayout collapse_contiguous_spec(const TensorLayout& layout) const = 0; | |||
| virtual TensorLayout::Span span_spec(const TensorLayout& layout) const = 0; | |||
| @@ -79,8 +78,7 @@ public: | |||
| */ | |||
| bool is_contiguous_spec(const TensorLayout& layout) const override; | |||
| TensorLayout collapse_contiguous_spec( | |||
| const TensorLayout& layout) const override; | |||
| TensorLayout collapse_contiguous_spec(const TensorLayout& layout) const override; | |||
| TensorLayout::Span span_spec(const TensorLayout& layout) const override; | |||
| @@ -88,8 +86,7 @@ public: | |||
| void serialize_append(std::string& result) const override; | |||
| static TensorFormat make(); | |||
| static TensorFormat deserialize(const Handle* handle, const void* buf, | |||
| size_t size); | |||
| static TensorFormat deserialize(const Handle* handle, const void* buf, size_t size); | |||
| }; | |||
| namespace detail { | |||
| @@ -112,8 +109,8 @@ class Image2DTensorFormatBase : public TensorFormat::ImplBase { | |||
| size_t m_align_axis, m_align_size_in_elements_log2; | |||
| protected: | |||
| Image2DTensorFormatBase(Type type, size_t align_axis, | |||
| size_t align_size_in_elements); | |||
| Image2DTensorFormatBase( | |||
| Type type, size_t align_axis, size_t align_size_in_elements); | |||
| virtual ~Image2DTensorFormatBase() = default; | |||
| public: | |||
| @@ -129,9 +126,7 @@ public: | |||
| size_t align_axis() const { return m_align_axis; } | |||
| size_t align_size_in_elements_log2() const { | |||
| return m_align_size_in_elements_log2; | |||
| } | |||
| size_t align_size_in_elements_log2() const { return m_align_size_in_elements_log2; } | |||
| std::string to_string() const override; | |||
| @@ -145,6 +140,7 @@ public: | |||
| size_t image_height(const TensorLayout& layout) const; | |||
| void serialize_append(std::string& result) const override; | |||
| protected: | |||
| struct SerializePack { | |||
| uint8_t align_axis; | |||
| @@ -160,15 +156,14 @@ class Image2DPackedTensorFormatBase : public Image2DTensorFormatBase { | |||
| * align COUNT, but mdl needs align size in byte, which equal to | |||
| * (image_width algin count) * sizeof(data_type) * pixel_size | |||
| */ | |||
| size_t image_pitch_alignment_in_bytes(size_t align_size_in_elements, | |||
| const TensorLayout& layout) const; | |||
| size_t image_pitch_alignment_in_bytes( | |||
| size_t align_size_in_elements, const TensorLayout& layout) const; | |||
| protected: | |||
| Image2DPackedTensorFormatBase(Type type, size_t align_axis, | |||
| size_t align_size_in_elements, | |||
| Handle::HandleVendorType vendor_type) | |||
| : detail::Image2DTensorFormatBase(type, align_axis, | |||
| align_size_in_elements), | |||
| Image2DPackedTensorFormatBase( | |||
| Type type, size_t align_axis, size_t align_size_in_elements, | |||
| Handle::HandleVendorType vendor_type) | |||
| : detail::Image2DTensorFormatBase(type, align_axis, align_size_in_elements), | |||
| m_vendor_type(vendor_type) {} | |||
| virtual ~Image2DPackedTensorFormatBase() = default; | |||
| @@ -197,13 +192,12 @@ public: | |||
| bool is_contiguous_spec(const TensorLayout& layout) const override; | |||
| TensorLayout collapse_contiguous_spec( | |||
| const TensorLayout& layout) const override; | |||
| TensorLayout collapse_contiguous_spec(const TensorLayout& layout) const override; | |||
| }; | |||
| using Image2DPack4TensorFormatBase = Image2DPackedTensorFormatBase<4>; | |||
| /*! | |||
| * \brief used for tensors storing lowbit data | |||
| * \brief used for tensors storing lowbit data | |||
| * | |||
| * \param m_size_nbits size in bits of elements in the tensor | |||
| * \param m_align_size_in_bits aligned size in bits | |||
| @@ -213,14 +207,14 @@ class LowbitsAlignedTensorFormatBase : public TensorFormat::ImplBase { | |||
| size_t m_size_nbits, m_align_size_in_bits, m_align_size_in_elements; | |||
| protected: //? | |||
| LowbitsAlignedTensorFormatBase(Type type, size_t size_nbits, | |||
| size_t align_size_in_bits); | |||
| LowbitsAlignedTensorFormatBase( | |||
| Type type, size_t size_nbits, size_t align_size_in_bits); | |||
| virtual ~LowbitsAlignedTensorFormatBase() = default; | |||
| public: | |||
| size_t align_size_in_bits() const { return m_align_size_in_bits; } | |||
| size_t size_nbits() const { return m_size_nbits; } | |||
| std::string to_string() const override; | |||
| @@ -238,8 +232,8 @@ public: | |||
| bool is_contiguous_spec(const TensorLayout& layout) const override; | |||
| TensorLayout collapse_contiguous_spec( | |||
| const TensorLayout& layout) const override; | |||
| TensorLayout collapse_contiguous_spec(const TensorLayout& layout) const override; | |||
| protected: | |||
| struct SerializePack { | |||
| uint8_t size_nbits; | |||
| @@ -254,16 +248,14 @@ protected: | |||
| * | |||
| * This is used for OpenCL. | |||
| */ | |||
| class Image2DPack4TensorFormat final | |||
| : public detail::Image2DPack4TensorFormatBase { | |||
| class Image2DPack4TensorFormat final : public detail::Image2DPack4TensorFormatBase { | |||
| public: | |||
| static constexpr Type TYPE = Type::IMAGE2D_PACK4; | |||
| //! for internal usage or test purposes | |||
| static TensorFormat make_raw(size_t align_axis, | |||
| size_t align_size_in_elements, | |||
| Handle::HandleVendorType vendor_type = | |||
| Handle::HandleVendorType::NOT_SPEC); | |||
| static TensorFormat make_raw( | |||
| size_t align_axis, size_t align_size_in_elements, | |||
| Handle::HandleVendorType vendor_type = Handle::HandleVendorType::NOT_SPEC); | |||
| static TensorFormat make(size_t align_axis, const Handle* handle); | |||
| @@ -273,13 +265,11 @@ public: | |||
| * Note that the alignment may be different if deserialized on another | |||
| * handle | |||
| */ | |||
| static TensorFormat deserialize(const Handle* handle, const void* buf, | |||
| size_t size); | |||
| static TensorFormat deserialize(const Handle* handle, const void* buf, size_t size); | |||
| static bool is_valid_image(const TensorLayout& layout) { | |||
| if (layout.format.type() == TYPE) { | |||
| layout.format.as_impl<Image2DPack4TensorFormat>().assert_valid( | |||
| layout); | |||
| layout.format.as_impl<Image2DPack4TensorFormat>().assert_valid(layout); | |||
| return true; | |||
| } | |||
| return false; | |||
| @@ -288,8 +278,9 @@ public: | |||
| TensorFormat change_axis(size_t axis) const override; | |||
| private: | |||
| Image2DPack4TensorFormat(size_t align_axis, size_t align_size_in_elements, | |||
| Handle::HandleVendorType vendor_type) | |||
| Image2DPack4TensorFormat( | |||
| size_t align_axis, size_t align_size_in_elements, | |||
| Handle::HandleVendorType vendor_type) | |||
| : detail::Image2DPack4TensorFormatBase( | |||
| TYPE, align_axis, align_size_in_elements, vendor_type) {} | |||
| }; | |||
| @@ -306,13 +297,12 @@ public: | |||
| static TensorFormat make(size_t size_nbits); | |||
| static TensorFormat deserialize(const Handle* handle, const void* buf, | |||
| size_t size); | |||
| static TensorFormat deserialize(const Handle* handle, const void* buf, size_t size); | |||
| static bool is_valid_layout(const TensorLayout& layout) { | |||
| if (layout.format.type() == TYPE) { | |||
| layout.format.as_impl<LowbitsAlignedToBytesTensorFormat>() | |||
| .assert_valid(layout); | |||
| layout.format.as_impl<LowbitsAlignedToBytesTensorFormat>().assert_valid( | |||
| layout); | |||
| return true; | |||
| } | |||
| return false; | |||
| @@ -320,8 +310,7 @@ public: | |||
| private: | |||
| LowbitsAlignedToBytesTensorFormat(size_t size_nbits) | |||
| : detail::LowbitsAlignedTensorFormatBase(TYPE, size_nbits, | |||
| BYTE_IN_BITS) {} | |||
| : detail::LowbitsAlignedTensorFormatBase(TYPE, size_nbits, BYTE_IN_BITS) {} | |||
| }; | |||
| } // namespace megdnn | |||
| @@ -167,13 +167,11 @@ public: | |||
| TensorIter(const TensorND& tensor) : m_tensor(tensor) {} | |||
| Iter begin() const { | |||
| return Iter::make(const_cast<TensorND&>(m_tensor), 0); | |||
| } | |||
| Iter begin() const { return Iter::make(const_cast<TensorND&>(m_tensor), 0); } | |||
| Iter end() const { | |||
| return Iter::make(const_cast<TensorND&>(m_tensor), | |||
| m_tensor.layout.total_nr_elems()); | |||
| return Iter::make( | |||
| const_cast<TensorND&>(m_tensor), m_tensor.layout.total_nr_elems()); | |||
| } | |||
| }; | |||
| /*! | |||
| @@ -11,19 +11,19 @@ | |||
| #pragma once | |||
| #include <type_traits> | |||
| #include <cstdlib> | |||
| #include <functional> | |||
| #include <utility> | |||
| #include <memory> | |||
| #include <cstdlib> | |||
| #include <type_traits> | |||
| #include <utility> | |||
| #include "megdnn/internal/visibility_prologue.h" | |||
| namespace megdnn { | |||
| template<typename Signature> | |||
| template <typename Signature> | |||
| using thin_function = ::std::function<Signature>; | |||
| } // namespace megdnn | |||
| } // namespace megdnn | |||
| #include "megdnn/internal/visibility_epilogue.h" | |||
| @@ -58,18 +58,16 @@ protected: | |||
| m_end_ptr(first_elm), | |||
| m_capacity_ptr(static_cast<char*>(first_elm) + size) {} | |||
| void grow_pod(void* first_elm_ptr, size_t min_sz_in_bytes, | |||
| size_t type_size); | |||
| void grow_pod(void* first_elm_ptr, size_t min_sz_in_bytes, size_t type_size); | |||
| public: | |||
| size_t size_in_bytes() const { | |||
| return size_t(static_cast<char*>(m_end_ptr) - | |||
| static_cast<char*>(m_begin_ptr)); | |||
| return size_t(static_cast<char*>(m_end_ptr) - static_cast<char*>(m_begin_ptr)); | |||
| } | |||
| size_t capacity_in_bytes() const { | |||
| return size_t(static_cast<char*>(m_capacity_ptr) - | |||
| static_cast<char*>(m_begin_ptr)); | |||
| return size_t( | |||
| static_cast<char*>(m_capacity_ptr) - static_cast<char*>(m_begin_ptr)); | |||
| } | |||
| bool empty() const { return m_begin_ptr == m_end_ptr; } | |||
| @@ -85,20 +83,15 @@ private: | |||
| U m_first_elm; | |||
| protected: | |||
| SmallVectorTemplateCommon(size_t size) | |||
| : SmallVectorBase(&m_first_elm, size) {} | |||
| SmallVectorTemplateCommon(size_t size) : SmallVectorBase(&m_first_elm, size) {} | |||
| void grow_pod(size_t min_sz_in_bytes, size_t type_size) { | |||
| SmallVectorBase::grow_pod(&m_first_elm, min_sz_in_bytes, type_size); | |||
| } | |||
| bool is_small() { | |||
| return m_begin_ptr == static_cast<const void*>(&m_first_elm); | |||
| } | |||
| bool is_small() { return m_begin_ptr == static_cast<const void*>(&m_first_elm); } | |||
| void reset_to_small() { | |||
| m_begin_ptr = m_end_ptr = m_capacity_ptr = &m_first_elm; | |||
| } | |||
| void reset_to_small() { m_begin_ptr = m_end_ptr = m_capacity_ptr = &m_first_elm; } | |||
| void set_end(T* p) { m_end_ptr = p; } | |||
| @@ -128,20 +121,12 @@ protected: | |||
| public: | |||
| // forwarding iterator creation | |||
| iterator begin() { return static_cast<iterator>(m_begin_ptr); } | |||
| const_iterator begin() const { | |||
| return static_cast<const_iterator>(m_begin_ptr); | |||
| } | |||
| const_iterator cbegin() const { | |||
| return static_cast<const_iterator>(m_begin_ptr); | |||
| } | |||
| const_iterator begin() const { return static_cast<const_iterator>(m_begin_ptr); } | |||
| const_iterator cbegin() const { return static_cast<const_iterator>(m_begin_ptr); } | |||
| iterator end() { return static_cast<iterator>(m_end_ptr); } | |||
| const_iterator end() const { | |||
| return static_cast<const_iterator>(m_end_ptr); | |||
| } | |||
| const_iterator cend() const { | |||
| return static_cast<const_iterator>(m_end_ptr); | |||
| } | |||
| const_iterator end() const { return static_cast<const_iterator>(m_end_ptr); } | |||
| const_iterator cend() const { return static_cast<const_iterator>(m_end_ptr); } | |||
| reference at(size_type idx) { | |||
| if (idx >= size()) { | |||
| @@ -167,13 +152,9 @@ public: | |||
| // reverse iterator creation method. | |||
| reverse_iterator rbegin() { return reverse_iterator(end()); } | |||
| const_reverse_iterator rbegin() const { | |||
| return const_reverse_iterator(end()); | |||
| } | |||
| const_reverse_iterator rbegin() const { return const_reverse_iterator(end()); } | |||
| reverse_iterator rend() { return reverse_iterator(begin()); } | |||
| const_reverse_iterator rend() const { | |||
| return const_reverse_iterator(begin()); | |||
| } | |||
| const_reverse_iterator rend() const { return const_reverse_iterator(begin()); } | |||
| pointer data() { return pointer(begin()); } | |||
| const_pointer data() const { return const_pointer(begin()); } | |||
| @@ -207,8 +188,8 @@ protected: | |||
| template <typename It1, typename It2> | |||
| static void uninitialized_move(It1 first, It1 last, It2 dest) { | |||
| std::uninitialized_copy(std::make_move_iterator(first), | |||
| std::make_move_iterator(last), dest); | |||
| std::uninitialized_copy( | |||
| std::make_move_iterator(first), std::make_move_iterator(last), dest); | |||
| } | |||
| template <typename It1, typename It2> | |||
| @@ -293,9 +274,7 @@ protected: | |||
| memcpy(dest, first, (last - first) * sizeof(T)); | |||
| } | |||
| void grow(size_t min_sz = 0) { | |||
| this->grow_pod(min_sz * sizeof(T), sizeof(T)); | |||
| } | |||
| void grow(size_t min_sz = 0) { this->grow_pod(min_sz * sizeof(T), sizeof(T)); } | |||
| public: | |||
| void push_back(const T& _elm) { | |||
| @@ -318,8 +297,7 @@ public: | |||
| * SmallVector<T, N> can be converted to SmallVectorImpl<T> to erase N | |||
| */ | |||
| template <typename T> | |||
| class SmallVectorImpl | |||
| : public SmallVectorTemplateBase<T, std::is_pod<T>::value> { | |||
| class SmallVectorImpl : public SmallVectorTemplateBase<T, std::is_pod<T>::value> { | |||
| using SuperClass = SmallVectorTemplateBase<T, std::is_pod<T>::value>; | |||
| public: | |||
| @@ -329,8 +307,7 @@ public: | |||
| protected: | |||
| explicit SmallVectorImpl(unsigned n) | |||
| : SmallVectorTemplateBase<T, std::is_pod<T>::value>(n * sizeof(T)) { | |||
| } | |||
| : SmallVectorTemplateBase<T, std::is_pod<T>::value>(n * sizeof(T)) {} | |||
| public: | |||
| SmallVectorImpl(const SmallVectorImpl&) = delete; | |||
| @@ -354,8 +331,7 @@ public: | |||
| } else if (n > this->size()) { | |||
| if (this->capacity() < n) | |||
| this->grow(n); | |||
| for (auto it = this->end(), end = this->begin() + n; it != end; | |||
| ++it) | |||
| for (auto it = this->end(), end = this->begin() + n; it != end; ++it) | |||
| new (&*it) T(); | |||
| this->set_end(this->begin() + n); | |||
| } | |||
| @@ -389,10 +365,11 @@ public: | |||
| void swap(SmallVectorImpl<T>& rhs); | |||
| /// Add the specified range to the end of the SmallVector. | |||
| template <typename in_iter, | |||
| typename = typename std::enable_if<std::is_convertible< | |||
| typename std::iterator_traits<in_iter>::iterator_category, | |||
| std::input_iterator_tag>::value>::type> | |||
| template < | |||
| typename in_iter, | |||
| typename = typename std::enable_if<std::is_convertible< | |||
| typename std::iterator_traits<in_iter>::iterator_category, | |||
| std::input_iterator_tag>::value>::type> | |||
| void append(in_iter in_start, in_iter in_end) { | |||
| size_type num_inputs = std::distance(in_start, in_end); | |||
| // Grow allocated space if needed. | |||
| @@ -432,10 +409,11 @@ public: | |||
| std::uninitialized_fill(this->begin(), this->end(), elm); | |||
| } | |||
| template <typename in_iter, | |||
| typename = typename std::enable_if<std::is_convertible< | |||
| typename std::iterator_traits<in_iter>::iterator_category, | |||
| std::input_iterator_tag>::value>::type> | |||
| template < | |||
| typename in_iter, | |||
| typename = typename std::enable_if<std::is_convertible< | |||
| typename std::iterator_traits<in_iter>::iterator_category, | |||
| std::input_iterator_tag>::value>::type> | |||
| void assign(in_iter in_start, in_iter in_end) { | |||
| clear(); | |||
| append(in_start, in_end); | |||
| @@ -571,8 +549,7 @@ public: | |||
| std::fill_n(it, num_overwritten, elm); | |||
| // Insert the non-overwritten middle part. | |||
| std::uninitialized_fill_n(old_end, num_to_insert - num_overwritten, | |||
| elm); | |||
| std::uninitialized_fill_n(old_end, num_to_insert - num_overwritten, elm); | |||
| return it; | |||
| } | |||
| @@ -646,8 +623,7 @@ public: | |||
| if (megdnn_unlikely(this->m_end_ptr >= this->m_capacity_ptr)) { | |||
| this->grow(); | |||
| } | |||
| new (static_cast<void*>(this->end())) | |||
| T(std::forward<ArgTypes>(args)...); | |||
| new (static_cast<void*>(this->end())) T(std::forward<ArgTypes>(args)...); | |||
| this->set_end(this->end() + 1); | |||
| } | |||
| @@ -661,13 +637,11 @@ public: | |||
| return std::equal(this->begin(), this->end(), rhs.begin()); | |||
| } | |||
| bool operator!=(const SmallVectorImpl<T>& rhs) const { | |||
| return !(*this == rhs); | |||
| } | |||
| bool operator!=(const SmallVectorImpl<T>& rhs) const { return !(*this == rhs); } | |||
| bool operator<(const SmallVectorImpl<T>& rhs) const { | |||
| return std::lexicographical_compare(this->begin(), this->end(), | |||
| rhs.begin(), rhs.end()); | |||
| return std::lexicographical_compare( | |||
| this->begin(), this->end(), rhs.begin(), rhs.end()); | |||
| } | |||
| }; | |||
| @@ -698,15 +672,13 @@ void SmallVectorImpl<T>::swap(SmallVectorImpl<T>& rhs) { | |||
| // Copy over the extra elms. | |||
| if (this->size() > rhs.size()) { | |||
| size_t elm_diff = this->size() - rhs.size(); | |||
| this->uninitialized_move(this->begin() + num_shared, this->end(), | |||
| rhs.end()); | |||
| this->uninitialized_move(this->begin() + num_shared, this->end(), rhs.end()); | |||
| rhs.set_end(rhs.end() + elm_diff); | |||
| this->destroy_range(this->begin() + num_shared, this->end()); | |||
| this->set_end(this->begin() + num_shared); | |||
| } else if (rhs.size() > this->size()) { | |||
| size_t elm_diff = rhs.size() - this->size(); | |||
| this->uninitialized_move(rhs.begin() + num_shared, rhs.end(), | |||
| this->end()); | |||
| this->uninitialized_move(rhs.begin() + num_shared, rhs.end(), this->end()); | |||
| this->set_end(this->end() + elm_diff); | |||
| this->destroy_range(rhs.begin() + num_shared, rhs.end()); | |||
| rhs.set_end(rhs.begin() + num_shared); | |||
| @@ -714,8 +686,7 @@ void SmallVectorImpl<T>::swap(SmallVectorImpl<T>& rhs) { | |||
| } | |||
| template <typename T> | |||
| SmallVectorImpl<T>& SmallVectorImpl<T>::operator=( | |||
| const SmallVectorImpl<T>& rhs) { | |||
| SmallVectorImpl<T>& SmallVectorImpl<T>::operator=(const SmallVectorImpl<T>& rhs) { | |||
| if (this == &rhs) | |||
| return *this; | |||
| size_t rhs_sz = rhs.size(); | |||
| @@ -740,8 +711,7 @@ SmallVectorImpl<T>& SmallVectorImpl<T>::operator=( | |||
| } else if (cur_sz) { | |||
| std::copy(rhs.begin(), rhs.begin() + cur_sz, this->begin()); | |||
| } | |||
| std::uninitialized_copy(rhs.begin() + cur_sz, rhs.end(), | |||
| this->begin() + cur_sz); | |||
| std::uninitialized_copy(rhs.begin() + cur_sz, rhs.end(), this->begin() + cur_sz); | |||
| this->set_end(this->begin() + rhs_sz); | |||
| return *this; | |||
| } | |||
| @@ -785,8 +755,7 @@ SmallVectorImpl<T>& SmallVectorImpl<T>::operator=(SmallVectorImpl<T>&& rhs) { | |||
| std::move(rhs.begin(), rhs.begin() + cur_sz, this->begin()); | |||
| } | |||
| this->uninitialized_move(rhs.begin() + cur_sz, rhs.end(), | |||
| this->begin() + cur_sz); | |||
| this->uninitialized_move(rhs.begin() + cur_sz, rhs.end(), this->begin() + cur_sz); | |||
| this->set_end(this->begin() + rhs_sz); | |||
| @@ -826,8 +795,7 @@ class SmallVector : public SmallVectorImpl<T> { | |||
| public: | |||
| SmallVector() : SmallVectorImpl<T>(N) {} | |||
| explicit SmallVector(size_t size, const T& value = T()) | |||
| : SmallVectorImpl<T>(N) { | |||
| explicit SmallVector(size_t size, const T& value = T()) : SmallVectorImpl<T>(N) { | |||
| this->assign(size, value); | |||
| } | |||
| @@ -901,15 +869,13 @@ namespace std { | |||
| /// Implement std::swap in terms of SmallVector swap. | |||
| template <typename T> | |||
| inline void swap(megdnn::SmallVectorImpl<T>& lhs, | |||
| megdnn::SmallVectorImpl<T>& rhs) { | |||
| inline void swap(megdnn::SmallVectorImpl<T>& lhs, megdnn::SmallVectorImpl<T>& rhs) { | |||
| lhs.swap(rhs); | |||
| } | |||
| /// Implement std::swap in terms of SmallVector swap. | |||
| template <typename T, unsigned N> | |||
| inline void swap(megdnn::SmallVector<T, N>& lhs, | |||
| megdnn::SmallVector<T, N>& rhs) { | |||
| inline void swap(megdnn::SmallVector<T, N>& lhs, megdnn::SmallVector<T, N>& rhs) { | |||
| lhs.swap(rhs); | |||
| } | |||
| } // end namespace std | |||
| @@ -17,13 +17,13 @@ | |||
| #include "megdnn/internal/visibility_prologue.h" | |||
| namespace megdnn { | |||
| struct Version { | |||
| int major, minor, patch; | |||
| }; | |||
| struct Version { | |||
| int major, minor, patch; | |||
| }; | |||
| //! get megdnn version of the binary | |||
| Version get_version(); | |||
| } | |||
| //! get megdnn version of the binary | |||
| Version get_version(); | |||
| } // namespace megdnn | |||
| #include "megdnn/internal/visibility_epilogue.h" | |||
| @@ -22,18 +22,17 @@ using namespace aarch64; | |||
| /* ===================== stride-2 algo ===================== */ | |||
| MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp16) | |||
| bool ConvBiasImpl::AlgoF16DirectStride2::usable(const NCBKernSizeParam& param, | |||
| AlgoSelectionStrategy) const { | |||
| bool ConvBiasImpl::AlgoF16DirectStride2::usable( | |||
| const NCBKernSizeParam& param, AlgoSelectionStrategy) const { | |||
| MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 0) { | |||
| auto&& fm = param.filter_meta; | |||
| auto FH = fm.spatial[0]; | |||
| return param.filter_meta.format == param::Convolution::Format::NCHW && | |||
| param.src_type.enumv() == DTypeEnum::Float16 && | |||
| param.filter_type.enumv() == DTypeEnum::Float16 && | |||
| param.dst_type.enumv() == DTypeEnum::Float16 && | |||
| !fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||
| fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && | |||
| FH == fm.spatial[1] && | |||
| param.dst_type.enumv() == DTypeEnum::Float16 && !fm.should_flip && | |||
| fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||
| fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && | |||
| (FH == 2 || FH == 3 || FH == 5 || FH == 7); | |||
| } | |||
| MIDOUT_END(); | |||
| @@ -52,8 +51,7 @@ size_t ConvBiasImpl::AlgoF16DirectStride2::get_workspace( | |||
| return 0; | |||
| } | |||
| SmallVector<ConvBiasImpl::NCBKern> | |||
| ConvBiasImpl::AlgoF16DirectStride2::dispatch_kerns( | |||
| SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16DirectStride2::dispatch_kerns( | |||
| const NCBKernSizeParam& param) const { | |||
| MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp16, 0, 2) { | |||
| return get_kimpls(param); | |||
| @@ -62,8 +60,7 @@ ConvBiasImpl::AlgoF16DirectStride2::dispatch_kerns( | |||
| return {}; | |||
| } | |||
| SmallVector<ConvBiasImpl::NCBKern> | |||
| ConvBiasImpl::AlgoF16DirectStride2::get_kimpls( | |||
| SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF16DirectStride2::get_kimpls( | |||
| const NCBKernSizeParam& param) const { | |||
| auto fm = param.filter_meta; | |||
| auto FH = fm.spatial[0]; | |||
| @@ -72,8 +69,9 @@ ConvBiasImpl::AlgoF16DirectStride2::get_kimpls( | |||
| size_t OC = param.filter_meta.ocpg; | |||
| size_t group = fm.group; | |||
| bool large_group = group >= param.nr_threads; | |||
| using Func = std::function<void(const __fp16*, const __fp16*, __fp16*, | |||
| size_t, size_t, size_t, size_t, size_t)>; | |||
| using Func = std::function<void( | |||
| const __fp16*, const __fp16*, __fp16*, size_t, size_t, size_t, size_t, | |||
| size_t)>; | |||
| Func conv = nullptr; | |||
| if (FH == 2) { | |||
| conv = fp16::conv_stride2::do_conv_2x2_stride2; | |||
| @@ -101,31 +99,35 @@ ConvBiasImpl::AlgoF16DirectStride2::get_kimpls( | |||
| bundle.set(kern_param.workspace_ptr); | |||
| for (size_t ic = 0; ic < IC; ic++) { | |||
| arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: | |||
| copy_padding_kern_stride(bundle, kern_param, ncb_index, | |||
| {ncb_index.thread_id, 0, ic}); | |||
| copy_padding_kern_stride( | |||
| bundle, kern_param, ncb_index, | |||
| {ncb_index.thread_id, 0, ic}); | |||
| } | |||
| for (size_t oc = 0; oc < OC; oc++) { | |||
| arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: | |||
| do_conv_kern_stride(bundle, kern_param, ncb_index, conv, | |||
| {ncb_index.thread_id, 0, oc}); | |||
| do_conv_kern_stride( | |||
| bundle, kern_param, ncb_index, conv, | |||
| {ncb_index.thread_id, 0, oc}); | |||
| } | |||
| }; | |||
| ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); | |||
| } else { | |||
| auto copy_padding = [bundle](const NCBKernParam& kern_param, | |||
| const NCBKernIndex& ncb_index) mutable { | |||
| auto copy_padding = [bundle]( | |||
| const NCBKernParam& kern_param, | |||
| const NCBKernIndex& ncb_index) mutable { | |||
| bundle.set(kern_param.workspace_ptr); | |||
| arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: | |||
| copy_padding_kern_stride(bundle, kern_param, ncb_index, | |||
| ncb_index.ndrange_id); | |||
| copy_padding_kern_stride( | |||
| bundle, kern_param, ncb_index, ncb_index.ndrange_id); | |||
| }; | |||
| ret_kerns.push_back({copy_padding, {group, N, IC}}); | |||
| auto do_conv = [bundle, conv](const NCBKernParam& kern_param, | |||
| const NCBKernIndex& ncb_index) mutable { | |||
| auto do_conv = [bundle, conv]( | |||
| const NCBKernParam& kern_param, | |||
| const NCBKernIndex& ncb_index) mutable { | |||
| bundle.set(kern_param.workspace_ptr); | |||
| arm_common::MultithreadDirectConvCommon<dt_float16, __fp16>:: | |||
| do_conv_kern_stride(bundle, kern_param, ncb_index, conv, | |||
| ncb_index.ndrange_id); | |||
| do_conv_kern_stride( | |||
| bundle, kern_param, ncb_index, conv, ncb_index.ndrange_id); | |||
| }; | |||
| ret_kerns.push_back({do_conv, {group, N, OC}}); | |||
| } | |||
| @@ -18,13 +18,13 @@ namespace aarch64 { | |||
| /* ===================== stride-2 algo ===================== */ | |||
| class ConvBiasImpl::AlgoF16DirectStride2 final : public AlgoBase { | |||
| SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| } | |||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
| const char* name() const override { return "ARMV8F16STRD2"; } | |||
| bool usable(const NCBKernSizeParam& param, | |||
| AlgoSelectionStrategy algo_selection_strategy) const override; | |||
| bool usable( | |||
| const NCBKernSizeParam& param, | |||
| AlgoSelectionStrategy algo_selection_strategy) const override; | |||
| size_t get_workspace(const NCBKernSizeParam& param) const override; | |||
| @@ -20,9 +20,9 @@ namespace aarch64 { | |||
| namespace fp16 { | |||
| namespace conv_stride2 { | |||
| static void do_conv_2x2_stride2(const __fp16* src, const __fp16* filter, | |||
| __fp16* dst, size_t IH, size_t IW, size_t OH, | |||
| size_t OW, size_t IC) { | |||
| static void do_conv_2x2_stride2( | |||
| const __fp16* src, const __fp16* filter, __fp16* dst, size_t IH, size_t IW, | |||
| size_t OH, size_t OW, size_t IC) { | |||
| const size_t tail_step = IW - 2 * OW + IW; | |||
| size_t width = OW >> 3; | |||
| size_t mod4_left = width & 3; | |||
| @@ -162,10 +162,9 @@ static void do_conv_2x2_stride2(const __fp16* src, const __fp16* filter, | |||
| "5: \n" | |||
| : "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1) | |||
| : "r"(mod4_left), "w"(_k0123) | |||
| : "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", | |||
| "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||
| "v15", "v16", "v17", "v18", "v19", "v28", "v29", "v30", | |||
| "v31"); | |||
| : "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", | |||
| "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", | |||
| "v17", "v18", "v19", "v28", "v29", "v30", "v31"); | |||
| r0 += tail_step; | |||
| r1 += tail_step; | |||
| @@ -175,9 +174,9 @@ static void do_conv_2x2_stride2(const __fp16* src, const __fp16* filter, | |||
| } | |||
| } | |||
| static void do_conv_3x3_stride2(const __fp16* src, const __fp16* filter, | |||
| __fp16* dst, size_t IH, size_t IW, size_t OH, | |||
| size_t OW, size_t IC) { | |||
| static void do_conv_3x3_stride2( | |||
| const __fp16* src, const __fp16* filter, __fp16* dst, size_t IH, size_t IW, | |||
| size_t OH, size_t OW, size_t IC) { | |||
| const size_t tail_step = IW - 2 * OW + IW; | |||
| size_t width = OW >> 3; | |||
| size_t mod3_left = width % 3; | |||
| @@ -352,10 +351,10 @@ static void do_conv_3x3_stride2(const __fp16* src, const __fp16* filter, | |||
| "3: \n" | |||
| : "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2) | |||
| : "r"(mod3_left), "w"(_k0123), "w"(_k3456), "w"(_k5678) | |||
| : "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", | |||
| "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||
| "v15", "v16", "v17", "v18", "v21", "v22", "v23", "v24", | |||
| "v25", "v26", "v27", "v28", "v29"); | |||
| : "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", | |||
| "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", | |||
| "v17", "v18", "v21", "v22", "v23", "v24", "v25", "v26", "v27", | |||
| "v28", "v29"); | |||
| r0 += tail_step; | |||
| r1 += tail_step; | |||
| @@ -366,9 +365,9 @@ static void do_conv_3x3_stride2(const __fp16* src, const __fp16* filter, | |||
| } | |||
| } | |||
| static void do_conv_5x5_stride2(const __fp16* src, const __fp16* filter, | |||
| __fp16* dst, size_t IH, size_t IW, size_t OH, | |||
| size_t OW, size_t IC) { | |||
| static void do_conv_5x5_stride2( | |||
| const __fp16* src, const __fp16* filter, __fp16* dst, size_t IH, size_t IW, | |||
| size_t OH, size_t OW, size_t IC) { | |||
| const size_t tail_step = IW - 2 * OW + IW; | |||
| size_t width = OW >> 3; | |||
| size_t mod2_left = width & 1; | |||
| @@ -384,18 +383,12 @@ static void do_conv_5x5_stride2(const __fp16* src, const __fp16* filter, | |||
| const __fp16* r4 = src_ptr + IW * 4; | |||
| register MEGDNN_SIMD_TYPE _k0123 asm("v0") = MEGDNN_SIMD_LOADU(filter); | |||
| register MEGDNN_SIMD_TYPE _k4567 asm("v1") = | |||
| MEGDNN_SIMD_LOADU(filter + 4); | |||
| register MEGDNN_SIMD_TYPE _k891011 asm("v2") = | |||
| MEGDNN_SIMD_LOADU(filter + 8); | |||
| register MEGDNN_SIMD_TYPE _k12131415 asm("v3") = | |||
| MEGDNN_SIMD_LOADU(filter + 12); | |||
| register MEGDNN_SIMD_TYPE _k16171819 asm("v4") = | |||
| MEGDNN_SIMD_LOADU(filter + 16); | |||
| register MEGDNN_SIMD_TYPE _k20212223 asm("v5") = | |||
| MEGDNN_SIMD_LOADU(filter + 20); | |||
| register MEGDNN_SIMD_TYPE _k24242424 asm("v6") = | |||
| MEGDNN_SIMD_SET1(filter[24]); | |||
| register MEGDNN_SIMD_TYPE _k4567 asm("v1") = MEGDNN_SIMD_LOADU(filter + 4); | |||
| register MEGDNN_SIMD_TYPE _k891011 asm("v2") = MEGDNN_SIMD_LOADU(filter + 8); | |||
| register MEGDNN_SIMD_TYPE _k12131415 asm("v3") = MEGDNN_SIMD_LOADU(filter + 12); | |||
| register MEGDNN_SIMD_TYPE _k16171819 asm("v4") = MEGDNN_SIMD_LOADU(filter + 16); | |||
| register MEGDNN_SIMD_TYPE _k20212223 asm("v5") = MEGDNN_SIMD_LOADU(filter + 20); | |||
| register MEGDNN_SIMD_TYPE _k24242424 asm("v6") = MEGDNN_SIMD_SET1(filter[24]); | |||
| for (size_t i = 0; i < OH; i++) { | |||
| asm volatile( | |||
| @@ -592,15 +585,14 @@ static void do_conv_5x5_stride2(const __fp16* src, const __fp16* filter, | |||
| "bne 2b \n" | |||
| "3: \n" | |||
| : "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), | |||
| "+r"(r3), "+r"(r4) | |||
| : "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3), | |||
| "+r"(r4) | |||
| : "w"(_k0123), "w"(_k4567), "w"(_k891011), "w"(_k12131415), | |||
| "w"(_k16171819), "w"(_k20212223), "w"(_k24242424), | |||
| "r"(mod2_left) | |||
| : "cc", "memory", "x1", "v7", "v8", "v9", "v10", "v11", | |||
| "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||
| "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", | |||
| "v28", "v29", "v30", "v31"); | |||
| "w"(_k16171819), "w"(_k20212223), "w"(_k24242424), "r"(mod2_left) | |||
| : "cc", "memory", "x1", "v7", "v8", "v9", "v10", "v11", "v12", | |||
| "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||
| "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", | |||
| "v31"); | |||
| r0 += tail_step; | |||
| r1 += tail_step; | |||
| @@ -613,9 +605,9 @@ static void do_conv_5x5_stride2(const __fp16* src, const __fp16* filter, | |||
| } | |||
| } | |||
| static void do_conv_7x7_stride2(const __fp16* src, const __fp16* filter, | |||
| __fp16* dst, size_t IH, size_t IW, size_t OH, | |||
| size_t OW, size_t IC) { | |||
| static void do_conv_7x7_stride2( | |||
| const __fp16* src, const __fp16* filter, __fp16* dst, size_t IH, size_t IW, | |||
| size_t OH, size_t OW, size_t IC) { | |||
| const size_t tail_step = IW - 2 * OW + IW; | |||
| size_t width = OW >> 3; | |||
| @@ -632,30 +624,20 @@ static void do_conv_7x7_stride2(const __fp16* src, const __fp16* filter, | |||
| const __fp16* r6 = src_ptr + IW * 6; | |||
| register MEGDNN_SIMD_TYPE _k0123 asm("v0") = MEGDNN_SIMD_LOADU(filter); | |||
| register MEGDNN_SIMD_TYPE _k4567 asm("v1") = | |||
| MEGDNN_SIMD_LOADU(filter + 4); | |||
| register MEGDNN_SIMD_TYPE _k891011 asm("v2") = | |||
| MEGDNN_SIMD_LOADU(filter + 8); | |||
| register MEGDNN_SIMD_TYPE _k12131415 asm("v3") = | |||
| MEGDNN_SIMD_LOADU(filter + 12); | |||
| register MEGDNN_SIMD_TYPE _k16171819 asm("v4") = | |||
| MEGDNN_SIMD_LOADU(filter + 16); | |||
| register MEGDNN_SIMD_TYPE _k20212223 asm("v5") = | |||
| MEGDNN_SIMD_LOADU(filter + 20); | |||
| register MEGDNN_SIMD_TYPE _k24252627 asm("v6") = | |||
| MEGDNN_SIMD_LOADU(filter + 24); | |||
| register MEGDNN_SIMD_TYPE _k28293031 asm("v7") = | |||
| MEGDNN_SIMD_LOADU(filter + 28); | |||
| register MEGDNN_SIMD_TYPE _k32333435 asm("v8") = | |||
| MEGDNN_SIMD_LOADU(filter + 32); | |||
| register MEGDNN_SIMD_TYPE _k36373839 asm("v9") = | |||
| MEGDNN_SIMD_LOADU(filter + 36); | |||
| register MEGDNN_SIMD_TYPE _k4567 asm("v1") = MEGDNN_SIMD_LOADU(filter + 4); | |||
| register MEGDNN_SIMD_TYPE _k891011 asm("v2") = MEGDNN_SIMD_LOADU(filter + 8); | |||
| register MEGDNN_SIMD_TYPE _k12131415 asm("v3") = MEGDNN_SIMD_LOADU(filter + 12); | |||
| register MEGDNN_SIMD_TYPE _k16171819 asm("v4") = MEGDNN_SIMD_LOADU(filter + 16); | |||
| register MEGDNN_SIMD_TYPE _k20212223 asm("v5") = MEGDNN_SIMD_LOADU(filter + 20); | |||
| register MEGDNN_SIMD_TYPE _k24252627 asm("v6") = MEGDNN_SIMD_LOADU(filter + 24); | |||
| register MEGDNN_SIMD_TYPE _k28293031 asm("v7") = MEGDNN_SIMD_LOADU(filter + 28); | |||
| register MEGDNN_SIMD_TYPE _k32333435 asm("v8") = MEGDNN_SIMD_LOADU(filter + 32); | |||
| register MEGDNN_SIMD_TYPE _k36373839 asm("v9") = MEGDNN_SIMD_LOADU(filter + 36); | |||
| register MEGDNN_SIMD_TYPE _k40414243 asm("v10") = | |||
| MEGDNN_SIMD_LOADU(filter + 40); | |||
| register MEGDNN_SIMD_TYPE _k44454647 asm("v11") = | |||
| MEGDNN_SIMD_LOADU(filter + 44); | |||
| register MEGDNN_SIMD_TYPE _k48484848 asm("v12") = | |||
| MEGDNN_SIMD_SET1(filter[48]); | |||
| register MEGDNN_SIMD_TYPE _k48484848 asm("v12") = MEGDNN_SIMD_SET1(filter[48]); | |||
| for (size_t i = 0; i < OH; i++) { | |||
| asm volatile( | |||
| @@ -1005,16 +987,15 @@ static void do_conv_7x7_stride2(const __fp16* src, const __fp16* filter, | |||
| "bne 2b \n" | |||
| "3: \n" | |||
| : "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3), | |||
| "+r"(r4), "+r"(r5), "+r"(r6) | |||
| : "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3), "+r"(r4), | |||
| "+r"(r5), "+r"(r6) | |||
| : "r"(width), "w"(_k0123), "w"(_k4567), "w"(_k891011), | |||
| "w"(_k12131415), "w"(_k16171819), "w"(_k20212223), | |||
| "w"(_k24252627), "w"(_k28293031), "w"(_k32333435), | |||
| "w"(_k36373839), "w"(_k40414243), "w"(_k44454647), | |||
| "w"(_k48484848) | |||
| : "cc", "memory", "x1", "v13", "v14", "v15", "v16", "v17", | |||
| "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", | |||
| "v26", "v27", "v28", "v29", "v30", "v31"); | |||
| "w"(_k36373839), "w"(_k40414243), "w"(_k44454647), "w"(_k48484848) | |||
| : "cc", "memory", "x1", "v13", "v14", "v15", "v16", "v17", "v18", | |||
| "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", | |||
| "v28", "v29", "v30", "v31"); | |||
| r0 += tail_step; | |||
| r1 += tail_step; | |||
| @@ -21,18 +21,17 @@ using namespace megdnn; | |||
| using namespace aarch64; | |||
| MIDOUT_DECL(megdnn_aarch64_conv_bias_stride2_conv2357_fp32) | |||
| bool ConvBiasImpl::AlgoF32DirectStride2::usable(const NCBKernSizeParam& param, | |||
| AlgoSelectionStrategy) const { | |||
| bool ConvBiasImpl::AlgoF32DirectStride2::usable( | |||
| const NCBKernSizeParam& param, AlgoSelectionStrategy) const { | |||
| MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 0) { | |||
| auto&& fm = param.filter_meta; | |||
| auto FH = fm.spatial[0]; | |||
| return param.filter_meta.format == param::ConvBias::Format::NCHW && | |||
| param.src_type.enumv() == DTypeEnum::Float32 && | |||
| param.filter_type.enumv() == DTypeEnum::Float32 && | |||
| param.dst_type.enumv() == DTypeEnum::Float32 && | |||
| !fm.should_flip && fm.spatial_ndim == 2 && fm.dilation[0] == 1 && | |||
| fm.dilation[1] == 1 && fm.stride[0] == 2 && fm.stride[1] == 2 && | |||
| FH == fm.spatial[1] && | |||
| param.dst_type.enumv() == DTypeEnum::Float32 && !fm.should_flip && | |||
| fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||
| fm.stride[0] == 2 && fm.stride[1] == 2 && FH == fm.spatial[1] && | |||
| (FH == 2 || FH == 3 || FH == 5 || FH == 7); | |||
| } | |||
| MIDOUT_END(); | |||
| @@ -50,8 +49,7 @@ size_t ConvBiasImpl::AlgoF32DirectStride2::get_workspace( | |||
| MIDOUT_END(); | |||
| return 0; | |||
| } | |||
| SmallVector<ConvBiasImpl::NCBKern> | |||
| ConvBiasImpl::AlgoF32DirectStride2::dispatch_kerns( | |||
| SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride2::dispatch_kerns( | |||
| const NCBKernSizeParam& param) const { | |||
| MIDOUT_BEGIN(megdnn_aarch64_conv_bias_stride2_conv2357_fp32, 0, 2) { | |||
| return get_kimpls(param); | |||
| @@ -60,8 +58,7 @@ ConvBiasImpl::AlgoF32DirectStride2::dispatch_kerns( | |||
| return {}; | |||
| } | |||
| SmallVector<ConvBiasImpl::NCBKern> | |||
| ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( | |||
| SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( | |||
| const NCBKernSizeParam& param) const { | |||
| auto fm = param.filter_meta; | |||
| auto FH = fm.spatial[0]; | |||
| @@ -70,8 +67,9 @@ ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( | |||
| size_t OC = param.filter_meta.ocpg; | |||
| size_t group = fm.group; | |||
| bool large_group = group >= param.nr_threads; | |||
| using Func = std::function<void(const float*, const float*, float*, size_t, | |||
| size_t, size_t, size_t, size_t)>; | |||
| using Func = std::function<void( | |||
| const float*, const float*, float*, size_t, size_t, size_t, size_t, | |||
| size_t)>; | |||
| Func conv = nullptr; | |||
| if (FH == 2) { | |||
| conv = fp32::conv_stride2::do_conv_2x2_stride2; | |||
| @@ -83,8 +81,9 @@ ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( | |||
| conv = fp32::conv_stride2::do_conv_7x7_stride2; | |||
| } | |||
| WorkspaceBundle bundle = arm_common::MultithreadDirectConvCommon< | |||
| float, float>::get_bundle_stride(param, large_group); | |||
| WorkspaceBundle bundle = | |||
| arm_common::MultithreadDirectConvCommon<float, float>::get_bundle_stride( | |||
| param, large_group); | |||
| SmallVector<NCBKern> ret_kerns; | |||
| //! Dense conv and small group | |||
| @@ -99,34 +98,34 @@ ConvBiasImpl::AlgoF32DirectStride2::get_kimpls( | |||
| bundle.set(kern_param.workspace_ptr); | |||
| for (size_t ic = 0; ic < IC; ic++) { | |||
| arm_common::MultithreadDirectConvCommon<float, float>:: | |||
| copy_padding_kern_stride(bundle, kern_param, ncb_index, | |||
| {ncb_index.thread_id, 0, ic}); | |||
| copy_padding_kern_stride( | |||
| bundle, kern_param, ncb_index, | |||
| {ncb_index.thread_id, 0, ic}); | |||
| } | |||
| for (size_t oc = 0; oc < OC; oc++) { | |||
| arm_common::MultithreadDirectConvCommon< | |||
| float, float>::do_conv_kern_stride(bundle, kern_param, | |||
| ncb_index, conv, | |||
| {ncb_index.thread_id, | |||
| 0, oc}); | |||
| arm_common::MultithreadDirectConvCommon<float, float>:: | |||
| do_conv_kern_stride( | |||
| bundle, kern_param, ncb_index, conv, | |||
| {ncb_index.thread_id, 0, oc}); | |||
| } | |||
| }; | |||
| ret_kerns.push_back({exec_one_group, {group, N, 1_z}}); | |||
| } else { | |||
| auto copy_padding = [bundle](const NCBKernParam& kern_param, | |||
| const NCBKernIndex& ncb_index) mutable { | |||
| auto copy_padding = [bundle]( | |||
| const NCBKernParam& kern_param, | |||
| const NCBKernIndex& ncb_index) mutable { | |||
| bundle.set(kern_param.workspace_ptr); | |||
| arm_common::MultithreadDirectConvCommon<float, float>:: | |||
| copy_padding_kern_stride(bundle, kern_param, ncb_index, | |||
| ncb_index.ndrange_id); | |||
| copy_padding_kern_stride( | |||
| bundle, kern_param, ncb_index, ncb_index.ndrange_id); | |||
| }; | |||
| ret_kerns.push_back({copy_padding, {group, N, IC}}); | |||
| auto do_conv = [bundle, conv](const NCBKernParam& kern_param, | |||
| const NCBKernIndex& ncb_index) mutable { | |||
| auto do_conv = [bundle, conv]( | |||
| const NCBKernParam& kern_param, | |||
| const NCBKernIndex& ncb_index) mutable { | |||
| bundle.set(kern_param.workspace_ptr); | |||
| arm_common::MultithreadDirectConvCommon< | |||
| float, float>::do_conv_kern_stride(bundle, kern_param, | |||
| ncb_index, conv, | |||
| ncb_index.ndrange_id); | |||
| arm_common::MultithreadDirectConvCommon<float, float>::do_conv_kern_stride( | |||
| bundle, kern_param, ncb_index, conv, ncb_index.ndrange_id); | |||
| }; | |||
| ret_kerns.push_back({do_conv, {group, N, OC}}); | |||
| } | |||
| @@ -22,14 +22,14 @@ using FallbackConvBiasImpl = fallback::ConvBiasImpl; | |||
| class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase { | |||
| SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const; | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| } | |||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
| const char* name() const override { return "ARMV8F32STRD2"; } | |||
| bool usable(const NCBKernSizeParam& param, | |||
| AlgoSelectionStrategy algo_selection_strategy) const override; | |||
| bool usable( | |||
| const NCBKernSizeParam& param, | |||
| AlgoSelectionStrategy algo_selection_strategy) const override; | |||
| size_t get_workspace(const NCBKernSizeParam& param) const override; | |||
| @@ -16,16 +16,15 @@ | |||
| namespace megdnn { | |||
| namespace aarch64 { | |||
| namespace fp32{ | |||
| namespace fp32 { | |||
| namespace conv_stride2 { | |||
| //! For the detail tune process, refer to `expr/conv_aarch64_stride2/main.cpp` | |||
| // refer to function do_conv_2x2_stride2_asm_unroll4 | |||
| static void do_conv_2x2_stride2(const float* src, const float* filter, | |||
| float* dst, size_t IH, size_t IW, size_t OH, | |||
| size_t OW, size_t IC) { | |||
| static void do_conv_2x2_stride2( | |||
| const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||
| size_t OH, size_t OW, size_t IC) { | |||
| const size_t tail_step = IW - 2 * OW + IW; | |||
| size_t width = OW >> 2; | |||
| size_t mod4_left = width & 3; | |||
| @@ -165,10 +164,9 @@ static void do_conv_2x2_stride2(const float* src, const float* filter, | |||
| "5: \n" | |||
| : "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1) | |||
| : "r"(mod4_left), "w"(_k0123) | |||
| : "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", | |||
| "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||
| "v15", "v16", "v17", "v18", "v19", "v28", "v29", "v30", | |||
| "v31"); | |||
| : "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", | |||
| "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", | |||
| "v17", "v18", "v19", "v28", "v29", "v30", "v31"); | |||
| r0 += tail_step; | |||
| r1 += tail_step; | |||
| @@ -179,9 +177,9 @@ static void do_conv_2x2_stride2(const float* src, const float* filter, | |||
| } | |||
| // refer to function do_conv_3x3_stride2_asm_unroll3 | |||
| static void do_conv_3x3_stride2(const float* src, const float* filter, | |||
| float* dst, size_t IH, size_t IW, size_t OH, | |||
| size_t OW, size_t IC) { | |||
| static void do_conv_3x3_stride2( | |||
| const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||
| size_t OH, size_t OW, size_t IC) { | |||
| const size_t tail_step = IW - 2 * OW + IW; | |||
| size_t width = OW >> 2; | |||
| size_t mod3_left = width % 3; | |||
| @@ -269,7 +267,7 @@ static void do_conv_3x3_stride2(const float* src, const float* filter, | |||
| "ld2 {v1.4s, v2.4s}, [%2], #32 \n" // 0, 2, 4, 6 | |||
| "ld2 {v5.4s, v6.4s}, [%3], #32 \n" | |||
| "ld1 {v3.4s}, [%2] \n" // load src 8 12 ... | |||
| "ld1 {v3.4s}, [%2] \n" // load src 8 12 ... | |||
| "fmla v0.4s, v1.4s, v21.4s \n" // src[i] * k[i] | |||
| "ext v7.16b, v1.16b, v3.16b, #4 \n" // 2, 4, 6, 8 | |||
| "fmla v0.4s, v2.4s, v22.4s \n" | |||
| @@ -356,10 +354,10 @@ static void do_conv_3x3_stride2(const float* src, const float* filter, | |||
| "3: \n" | |||
| : "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2) | |||
| : "r"(mod3_left), "w"(_k0123), "w"(_k3456), "w"(_k5678) | |||
| : "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", | |||
| "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||
| "v15", "v16", "v17", "v18", "v21", "v22", "v23", "v24", | |||
| "v25", "v26", "v27", "v28", "v29"); | |||
| : "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", | |||
| "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", | |||
| "v17", "v18", "v21", "v22", "v23", "v24", "v25", "v26", "v27", | |||
| "v28", "v29"); | |||
| r0 += tail_step; | |||
| r1 += tail_step; | |||
| @@ -371,9 +369,9 @@ static void do_conv_3x3_stride2(const float* src, const float* filter, | |||
| } | |||
| // refer to function do_conv_5x5_stride2_asm_unroll2 | |||
| static void do_conv_5x5_stride2(const float* src, const float* filter, | |||
| float* dst, size_t IH, size_t IW, size_t OH, | |||
| size_t OW, size_t IC) { | |||
| static void do_conv_5x5_stride2( | |||
| const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||
| size_t OH, size_t OW, size_t IC) { | |||
| const size_t tail_step = IW - 2 * OW + IW; | |||
| size_t width = OW >> 2; | |||
| size_t mod2_left = width & 1; | |||
| @@ -591,15 +589,13 @@ static void do_conv_5x5_stride2(const float* src, const float* filter, | |||
| "bne 2b \n" | |||
| "3: \n" | |||
| : "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), | |||
| "+r"(r3), "+r"(r4) | |||
| : "+r"(width), "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3), | |||
| "+r"(r4) | |||
| : "w"(_k0123), "w"(_k4567), "w"(_k891011), "w"(_k12131415), | |||
| "w"(_k16171819), "w"(_k20212223), "w"(_k24242424), | |||
| "r"(mod2_left) | |||
| : "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", | |||
| "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||
| "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", | |||
| "v23", "v24"); | |||
| "w"(_k16171819), "w"(_k20212223), "w"(_k24242424), "r"(mod2_left) | |||
| : "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", | |||
| "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", | |||
| "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24"); | |||
| r0 += tail_step; | |||
| r1 += tail_step; | |||
| @@ -613,9 +609,9 @@ static void do_conv_5x5_stride2(const float* src, const float* filter, | |||
| } | |||
| // refer to function do_conv_7x7_stride2_asm_unroll2 | |||
| static void do_conv_7x7_stride2(const float* src, const float* filter, | |||
| float* dst, size_t IH, size_t IW, size_t OH, | |||
| size_t OW, size_t IC) { | |||
| static void do_conv_7x7_stride2( | |||
| const float* src, const float* filter, float* dst, size_t IH, size_t IW, | |||
| size_t OH, size_t OW, size_t IC) { | |||
| const size_t tail_step = IW - 2 * OW + IW; | |||
| size_t width = OW >> 2; | |||
| @@ -993,16 +989,15 @@ static void do_conv_7x7_stride2(const float* src, const float* filter, | |||
| "bne 2b \n" | |||
| "3: \n" | |||
| : "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3), | |||
| "+r"(r4), "+r"(r5), "+r"(r6) | |||
| : "+r"(outptr), "+r"(r0), "+r"(r1), "+r"(r2), "+r"(r3), "+r"(r4), | |||
| "+r"(r5), "+r"(r6) | |||
| : "r"(width), "w"(_k0123), "w"(_k4567), "w"(_k891011), | |||
| "w"(_k12131415), "w"(_k16171819), "w"(_k20212223), | |||
| "w"(_k24252627), "w"(_k28293031), "w"(_k32333435), | |||
| "w"(_k36373839), "w"(_k40414243), "w"(_k44454647), | |||
| "w"(_k48484848) | |||
| : "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", | |||
| "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||
| "v15", "v16", "v17", "v18"); | |||
| "w"(_k36373839), "w"(_k40414243), "w"(_k44454647), "w"(_k48484848) | |||
| : "cc", "memory", "x1", "v0", "v1", "v2", "v3", "v4", "v5", "v6", | |||
| "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", | |||
| "v17", "v18"); | |||
| r0 += tail_step; | |||
| r1 += tail_step; | |||
| @@ -68,9 +68,9 @@ WorkspaceBundle ConvBiasImpl::AlgoS8MatrixMul::get_bundle( | |||
| size_t N = OH * OW; | |||
| #if MGB_ENABLE_DOT | |||
| #define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||
| _bias_midout_enum, _nonline, \ | |||
| _nonline_midout_enum) \ | |||
| #define DISPATCH_GEMM_STRATEGY( \ | |||
| _gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ | |||
| _nonline_midout_enum) \ | |||
| matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||
| M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||
| part2 = megdnn::matmul::GemmInterleaved< \ | |||
| @@ -84,11 +84,12 @@ WorkspaceBundle ConvBiasImpl::AlgoS8MatrixMul::get_bundle( | |||
| DISPATCH_GEMM_BIAS(s8_4x4, 0) | |||
| } | |||
| #else | |||
| #define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||
| _bias_midout_enum, _nonline, \ | |||
| _nonline_midout_enum) \ | |||
| MIDOUT_BEGIN(megdnn_aarch64_conv_bias_int8_gemm, 0, _gemm_midout_enum, \ | |||
| _bias_midout_enum, _nonline_midout_enum) { \ | |||
| #define DISPATCH_GEMM_STRATEGY( \ | |||
| _gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ | |||
| _nonline_midout_enum) \ | |||
| MIDOUT_BEGIN( \ | |||
| megdnn_aarch64_conv_bias_int8_gemm, 0, _gemm_midout_enum, \ | |||
| _bias_midout_enum, _nonline_midout_enum) { \ | |||
| matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||
| M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||
| part2 = megdnn::matmul::GemmInterleaved< \ | |||
| @@ -104,8 +105,8 @@ WorkspaceBundle ConvBiasImpl::AlgoS8MatrixMul::get_bundle( | |||
| return {nullptr, {part0, part1, part2}}; | |||
| } | |||
| void ConvBiasImpl::AlgoS8MatrixMul::kimpl(const NCBKernParam& param, | |||
| const NCBKernIndex& ncb_index) { | |||
| void ConvBiasImpl::AlgoS8MatrixMul::kimpl( | |||
| const NCBKernParam& param, const NCBKernIndex& ncb_index) { | |||
| auto is_xcorr = !param.filter_meta.should_flip; | |||
| UNPACK_CONV_NCB_KERN_SIZES(param); | |||
| auto bundle = get_bundle(param); | |||
| @@ -157,29 +158,28 @@ void ConvBiasImpl::AlgoS8MatrixMul::kimpl(const NCBKernParam& param, | |||
| img2col<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW); | |||
| } else { | |||
| if (is_xcorr) | |||
| img2col_stride<true>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, | |||
| FW, SH, SW); | |||
| img2col_stride<true>( | |||
| src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW, SH, SW); | |||
| else | |||
| img2col_stride<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, | |||
| FW, SH, SW); | |||
| img2col_stride<false>( | |||
| src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW, SH, SW); | |||
| } | |||
| } | |||
| { | |||
| Workspace workspace(static_cast<dt_byte*>(bundle.get(2)), | |||
| bundle.get_size(2)); | |||
| Workspace workspace( | |||
| static_cast<dt_byte*>(bundle.get(2)), bundle.get_size(2)); | |||
| size_t M = OC; | |||
| size_t K = IC * FH * FW; | |||
| size_t N = OH * OW; | |||
| #if MGB_ENABLE_DOT | |||
| #define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||
| _bias_midout_enum, _nonline, \ | |||
| _nonline_midout_enum) \ | |||
| matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||
| M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||
| megdnn::matmul::GemmInterleaved< \ | |||
| matmul::gemm_##_gemm##_##_bias##_##_nonline> \ | |||
| gemm_interleaved(M, N, K, false, false, strategy); \ | |||
| #define DISPATCH_GEMM_STRATEGY( \ | |||
| _gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ | |||
| _nonline_midout_enum) \ | |||
| matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||
| M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||
| megdnn::matmul::GemmInterleaved<matmul::gemm_##_gemm##_##_bias##_##_nonline> \ | |||
| gemm_interleaved(M, N, K, false, false, strategy); \ | |||
| gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias); | |||
| if (cpuinfo_has_arm_neon_dot()) { | |||
| @@ -188,19 +188,18 @@ void ConvBiasImpl::AlgoS8MatrixMul::kimpl(const NCBKernParam& param, | |||
| DISPATCH_GEMM_BIAS(s8_4x4, 0) | |||
| } | |||
| #else | |||
| #define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||
| _bias_midout_enum, _nonline, \ | |||
| _nonline_midout_enum) \ | |||
| MIDOUT_BEGIN(megdnn_aarch64_conv_bias_int8_gemm, 1, _gemm_midout_enum, \ | |||
| _bias_midout_enum, _nonline_midout_enum) { \ | |||
| matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||
| M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||
| megdnn::matmul::GemmInterleaved< \ | |||
| matmul::gemm_##_gemm##_##_bias##_##_nonline> \ | |||
| gemm_interleaved(M, N, K, false, false, strategy); \ | |||
| gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, \ | |||
| bias); \ | |||
| } \ | |||
| #define DISPATCH_GEMM_STRATEGY( \ | |||
| _gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ | |||
| _nonline_midout_enum) \ | |||
| MIDOUT_BEGIN( \ | |||
| megdnn_aarch64_conv_bias_int8_gemm, 1, _gemm_midout_enum, \ | |||
| _bias_midout_enum, _nonline_midout_enum) { \ | |||
| matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||
| M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||
| megdnn::matmul::GemmInterleaved<matmul::gemm_##_gemm##_##_bias##_##_nonline> \ | |||
| gemm_interleaved(M, N, K, false, false, strategy); \ | |||
| gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias); \ | |||
| } \ | |||
| MIDOUT_END() | |||
| DISPATCH_GEMM_BIAS(s8_4x4, 0) | |||
| #endif | |||
| @@ -12,8 +12,8 @@ | |||
| #pragma once | |||
| #include "src/aarch64/conv_bias/opr_impl.h" | |||
| #include "src/fallback/conv_bias/opr_impl.h" | |||
| #include "src/common/opr_delegate.h" | |||
| #include "src/fallback/conv_bias/opr_impl.h" | |||
| namespace megdnn { | |||
| namespace aarch64 { | |||
| @@ -25,18 +25,16 @@ class ConvBiasImpl::AlgoS8MatrixMul final : public AlgoBase { | |||
| static void kimpl(const NCBKernParam& param, const NCBKernIndex& ncb_index); | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| } | |||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
| const char* name() const override { return "S8MATMUL"; } | |||
| bool usable(const NCBKernSizeParam& param, | |||
| AlgoSelectionStrategy algo_selection_strategy) const override; | |||
| bool usable( | |||
| const NCBKernSizeParam& param, | |||
| AlgoSelectionStrategy algo_selection_strategy) const override; | |||
| size_t get_workspace(const NCBKernSizeParam& param) const override { | |||
| return get_bundle(param).total_size_in_bytes(); | |||
| } | |||
| SmallVector<NCBKern> dispatch_kerns( | |||
| const NCBKernSizeParam& param) const override { | |||
| SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam& param) const override { | |||
| size_t group = param.filter_meta.group; | |||
| return {{kimpl, {group, 1_z, 1_z}}}; | |||
| } | |||
| @@ -29,9 +29,10 @@ struct KernCaller; | |||
| #if MGB_ENABLE_DOT | |||
| template <BiasMode bmode, typename Op> | |||
| struct KernCaller<bmode, Op, 8, 12> { | |||
| static void run(const dt_int8* packA, const dt_int8* packB, size_t M, | |||
| size_t N, size_t K, dt_int8* C, size_t LDC, bool is_first_k, | |||
| Op op, const dt_int32* bias, dt_int32* workspace) { | |||
| static void run( | |||
| const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||
| dt_int8* C, size_t LDC, bool is_first_k, Op op, const dt_int32* bias, | |||
| dt_int32* workspace) { | |||
| megdnn_assert(is_first_k); | |||
| constexpr size_t A_INTERLEAVE = 8; | |||
| @@ -49,19 +50,19 @@ struct KernCaller<bmode, Op, 8, 12> { | |||
| size_t n = 0; | |||
| const dt_int8* cur_packB = packB; | |||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
| matmul_8x12x4::kern_8x12(packA, cur_packB, K, workspace, 12, | |||
| is_first_k); | |||
| matmul_8x12x4::kern_8x12( | |||
| packA, cur_packB, K, workspace, 12, is_first_k); | |||
| arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 8, 12, 8, | |||
| 12>::postprocess(bias, workspace, | |||
| output, LDC, op); | |||
| arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 8, 12, 8, 12>:: | |||
| postprocess(bias, workspace, output, LDC, op); | |||
| output += B_INTERLEAVE; | |||
| cur_packB += K12; | |||
| } | |||
| for (; n < N; n += 4) { | |||
| matmul_8x12x4::kern_8x4(packA, cur_packB, K, workspace, 4, | |||
| is_first_k, std::min<size_t>(N - n, 4)); | |||
| matmul_8x12x4::kern_8x4( | |||
| packA, cur_packB, K, workspace, 4, is_first_k, | |||
| std::min<size_t>(N - n, 4)); | |||
| #define cb(m, n) \ | |||
| arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 8, 4, 8, n>::postprocess( \ | |||
| @@ -83,9 +84,9 @@ struct KernCaller<bmode, Op, 8, 12> { | |||
| const dt_int8* cur_packB = packB; | |||
| size_t n = 0; | |||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
| matmul_8x12x4::kern_4x12(packA, cur_packB, K, workspace, 12, | |||
| is_first_k, | |||
| std::min<size_t>(M - m, 4)); | |||
| matmul_8x12x4::kern_4x12( | |||
| packA, cur_packB, K, workspace, 12, is_first_k, | |||
| std::min<size_t>(M - m, 4)); | |||
| #define cb(m, n) \ | |||
| arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 12, m, n>::postprocess( \ | |||
| bias, workspace, output, LDC, op); | |||
| @@ -97,14 +98,13 @@ struct KernCaller<bmode, Op, 8, 12> { | |||
| } | |||
| for (; n < N; n += 4) { | |||
| matmul_8x12x4::kern_4x4(packA, cur_packB, K, workspace, 4, | |||
| is_first_k, std::min<size_t>(M - m, 4), | |||
| std::min<size_t>(N - n, 4)); | |||
| matmul_8x12x4::kern_4x4( | |||
| packA, cur_packB, K, workspace, 4, is_first_k, | |||
| std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||
| #define cb(m, n) \ | |||
| arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, m, n>::postprocess( \ | |||
| bias, workspace, output, LDC, op); | |||
| DISPATCH_M(cb, std::min<size_t>(M - m, 4), | |||
| std::min<size_t>(N - n, 4)); | |||
| DISPATCH_M(cb, std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||
| #undef cb | |||
| output += 4; | |||
| @@ -122,9 +122,10 @@ struct KernCaller<bmode, Op, 8, 12> { | |||
| template <BiasMode bmode, typename Op> | |||
| struct KernCaller<bmode, Op, 4, 4> { | |||
| static void run(const dt_int8* packA, const dt_int8* packB, size_t M, | |||
| size_t N, size_t K, dt_int8* C, size_t LDC, bool is_first_k, | |||
| Op op, const dt_int32* bias, dt_int32* workspace) { | |||
| static void run( | |||
| const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||
| dt_int8* C, size_t LDC, bool is_first_k, Op op, const dt_int32* bias, | |||
| dt_int32* workspace) { | |||
| megdnn_assert(is_first_k); | |||
| constexpr size_t A_INTERLEAVE = 4; | |||
| @@ -140,20 +141,18 @@ struct KernCaller<bmode, Op, 4, 4> { | |||
| size_t n = 0; | |||
| const dt_int8* cur_packB = packB; | |||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
| matmul_4x4x16::kern_4x4(packA, cur_packB, K, workspace, 4, | |||
| is_first_k); | |||
| arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, 4, | |||
| 4>::postprocess(bias, workspace, | |||
| output, LDC, op); | |||
| matmul_4x4x16::kern_4x4(packA, cur_packB, K, workspace, 4, is_first_k); | |||
| arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, 4, 4>::postprocess( | |||
| bias, workspace, output, LDC, op); | |||
| output += B_INTERLEAVE; | |||
| cur_packB += K4; | |||
| } | |||
| for (; n < N; n += B_INTERLEAVE) { | |||
| matmul_4x4x16::kern_4x4_remain(packA, cur_packB, K, workspace, | |||
| 4, is_first_k, 4, | |||
| std::min<size_t>(N - n, 4)); | |||
| matmul_4x4x16::kern_4x4_remain( | |||
| packA, cur_packB, K, workspace, 4, is_first_k, 4, | |||
| std::min<size_t>(N - n, 4)); | |||
| #define cb(m, n) \ | |||
| arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, 4, n>::postprocess( \ | |||
| bias, workspace, output, LDC, op); | |||
| @@ -182,8 +181,7 @@ struct KernCaller<bmode, Op, 4, 4> { | |||
| #define cb(m, n) \ | |||
| arm_common::ConvBiasMatmul<bmode, Op, dt_int8, 4, 4, m, n>::postprocess( \ | |||
| bias, workspace, output, LDC, op); | |||
| DISPATCH_M(cb, std::min<size_t>(M - m, 4), | |||
| std::min<size_t>(N - n, 4)); | |||
| DISPATCH_M(cb, std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||
| #undef cb | |||
| output += B_INTERLEAVE; | |||
| cur_packB += K4; | |||
| @@ -200,21 +198,19 @@ struct KernCaller<bmode, Op, 4, 4> { | |||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_4x4_nobias_identity) | |||
| void gemm_s8_4x4_nobias_identity::pack_A(dt_int8* outptr, const dt_int8* inptr, | |||
| int ldin, int y0, int ymax, int k0, | |||
| int kmax, bool transpose) const { | |||
| void gemm_s8_4x4_nobias_identity::pack_A( | |||
| dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||
| int kmax, bool transpose) const { | |||
| if (transpose) { | |||
| matmul_4x4x16::gemm_s8_4x4_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, | |||
| kmax); | |||
| matmul_4x4x16::gemm_s8_4x4_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||
| } else { | |||
| matmul_4x4x16::gemm_s8_4x4_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, | |||
| kmax); | |||
| matmul_4x4x16::gemm_s8_4x4_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||
| } | |||
| } | |||
| void gemm_s8_4x4_nobias_identity::pack_B(dt_int8* out, const dt_int8* in, | |||
| int ldin, int x0, int xmax, int k0, | |||
| int kmax, bool transpose) const { | |||
| void gemm_s8_4x4_nobias_identity::pack_B( | |||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||
| bool transpose) const { | |||
| if (transpose) { | |||
| matmul_4x4x16::gemm_s8_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, kmax); | |||
| } else { | |||
| @@ -229,23 +225,21 @@ size_t gemm_s8_4x4_nobias_identity::get_workspace_size() const { | |||
| #if MGB_ENABLE_DOT | |||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x12_nobias_identity) | |||
| void gemm_s8_8x12_nobias_identity::pack_A(dt_int8* outptr, const dt_int8* inptr, | |||
| int ldin, int y0, int ymax, int k0, | |||
| int kmax, bool transpose) const { | |||
| void gemm_s8_8x12_nobias_identity::pack_A( | |||
| dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||
| int kmax, bool transpose) const { | |||
| MEGDNN_MARK_USED_VAR(matmul_8x12x4::gemm_s8_8x12_pack_A_t); | |||
| MEGDNN_MARK_USED_VAR(matmul_8x12x4::gemm_s8_8x12_pack_B_t); | |||
| if (transpose) { | |||
| matmul_8x12x4::gemm_s8_8x12_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, | |||
| kmax); | |||
| matmul_8x12x4::gemm_s8_8x12_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||
| } else { | |||
| matmul_8x12x4::gemm_s8_8x12_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, | |||
| kmax); | |||
| matmul_8x12x4::gemm_s8_8x12_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||
| } | |||
| } | |||
| void gemm_s8_8x12_nobias_identity::pack_B(dt_int8* out, const dt_int8* in, | |||
| int ldin, int x0, int xmax, int k0, | |||
| int kmax, bool transpose) const { | |||
| void gemm_s8_8x12_nobias_identity::pack_B( | |||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||
| bool transpose) const { | |||
| if (transpose) { | |||
| matmul_8x12x4::gemm_s8_8x12_pack_A_n(out, in, ldin, x0, xmax, k0, kmax); | |||
| } else { | |||
| @@ -259,18 +253,17 @@ size_t gemm_s8_8x12_nobias_identity::get_workspace_size() const { | |||
| #endif | |||
| #define KERN(_block_m, _block_n, _bias, _BIAS, _nonline, _OP) \ | |||
| void gemm_s8_##_block_m##x##_block_n##_##_bias##_##_nonline::kern( \ | |||
| const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, \ | |||
| size_t K, dt_int8* C, size_t LDC, bool is_first_k, \ | |||
| const dt_int32* bias, dt_int32* workspace) const { \ | |||
| float scale_A = A_dtype.param<dtype::QuantizedS8>().scale; \ | |||
| float scale_B = B_dtype.param<dtype::QuantizedS8>().scale; \ | |||
| float scale_C = C_dtype.param<dtype::QuantizedS8>().scale; \ | |||
| DEFINE_OP(_OP); \ | |||
| impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n>::run( \ | |||
| packA, packB, M, N, K, C, LDC, is_first_k, op, bias, \ | |||
| workspace); \ | |||
| #define KERN(_block_m, _block_n, _bias, _BIAS, _nonline, _OP) \ | |||
| void gemm_s8_##_block_m##x##_block_n##_##_bias##_##_nonline::kern( \ | |||
| const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, \ | |||
| dt_int8* C, size_t LDC, bool is_first_k, const dt_int32* bias, \ | |||
| dt_int32* workspace) const { \ | |||
| float scale_A = A_dtype.param<dtype::QuantizedS8>().scale; \ | |||
| float scale_B = B_dtype.param<dtype::QuantizedS8>().scale; \ | |||
| float scale_C = C_dtype.param<dtype::QuantizedS8>().scale; \ | |||
| DEFINE_OP(_OP); \ | |||
| impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n>::run( \ | |||
| packA, packB, M, N, K, C, LDC, is_first_k, op, bias, workspace); \ | |||
| } | |||
| #define DEFINE_OP(_Op) \ | |||
| @@ -286,18 +279,16 @@ KERN(8, 12, nobias, BiasMode::NO_BIAS, hswish, HSwishOp) | |||
| #endif | |||
| #undef DEFINE_OP | |||
| #define DEFINE_OP(_Op) \ | |||
| arm_common::_Op<dt_qint32, dt_qint8> op(scale_A* scale_B, \ | |||
| scale_A* scale_B, scale_C); | |||
| #define DEFINE_OP(_Op) \ | |||
| arm_common::_Op<dt_qint32, dt_qint8> op( \ | |||
| scale_A* scale_B, scale_A* scale_B, scale_C); | |||
| KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) | |||
| KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) | |||
| KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, | |||
| FuseAddHSwishOp) | |||
| KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, FuseAddHSwishOp) | |||
| #if MGB_ENABLE_DOT | |||
| KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) | |||
| KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) | |||
| KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, | |||
| FuseAddHSwishOp) | |||
| KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, FuseAddHSwishOp) | |||
| #endif | |||
| #undef DEFINE_OP | |||
| @@ -20,43 +20,42 @@ namespace matmul { | |||
| * | |||
| * \name gemm_<type>_<block>_biasmode_nolinemode | |||
| */ | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_int8, dt_int8, dt_int32, 4, 4, 16, | |||
| false, true, | |||
| gemm_s8_4x4_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK( | |||
| dt_int8, dt_int8, dt_int32, 4, 4, 16, false, true, gemm_s8_4x4_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_nobias_relu, | |||
| gemm_s8_4x4_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||
| gemm_s8_4x4_nobias_relu, gemm_s8_4x4_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_nobias_hswish, | |||
| gemm_s8_4x4_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||
| gemm_s8_4x4_nobias_hswish, gemm_s8_4x4_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_identity, | |||
| gemm_s8_4x4_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||
| gemm_s8_4x4_bias_channel_identity, gemm_s8_4x4_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_relu, | |||
| gemm_s8_4x4_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||
| gemm_s8_4x4_bias_channel_relu, gemm_s8_4x4_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_hswish, | |||
| gemm_s8_4x4_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||
| gemm_s8_4x4_bias_channel_hswish, gemm_s8_4x4_nobias_identity); | |||
| #if MGB_ENABLE_DOT | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_int8, dt_int8, dt_int32, 8, 12, 4, | |||
| false, true, | |||
| gemm_s8_8x12_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK( | |||
| dt_int8, dt_int8, dt_int32, 8, 12, 4, false, true, | |||
| gemm_s8_8x12_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_nobias_relu, | |||
| gemm_s8_8x12_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||
| gemm_s8_8x12_nobias_relu, gemm_s8_8x12_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_nobias_hswish, | |||
| gemm_s8_8x12_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||
| gemm_s8_8x12_nobias_hswish, gemm_s8_8x12_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_identity, | |||
| gemm_s8_8x12_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||
| gemm_s8_8x12_bias_channel_identity, gemm_s8_8x12_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_relu, | |||
| gemm_s8_8x12_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||
| gemm_s8_8x12_bias_channel_relu, gemm_s8_8x12_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_hswish, | |||
| gemm_s8_8x12_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||
| gemm_s8_8x12_bias_channel_hswish, gemm_s8_8x12_nobias_identity); | |||
| #endif | |||
| } // namespace matmul | |||
| @@ -13,13 +13,13 @@ | |||
| #include "src/aarch64/conv_bias/int8/algos.h" | |||
| #include "src/aarch64/conv_bias/quint8/algos.h" | |||
| #include "src/naive/handle.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/common/metahelper.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/naive/handle.h" | |||
| #include "src/fallback/convolution/opr_impl.h" | |||
| #include "src/aarch64/conv_bias/fp32/algos.h" | |||
| #include "src/aarch64/conv_bias/fp16/algos.h" | |||
| #include "src/aarch64/conv_bias/fp32/algos.h" | |||
| #include "src/fallback/convolution/opr_impl.h" | |||
| using namespace megdnn; | |||
| using namespace aarch64; | |||
| @@ -56,12 +56,10 @@ public: | |||
| const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& direct_algos() const { | |||
| return m_direct_algos; | |||
| } | |||
| const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& matmul_algos() | |||
| const { | |||
| const SmallVector<fallback::ConvBiasImpl::AlgoBase*>& matmul_algos() const { | |||
| return m_matmul_algos; | |||
| } | |||
| const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } | |||
| }; | |||
| const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() { | |||
| @@ -71,15 +69,16 @@ const ConvBiasImpl::AlgoPack& ConvBiasImpl::algo_pack() { | |||
| MEGDNN_FB_DEF_GET_ALGO_FROM_DESC(ConvBiasImpl) | |||
| SmallVector<fallback::ConvBiasImpl::AlgoBase*> | |||
| ConvBiasImpl::get_all_packed_algo() { | |||
| SmallVector<fallback::ConvBiasImpl::AlgoBase*> ConvBiasImpl::get_all_packed_algo() { | |||
| auto&& algos = arm_common::ConvBiasImpl::get_all_packed_algo(); | |||
| algos.insert(algos.begin(), algo_pack().direct_algos().begin(), | |||
| algo_pack().direct_algos().end()); | |||
| algos.insert( | |||
| algos.begin(), algo_pack().direct_algos().begin(), | |||
| algo_pack().direct_algos().end()); | |||
| //! We put matmul algos at the begin. Because matmul will get privilege when | |||
| //! prefer return true. See | |||
| algos.insert(algos.begin(), algo_pack().matmul_algos().begin(), | |||
| algo_pack().matmul_algos().end()); | |||
| algos.insert( | |||
| algos.begin(), algo_pack().matmul_algos().begin(), | |||
| algo_pack().matmul_algos().end()); | |||
| return std::move(algos); | |||
| } | |||
| @@ -9,8 +9,8 @@ | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #pragma once | |||
| #include "src/common/utils.h" | |||
| #include "src/arm_common/conv_bias/opr_impl.h" | |||
| #include "src/common/utils.h" | |||
| namespace megdnn { | |||
| namespace aarch64 { | |||
| @@ -70,9 +70,9 @@ WorkspaceBundle ConvBiasImpl::AlgoQU8MatrixMul::get_bundle( | |||
| size_t N = OH * OW; | |||
| #if MGB_ENABLE_DOT | |||
| #define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||
| _bias_midout_enum, _nonline, \ | |||
| _nonline_midout_enum) \ | |||
| #define DISPATCH_GEMM_STRATEGY( \ | |||
| _gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ | |||
| _nonline_midout_enum) \ | |||
| matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||
| M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||
| part2 = megdnn::matmul::GemmInterleaved< \ | |||
| @@ -86,11 +86,12 @@ WorkspaceBundle ConvBiasImpl::AlgoQU8MatrixMul::get_bundle( | |||
| DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0); | |||
| } | |||
| #else | |||
| #define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||
| _bias_midout_enum, _nonline, \ | |||
| _nonline_midout_enum) \ | |||
| MIDOUT_BEGIN(megdnn_aarch64_conv_bias_quint8_gemm, 0, _gemm_midout_enum, \ | |||
| _bias_midout_enum, _nonline_midout_enum) { \ | |||
| #define DISPATCH_GEMM_STRATEGY( \ | |||
| _gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ | |||
| _nonline_midout_enum) \ | |||
| MIDOUT_BEGIN( \ | |||
| megdnn_aarch64_conv_bias_quint8_gemm, 0, _gemm_midout_enum, \ | |||
| _bias_midout_enum, _nonline_midout_enum) { \ | |||
| matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||
| M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||
| part2 = megdnn::matmul::GemmInterleaved< \ | |||
| @@ -106,8 +107,8 @@ WorkspaceBundle ConvBiasImpl::AlgoQU8MatrixMul::get_bundle( | |||
| return {nullptr, {part0, part1, part2}}; | |||
| } | |||
| void ConvBiasImpl::AlgoQU8MatrixMul::kimpl(const NCBKernParam& param, | |||
| const NCBKernIndex& ncb_index) { | |||
| void ConvBiasImpl::AlgoQU8MatrixMul::kimpl( | |||
| const NCBKernParam& param, const NCBKernIndex& ncb_index) { | |||
| auto is_xcorr = !param.filter_meta.should_flip; | |||
| UNPACK_CONV_NCB_KERN_SIZES(param); | |||
| auto bundle = get_bundle(param); | |||
| @@ -160,29 +161,28 @@ void ConvBiasImpl::AlgoQU8MatrixMul::kimpl(const NCBKernParam& param, | |||
| img2col<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW); | |||
| } else { | |||
| if (is_xcorr) | |||
| img2col_stride<true>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, | |||
| FW, SH, SW); | |||
| img2col_stride<true>( | |||
| src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW, SH, SW); | |||
| else | |||
| img2col_stride<false>(src2, B, OC, OH, OW, IC, IH2, IW2, FH, | |||
| FW, SH, SW); | |||
| img2col_stride<false>( | |||
| src2, B, OC, OH, OW, IC, IH2, IW2, FH, FW, SH, SW); | |||
| } | |||
| } | |||
| { | |||
| Workspace workspace(static_cast<dt_byte*>(bundle.get(2)), | |||
| bundle.get_size(2)); | |||
| Workspace workspace( | |||
| static_cast<dt_byte*>(bundle.get(2)), bundle.get_size(2)); | |||
| size_t M = OC; | |||
| size_t K = IC * FH * FW; | |||
| size_t N = OH * OW; | |||
| #if MGB_ENABLE_DOT | |||
| #define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||
| _bias_midout_enum, _nonline, \ | |||
| _nonline_midout_enum) \ | |||
| matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||
| M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||
| megdnn::matmul::GemmInterleaved< \ | |||
| matmul::gemm_##_gemm##_##_bias##_##_nonline> \ | |||
| gemm_interleaved(M, N, K, false, false, strategy); \ | |||
| #define DISPATCH_GEMM_STRATEGY( \ | |||
| _gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ | |||
| _nonline_midout_enum) \ | |||
| matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||
| M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||
| megdnn::matmul::GemmInterleaved<matmul::gemm_##_gemm##_##_bias##_##_nonline> \ | |||
| gemm_interleaved(M, N, K, false, false, strategy); \ | |||
| gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias); | |||
| if (cpuinfo_has_arm_neon_dot()) { | |||
| @@ -191,19 +191,18 @@ void ConvBiasImpl::AlgoQU8MatrixMul::kimpl(const NCBKernParam& param, | |||
| DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0) | |||
| } | |||
| #else | |||
| #define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ | |||
| _bias_midout_enum, _nonline, \ | |||
| _nonline_midout_enum) \ | |||
| MIDOUT_BEGIN(megdnn_aarch64_conv_bias_quint8_gemm, 1, _gemm_midout_enum, \ | |||
| _bias_midout_enum, _nonline_midout_enum) { \ | |||
| matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||
| M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||
| megdnn::matmul::GemmInterleaved< \ | |||
| matmul::gemm_##_gemm##_##_bias##_##_nonline> \ | |||
| gemm_interleaved(M, N, K, false, false, strategy); \ | |||
| gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, \ | |||
| bias); \ | |||
| } \ | |||
| #define DISPATCH_GEMM_STRATEGY( \ | |||
| _gemm, _gemm_midout_enum, _bias, _bias_midout_enum, _nonline, \ | |||
| _nonline_midout_enum) \ | |||
| MIDOUT_BEGIN( \ | |||
| megdnn_aarch64_conv_bias_quint8_gemm, 1, _gemm_midout_enum, \ | |||
| _bias_midout_enum, _nonline_midout_enum) { \ | |||
| matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ | |||
| M, N, K, param.filter_type, param.src_type, param.dst_type); \ | |||
| megdnn::matmul::GemmInterleaved<matmul::gemm_##_gemm##_##_bias##_##_nonline> \ | |||
| gemm_interleaved(M, N, K, false, false, strategy); \ | |||
| gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias); \ | |||
| } \ | |||
| MIDOUT_END() | |||
| DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0) | |||
| @@ -12,8 +12,8 @@ | |||
| #pragma once | |||
| #include "src/aarch64/conv_bias/opr_impl.h" | |||
| #include "src/fallback/conv_bias/opr_impl.h" | |||
| #include "src/common/opr_delegate.h" | |||
| #include "src/fallback/conv_bias/opr_impl.h" | |||
| namespace megdnn { | |||
| namespace aarch64 { | |||
| @@ -25,18 +25,16 @@ class ConvBiasImpl::AlgoQU8MatrixMul final : public AlgoBase { | |||
| static void kimpl(const NCBKernParam& param, const NCBKernIndex&); | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| } | |||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
| const char* name() const override { return "QU8MATMUL"; } | |||
| bool usable(const NCBKernSizeParam& param, | |||
| AlgoSelectionStrategy algo_selection_strategy) const override; | |||
| bool usable( | |||
| const NCBKernSizeParam& param, | |||
| AlgoSelectionStrategy algo_selection_strategy) const override; | |||
| size_t get_workspace(const NCBKernSizeParam& param) const override { | |||
| return get_bundle(param).total_size_in_bytes(); | |||
| } | |||
| SmallVector<NCBKern> dispatch_kerns( | |||
| const NCBKernSizeParam& param) const override { | |||
| SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam& param) const override { | |||
| size_t group = param.filter_meta.group; | |||
| return {{kimpl, {group, 1_z, 1_z}}}; | |||
| } | |||
| @@ -14,8 +14,8 @@ | |||
| #include "src/common/utils.h" | |||
| #include "src/fallback/conv_bias/common.h" | |||
| #include "src/aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h" | |||
| #include "src/aarch64/matrix_mul/quint8/kernel_8x8x8.h" | |||
| #include "src/aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h" | |||
| #include "src/arm_common/conv_bias/matmul_postprocess.h" | |||
| using namespace megdnn; | |||
| @@ -29,10 +29,10 @@ struct KernCaller; | |||
| #if MGB_ENABLE_DOT | |||
| template <BiasMode bmode, typename Op> | |||
| struct KernCaller<bmode, Op, 8, 8, true> { | |||
| static void run(const dt_uint8* packA, const dt_uint8* packB, size_t M, | |||
| size_t N, size_t K, dt_uint8* C, size_t LDC, | |||
| bool is_first_k, Op op, const dt_int32* bias, | |||
| dt_int32* workspace, uint8_t zp_A, uint8_t zp_B) { | |||
| static void run( | |||
| const dt_uint8* packA, const dt_uint8* packB, size_t M, size_t N, size_t K, | |||
| dt_uint8* C, size_t LDC, bool is_first_k, Op op, const dt_int32* bias, | |||
| dt_int32* workspace, uint8_t zp_A, uint8_t zp_B) { | |||
| megdnn_assert(is_first_k); | |||
| constexpr size_t A_INTERLEAVE = 8; | |||
| constexpr size_t B_INTERLEAVE = 8; | |||
| @@ -50,20 +50,19 @@ struct KernCaller<bmode, Op, 8, 8, true> { | |||
| size_t n = 0; | |||
| const dt_uint8* cur_packB = packB; | |||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
| matmul_8x8x4::kern_8x8(packA, cur_packB, K, workspace, 8, | |||
| is_first_k, zp_A, zp_B, zAB); | |||
| matmul_8x8x4::kern_8x8( | |||
| packA, cur_packB, K, workspace, 8, is_first_k, zp_A, zp_B, zAB); | |||
| arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 8, 8, | |||
| 8>::postprocess(bias, workspace, | |||
| output, LDC, op); | |||
| arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 8, 8, 8>:: | |||
| postprocess(bias, workspace, output, LDC, op); | |||
| output += B_INTERLEAVE; | |||
| cur_packB += K8; | |||
| } | |||
| for (; n < N; n += 4) { | |||
| matmul_8x8x4::kern_8x4(packA, cur_packB, K, workspace, 4, | |||
| is_first_k, std::min<size_t>(N - n, 4), | |||
| zp_A, zp_B, zAB); | |||
| matmul_8x8x4::kern_8x4( | |||
| packA, cur_packB, K, workspace, 4, is_first_k, | |||
| std::min<size_t>(N - n, 4), zp_A, zp_B, zAB); | |||
| #define cb(m, n) \ | |||
| arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 4, 8, n>::postprocess( \ | |||
| bias, workspace, output, LDC, op); | |||
| @@ -84,9 +83,9 @@ struct KernCaller<bmode, Op, 8, 8, true> { | |||
| const dt_uint8* cur_packB = packB; | |||
| size_t n = 0; | |||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
| matmul_8x8x4::kern_4x8(packA, cur_packB, K, workspace, 8, | |||
| is_first_k, std::min<size_t>(M - m, 4), | |||
| zp_A, zp_B, zAB); | |||
| matmul_8x8x4::kern_4x8( | |||
| packA, cur_packB, K, workspace, 8, is_first_k, | |||
| std::min<size_t>(M - m, 4), zp_A, zp_B, zAB); | |||
| #define cb(m, n) \ | |||
| arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 8, m, n>::postprocess( \ | |||
| bias, workspace, output, LDC, op); | |||
| @@ -98,15 +97,14 @@ struct KernCaller<bmode, Op, 8, 8, true> { | |||
| } | |||
| for (; n < N; n += 4) { | |||
| matmul_8x8x4::kern_4x4(packA, cur_packB, K, workspace, 4, | |||
| is_first_k, std::min<size_t>(M - m, 4), | |||
| std::min<size_t>(N - n, 4), zp_A, zp_B, | |||
| zAB); | |||
| matmul_8x8x4::kern_4x4( | |||
| packA, cur_packB, K, workspace, 4, is_first_k, | |||
| std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4), zp_A, | |||
| zp_B, zAB); | |||
| #define cb(m, n) \ | |||
| arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 4, m, n>::postprocess( \ | |||
| bias, workspace, output, LDC, op); | |||
| DISPATCH_M(cb, std::min<size_t>(M - m, 4), | |||
| std::min<size_t>(N - n, 4)); | |||
| DISPATCH_M(cb, std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||
| #undef cb | |||
| output += 4; | |||
| @@ -124,10 +122,10 @@ struct KernCaller<bmode, Op, 8, 8, true> { | |||
| template <BiasMode bmode, typename Op> | |||
| struct KernCaller<bmode, Op, 8, 8, false> { | |||
| static void run(const dt_uint8* packA, const dt_uint8* packB, size_t M, | |||
| size_t N, size_t K, dt_uint8* C, size_t LDC, | |||
| bool is_first_k, Op op, const dt_int32* bias, | |||
| dt_int32* workspace, uint8_t zp_A, uint8_t zp_B) { | |||
| static void run( | |||
| const dt_uint8* packA, const dt_uint8* packB, size_t M, size_t N, size_t K, | |||
| dt_uint8* C, size_t LDC, bool is_first_k, Op op, const dt_int32* bias, | |||
| dt_int32* workspace, uint8_t zp_A, uint8_t zp_B) { | |||
| megdnn_assert(is_first_k); | |||
| constexpr size_t A_INTERLEAVE = 8; | |||
| @@ -144,27 +142,25 @@ struct KernCaller<bmode, Op, 8, 8, false> { | |||
| size_t n = 0; | |||
| const dt_uint8* cur_packB = packB; | |||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
| matmul_8x8x8::kern_8x8(packA, cur_packB, K, workspace, 8, | |||
| is_first_k, zp_A, zp_B); | |||
| matmul_8x8x8::kern_8x8( | |||
| packA, cur_packB, K, workspace, 8, is_first_k, zp_A, zp_B); | |||
| arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 8, 8, | |||
| 8>::postprocess(bias, workspace, | |||
| output, LDC, op); | |||
| arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 8, 8, 8>:: | |||
| postprocess(bias, workspace, output, LDC, op); | |||
| output += B_INTERLEAVE; | |||
| cur_packB += K8; | |||
| } | |||
| for (; n < N; n += 4) { | |||
| matmul_8x8x8::kern_8x4(packA, cur_packB, K, workspace, 4, | |||
| is_first_k, std::min<size_t>(N - n, 4), | |||
| zp_A, zp_B); | |||
| matmul_8x8x8::kern_8x4( | |||
| packA, cur_packB, K, workspace, 4, is_first_k, | |||
| std::min<size_t>(N - n, 4), zp_A, zp_B); | |||
| #define cb(m, n) \ | |||
| arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 8, 4, 8, n>::postprocess( \ | |||
| bias, workspace, output, LDC, op); | |||
| DISPATCH_N(cb, 8, std::min<size_t>(N - n, 4)); | |||
| #undef cb | |||
| output += 4; | |||
| cur_packB += K4; | |||
| } | |||
| @@ -179,9 +175,9 @@ struct KernCaller<bmode, Op, 8, 8, false> { | |||
| const dt_uint8* cur_packB = packB; | |||
| size_t n = 0; | |||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
| matmul_8x8x8::kern_4x8(packA, cur_packB, K, workspace, 8, | |||
| is_first_k, std::min<size_t>(M - m, 4), | |||
| zp_A, zp_B); | |||
| matmul_8x8x8::kern_4x8( | |||
| packA, cur_packB, K, workspace, 8, is_first_k, | |||
| std::min<size_t>(M - m, 4), zp_A, zp_B); | |||
| #define cb(m, n) \ | |||
| arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 8, m, n>::postprocess( \ | |||
| bias, workspace, output, LDC, op); | |||
| @@ -193,17 +189,16 @@ struct KernCaller<bmode, Op, 8, 8, false> { | |||
| } | |||
| for (; n < N; n += 4) { | |||
| matmul_8x8x8::kern_4x4(packA, cur_packB, K, workspace, 4, | |||
| is_first_k, std::min<size_t>(M - m, 4), | |||
| std::min<size_t>(N - n, 4), zp_A, zp_B); | |||
| matmul_8x8x8::kern_4x4( | |||
| packA, cur_packB, K, workspace, 4, is_first_k, | |||
| std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4), zp_A, | |||
| zp_B); | |||
| #define cb(m, n) \ | |||
| arm_common::ConvBiasMatmul<bmode, Op, dt_uint8, 4, 4, m, n>::postprocess( \ | |||
| bias, workspace, output, LDC, op); | |||
| DISPATCH_M(cb, std::min<size_t>(M - m, 4), | |||
| std::min<size_t>(N - n, 4)); | |||
| DISPATCH_M(cb, std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||
| #undef cb | |||
| output += 4; | |||
| cur_packB += K4; | |||
| } | |||
| @@ -219,27 +214,27 @@ struct KernCaller<bmode, Op, 8, 8, false> { | |||
| #if MGB_ENABLE_DOT | |||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_dot_nobias_identity) | |||
| void gemm_u8_8x8_dot_nobias_identity::pack_A(uint8_t* outptr, const uint8_t* inptr, | |||
| int ldin, int y0, int ymax, int k0, | |||
| int kmax, bool transpose) const { | |||
| void gemm_u8_8x8_dot_nobias_identity::pack_A( | |||
| uint8_t* outptr, const uint8_t* inptr, int ldin, int y0, int ymax, int k0, | |||
| int kmax, bool transpose) const { | |||
| if (transpose) { | |||
| matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper(outptr, inptr, ldin, y0, | |||
| ymax, k0, kmax); | |||
| matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper( | |||
| outptr, inptr, ldin, y0, ymax, k0, kmax); | |||
| } else { | |||
| matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(outptr, inptr, ldin, | |||
| y0, ymax, k0, kmax); | |||
| matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper( | |||
| outptr, inptr, ldin, y0, ymax, k0, kmax); | |||
| } | |||
| } | |||
| void gemm_u8_8x8_dot_nobias_identity::pack_B(uint8_t* out, const uint8_t* in, | |||
| int ldin, int x0, int xmax, int k0, | |||
| int kmax, bool transpose) const { | |||
| void gemm_u8_8x8_dot_nobias_identity::pack_B( | |||
| uint8_t* out, const uint8_t* in, int ldin, int x0, int xmax, int k0, int kmax, | |||
| bool transpose) const { | |||
| if (transpose) { | |||
| matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(out, in, ldin, x0, | |||
| xmax, k0, kmax); | |||
| matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper( | |||
| out, in, ldin, x0, xmax, k0, kmax); | |||
| } else { | |||
| matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper(out, in, ldin, x0, xmax, | |||
| k0, kmax); | |||
| matmul_8x8x4::gemm_u8_8x8_transpose_pack_helper( | |||
| out, in, ldin, x0, xmax, k0, kmax); | |||
| } | |||
| } | |||
| @@ -249,30 +244,27 @@ size_t gemm_u8_8x8_dot_nobias_identity::get_workspace_size() const { | |||
| #endif | |||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_nodot_nobias_identity) | |||
| void gemm_u8_8x8_nodot_nobias_identity::pack_A(dt_uint8* outptr, | |||
| const dt_uint8* inptr, int ldin, | |||
| int y0, int ymax, int k0, int kmax, | |||
| bool transpose) const { | |||
| void gemm_u8_8x8_nodot_nobias_identity::pack_A( | |||
| dt_uint8* outptr, const dt_uint8* inptr, int ldin, int y0, int ymax, int k0, | |||
| int kmax, bool transpose) const { | |||
| uint8_t zA = A_dtype.param<dtype::Quantized8Asymm>().zero_point; | |||
| if (transpose) { | |||
| matmul_8x8x8::gemm_u8_8x8_transpose_pack_A_n(outptr, inptr, ldin, y0, | |||
| ymax, k0, kmax, zA); | |||
| matmul_8x8x8::gemm_u8_8x8_transpose_pack_A_n( | |||
| outptr, inptr, ldin, y0, ymax, k0, kmax, zA); | |||
| } else { | |||
| matmul_8x8x8::gemm_u8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, | |||
| kmax, zA); | |||
| matmul_8x8x8::gemm_u8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax, zA); | |||
| } | |||
| } | |||
| void gemm_u8_8x8_nodot_nobias_identity::pack_B(dt_uint8* out, const dt_uint8* in, | |||
| int ldin, int x0, int xmax, int k0, | |||
| int kmax, bool transpose) const { | |||
| void gemm_u8_8x8_nodot_nobias_identity::pack_B( | |||
| dt_uint8* out, const dt_uint8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||
| bool transpose) const { | |||
| uint8_t zB = B_dtype.param<dtype::Quantized8Asymm>().zero_point; | |||
| if (transpose) { | |||
| matmul_8x8x8::gemm_u8_8x8_transpose_pack_B_n(out, in, ldin, x0, xmax, | |||
| k0, kmax, zB); | |||
| matmul_8x8x8::gemm_u8_8x8_transpose_pack_B_n( | |||
| out, in, ldin, x0, xmax, k0, kmax, zB); | |||
| } else { | |||
| matmul_8x8x8::gemm_u8_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax, | |||
| zB); | |||
| matmul_8x8x8::gemm_u8_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax, zB); | |||
| } | |||
| } | |||
| @@ -280,22 +272,21 @@ size_t gemm_u8_8x8_nodot_nobias_identity::get_workspace_size() const { | |||
| return 8 * 8 * sizeof(dt_int32); | |||
| } | |||
| #define KERN(_block_m, _block_n, _dot, _suffix, _bias, _BIAS, _nonline, \ | |||
| _OP) \ | |||
| void gemm_u8_##_block_m##x##_block_n##_suffix##_##_bias##_##_nonline:: \ | |||
| kern(const dt_uint8* packA, const dt_uint8* packB, size_t M, \ | |||
| size_t N, size_t K, dt_uint8* C, size_t LDC, bool is_first_k, \ | |||
| const dt_int32* bias, dt_int32* workspace) const { \ | |||
| float scale_A = A_dtype.param<dtype::Quantized8Asymm>().scale; \ | |||
| uint8_t zp_A = A_dtype.param<dtype::Quantized8Asymm>().zero_point; \ | |||
| float scale_B = B_dtype.param<dtype::Quantized8Asymm>().scale; \ | |||
| uint8_t zp_B = B_dtype.param<dtype::Quantized8Asymm>().zero_point; \ | |||
| float scale_C = C_dtype.param<dtype::Quantized8Asymm>().scale; \ | |||
| uint8_t zp_C = C_dtype.param<dtype::Quantized8Asymm>().zero_point; \ | |||
| DEFINE_OP(_OP); \ | |||
| impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n, _dot>::run( \ | |||
| packA, packB, M, N, K, C, LDC, is_first_k, op, bias, \ | |||
| workspace, zp_A, zp_B); \ | |||
| #define KERN(_block_m, _block_n, _dot, _suffix, _bias, _BIAS, _nonline, _OP) \ | |||
| void gemm_u8_##_block_m##x##_block_n##_suffix##_##_bias##_##_nonline::kern( \ | |||
| const dt_uint8* packA, const dt_uint8* packB, size_t M, size_t N, \ | |||
| size_t K, dt_uint8* C, size_t LDC, bool is_first_k, const dt_int32* bias, \ | |||
| dt_int32* workspace) const { \ | |||
| float scale_A = A_dtype.param<dtype::Quantized8Asymm>().scale; \ | |||
| uint8_t zp_A = A_dtype.param<dtype::Quantized8Asymm>().zero_point; \ | |||
| float scale_B = B_dtype.param<dtype::Quantized8Asymm>().scale; \ | |||
| uint8_t zp_B = B_dtype.param<dtype::Quantized8Asymm>().zero_point; \ | |||
| float scale_C = C_dtype.param<dtype::Quantized8Asymm>().scale; \ | |||
| uint8_t zp_C = C_dtype.param<dtype::Quantized8Asymm>().zero_point; \ | |||
| DEFINE_OP(_OP); \ | |||
| impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n, _dot>::run( \ | |||
| packA, packB, M, N, K, C, LDC, is_first_k, op, bias, workspace, zp_A, \ | |||
| zp_B); \ | |||
| } | |||
| #define DEFINE_OP(_Op) \ | |||
| @@ -311,17 +302,22 @@ KERN(8, 8, false, _nodot, nobias, BiasMode::NO_BIAS, relu, ReluOp) | |||
| KERN(8, 8, false, _nodot, nobias, BiasMode::NO_BIAS, hswish, HSwishOp) | |||
| #undef DEFINE_OP | |||
| #define DEFINE_OP(_Op) \ | |||
| arm_common::_Op<dt_qint32, dt_quint8> op(scale_A* scale_B, \ | |||
| scale_A* scale_B, scale_C, zp_C); | |||
| #define DEFINE_OP(_Op) \ | |||
| arm_common::_Op<dt_qint32, dt_quint8> op( \ | |||
| scale_A* scale_B, scale_A* scale_B, scale_C, zp_C); | |||
| #if MGB_ENABLE_DOT | |||
| KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) | |||
| KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) | |||
| KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, FuseAddHSwishOp) | |||
| KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, | |||
| FuseAddReluOp) | |||
| KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, | |||
| FuseAddHSwishOp) | |||
| #endif | |||
| KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) | |||
| KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) | |||
| KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, FuseAddHSwishOp) | |||
| KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, | |||
| AddOp) | |||
| KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, | |||
| FuseAddReluOp) | |||
| KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, | |||
| FuseAddHSwishOp) | |||
| #undef DEFINE_OP | |||
| #undef KERN | |||
| @@ -16,46 +16,44 @@ namespace aarch64 { | |||
| namespace matmul { | |||
| #if MGB_ENABLE_DOT | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_uint8, dt_uint8, dt_int32, 8, 8, 4, | |||
| false, true, | |||
| gemm_u8_8x8_dot_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK( | |||
| dt_uint8, dt_uint8, dt_int32, 8, 8, 4, false, true, | |||
| gemm_u8_8x8_dot_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_nobias_relu, | |||
| gemm_u8_8x8_dot_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||
| gemm_u8_8x8_dot_nobias_relu, gemm_u8_8x8_dot_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_nobias_hswish, | |||
| gemm_u8_8x8_dot_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||
| gemm_u8_8x8_dot_nobias_hswish, gemm_u8_8x8_dot_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_bias_channel_identity, | |||
| gemm_u8_8x8_dot_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||
| gemm_u8_8x8_dot_bias_channel_identity, gemm_u8_8x8_dot_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_bias_channel_relu, | |||
| gemm_u8_8x8_dot_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_bias_channel_hswish, | |||
| gemm_u8_8x8_dot_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||
| gemm_u8_8x8_dot_bias_channel_relu, gemm_u8_8x8_dot_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||
| gemm_u8_8x8_dot_bias_channel_hswish, gemm_u8_8x8_dot_nobias_identity); | |||
| #endif | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_uint8, dt_uint8, dt_int32, 8, 8, 8, | |||
| false, true, | |||
| gemm_u8_8x8_nodot_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_nobias_relu, | |||
| gemm_u8_8x8_nodot_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK( | |||
| dt_uint8, dt_uint8, dt_int32, 8, 8, 8, false, true, | |||
| gemm_u8_8x8_nodot_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_nobias_hswish, | |||
| gemm_u8_8x8_nodot_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||
| gemm_u8_8x8_nodot_nobias_relu, gemm_u8_8x8_nodot_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_bias_channel_identity, | |||
| gemm_u8_8x8_nodot_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||
| gemm_u8_8x8_nodot_nobias_hswish, gemm_u8_8x8_nodot_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_bias_channel_relu, | |||
| gemm_u8_8x8_nodot_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||
| gemm_u8_8x8_nodot_bias_channel_identity, gemm_u8_8x8_nodot_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_bias_channel_hswish, | |||
| gemm_u8_8x8_nodot_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||
| gemm_u8_8x8_nodot_bias_channel_relu, gemm_u8_8x8_nodot_nobias_identity); | |||
| MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER( | |||
| gemm_u8_8x8_nodot_bias_channel_hswish, gemm_u8_8x8_nodot_nobias_identity); | |||
| } // namespace matmul | |||
| } // namespace aarch64 | |||
| @@ -11,11 +11,11 @@ | |||
| #include "src/common/handle_impl.h" | |||
| #include "src/aarch64/conv_bias/opr_impl.h" | |||
| #include "src/aarch64/handle.h" | |||
| #include "src/aarch64/matrix_mul/opr_impl.h" | |||
| #include "src/aarch64/rotate/opr_impl.h" | |||
| #include "src/aarch64/relayout/opr_impl.h" | |||
| #include "src/aarch64/conv_bias/opr_impl.h" | |||
| #include "src/aarch64/rotate/opr_impl.h" | |||
| #include "src/aarch64/warp_perspective/opr_impl.h" | |||
| namespace megdnn { | |||
| @@ -38,7 +38,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(WarpPerspective) | |||
| MEGDNN_FOREACH_OPR_CLASS(MEGDNN_INST_CREATE_OPERATOR) | |||
| #pragma GCC diagnostic pop | |||
| } // namespace aarch64 | |||
| } // namespace megdnn | |||
| } // namespace aarch64 | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -14,20 +14,18 @@ | |||
| namespace megdnn { | |||
| namespace aarch64 { | |||
| class HandleImpl: public arm_common::HandleImpl { | |||
| public: | |||
| HandleImpl(megcoreComputingHandle_t computing_handle, | |||
| HandleType type = HandleType::AARCH64): | |||
| arm_common::HandleImpl::HandleImpl(computing_handle, type) | |||
| {} | |||
| class HandleImpl : public arm_common::HandleImpl { | |||
| public: | |||
| HandleImpl( | |||
| megcoreComputingHandle_t computing_handle, | |||
| HandleType type = HandleType::AARCH64) | |||
| : arm_common::HandleImpl::HandleImpl(computing_handle, type) {} | |||
| template <typename Opr> | |||
| std::unique_ptr<Opr> create_operator(); | |||
| template <typename Opr> | |||
| std::unique_ptr<Opr> create_operator(); | |||
| }; | |||
| } // namespace aarch64 | |||
| } // namespace megdnn | |||
| } // namespace aarch64 | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -21,9 +21,7 @@ namespace aarch64 { | |||
| class MatrixMulImpl::AlgoF32K8x12x1 final : public AlgoBase { | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| } | |||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
| const char* name() const override { return "AARCH64_F32K8X12X1"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| @@ -35,8 +33,7 @@ public: | |||
| class MatrixMulImpl::AlgoF32MK4_8x12x1 final : public AlgoBase { | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE | | |||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
| return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
| } | |||
| const char* name() const override { return "AARCH64_F32_MK4_K8X12X1"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| @@ -48,9 +45,7 @@ public: | |||
| class MatrixMulImpl::AlgoF32K4x16x1 final : public AlgoBase { | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| } | |||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
| const char* name() const override { return "AARCH64_F32K4X16X1"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| @@ -61,9 +56,7 @@ public: | |||
| class MatrixMulImpl::AlgoF32MK4_4x16 final : public AlgoBase { | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| } | |||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
| const char* name() const override { return "AARCH64_F32_MK4_4x16"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| @@ -73,8 +66,7 @@ public: | |||
| MEGDNN_DECL_ALGO_TYPE(AARCH64_F32_MK4_4x16) | |||
| }; | |||
| class MatrixMulImpl::AlgoF32Gemv final | |||
| : public arm_common::MatrixMulImpl::AlgoF32Gemv { | |||
| class MatrixMulImpl::AlgoF32Gemv final : public arm_common::MatrixMulImpl::AlgoF32Gemv { | |||
| public: | |||
| AlgoF32Gemv() : arm_common::MatrixMulImpl::AlgoF32Gemv() { | |||
| m_handle_type = Handle::HandleType::AARCH64; | |||
| @@ -85,9 +77,7 @@ public: | |||
| #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
| class MatrixMulImpl::AlgoF16K8x24x1 final : public AlgoBase { | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| } | |||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
| const char* name() const override { return "AARCH64_F16_K8X24X1"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| @@ -98,9 +88,7 @@ public: | |||
| class MatrixMulImpl::AlgoF16MK8_8x8 final : public AlgoBase { | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| } | |||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
| const char* name() const override { return "AARCH64_F16_MK8_8X8"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| @@ -115,12 +103,8 @@ public: | |||
| #if MGB_ENABLE_DOT | |||
| class MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd final : public AlgoBase { | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| } | |||
| const char* name() const override { | |||
| return "AARCH64_INT8X8X32_K8X12X4_DOTPROD"; | |||
| } | |||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
| const char* name() const override { return "AARCH64_INT8X8X32_K8X12X4_DOTPROD"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| @@ -130,12 +114,8 @@ public: | |||
| class MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd final : public AlgoBase { | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| } | |||
| const char* name() const override { | |||
| return "AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD"; | |||
| } | |||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
| const char* name() const override { return "AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| @@ -147,8 +127,7 @@ public: | |||
| class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase { | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE | | |||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
| return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
| } | |||
| const char* name() const override { return "AARCH64_INT8X8X32_MK4_4X4X16"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| @@ -163,9 +142,7 @@ public: | |||
| class MatrixMulImpl::AlgoInt8x8x32K4x4x16 final : public AlgoBase { | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| } | |||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
| const char* name() const override { return "AARCH64_INT8X8X32_K4X4X16"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| bool preferred(const KernSizeParam&) const override; | |||
| @@ -178,9 +155,7 @@ public: | |||
| class MatrixMulImpl::AlgoInt8x8x32K8x8x8 final : public AlgoBase { | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| } | |||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
| const char* name() const override { return "AARCH64_INT8X8X32_K8X8X8"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| bool preferred(const KernSizeParam&) const override; | |||
| @@ -192,9 +167,7 @@ public: | |||
| class MatrixMulImpl::AlgoInt8x8x16K8x8x8 final : public AlgoBase { | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| } | |||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
| const char* name() const override { return "AARCH64_INT8X8X16_K8X8X8"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| bool preferred(const KernSizeParam&) const override; | |||
| @@ -207,9 +180,7 @@ public: | |||
| class MatrixMulImpl::AlgoInt8x8x16K4x4x16 final : public AlgoBase { | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| } | |||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
| const char* name() const override { return "AARCH64_INT8X8X16_K4X4X16"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| bool preferred(const KernSizeParam&) const override; | |||
| @@ -222,8 +193,7 @@ public: | |||
| class MatrixMulImpl::AlgoInt4x4x16K8x8x8 final : public AlgoBase { | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE | | |||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
| return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
| } | |||
| const char* name() const override { return "AARCH64_INT4X4X16_K8X8X8"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| @@ -238,12 +208,9 @@ public: | |||
| class MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4 final : public AlgoBase { | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE | | |||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
| } | |||
| const char* name() const override { | |||
| return "AARCH64_INT8X8X16_MK4_16X12X4"; | |||
| return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
| } | |||
| const char* name() const override { return "AARCH64_INT8X8X16_MK4_16X12X4"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| bool preferred(const KernSizeParam&) const override; | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| @@ -257,12 +224,9 @@ public: | |||
| class MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8 final : public AlgoBase { | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE | | |||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
| } | |||
| const char* name() const override { | |||
| return "AARCH64_INT8X8X16_MK4_K8X8X8"; | |||
| return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
| } | |||
| const char* name() const override { return "AARCH64_INT8X8X16_MK4_K8X8X8"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| bool preferred(const KernSizeParam&) const override; | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| @@ -276,8 +240,7 @@ public: | |||
| class MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8 final : public AlgoBase { | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE | | |||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
| return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
| } | |||
| const char* name() const override { return "AARCH64_INT8X8X16_MK4_4X4X8"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| @@ -292,9 +255,7 @@ public: | |||
| class MatrixMulImpl::AlgoInt16x16x32K12x8x1 final : public AlgoBase { | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| } | |||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
| const char* name() const override { return "AARCH64_INT16X16X32_K12X8X1"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| bool preferred(const KernSizeParam&) const override; | |||
| @@ -306,9 +267,7 @@ public: | |||
| class MatrixMulImpl::AlgoInt16x16x32MK8_8x8 final : public AlgoBase { | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| } | |||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
| const char* name() const override { return "AARCH64_INT16X16X32_MK8_8X8"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| @@ -321,12 +280,8 @@ public: | |||
| #if MGB_ENABLE_DOT | |||
| class MatrixMulImpl::AlgoQuint8K8x8x4DotProd final : public AlgoBase { | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| } | |||
| const char* name() const override { | |||
| return "AARCH64_QUINT8_K8X8X4_DOTPROD"; | |||
| } | |||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
| const char* name() const override { return "AARCH64_QUINT8_K8X8X4_DOTPROD"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| kern_t get_kern(const KernSizeParam&) const override; | |||
| @@ -336,8 +291,7 @@ public: | |||
| class MatrixMulImpl::AlgoQuint8GemvDotProd final : public AlgoBase { | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE | | |||
| AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
| return AlgoAttribute::REPRODUCIBLE | AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||
| } | |||
| const char* name() const override { return "AARCH64_QUINT8_GEMV_DOTPROD"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| @@ -352,9 +306,7 @@ public: | |||
| #endif | |||
| class MatrixMulImpl::AlgoQuint8K8x8x8 final : public AlgoBase { | |||
| public: | |||
| AlgoAttribute attribute() const override { | |||
| return AlgoAttribute::REPRODUCIBLE; | |||
| } | |||
| AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; } | |||
| const char* name() const override { return "AARCH64_QUINT8_K8X8X8"; } | |||
| bool usable(const KernSizeParam&) const override; | |||
| size_t get_workspace(const KernSizeParam&) const override; | |||
| @@ -16,11 +16,11 @@ namespace megdnn { | |||
| namespace aarch64 { | |||
| namespace matmul { | |||
| MEGDNN_REG_GEMM_STRATEGY(dt_float16, dt_float16, dt_float16, 8, 24, 1, false, | |||
| true, hgemm_8x24); | |||
| MEGDNN_REG_GEMM_STRATEGY( | |||
| dt_float16, dt_float16, dt_float16, 8, 24, 1, false, true, hgemm_8x24); | |||
| MEGDNN_REG_GEMM_STRATEGY_NOPACK(dt_float16, dt_float16, dt_float16, 8, 8, 1, | |||
| false, true, gemm_nopack_f16_8x8); | |||
| MEGDNN_REG_GEMM_STRATEGY_NOPACK( | |||
| dt_float16, dt_float16, dt_float16, 8, 8, 1, false, true, gemm_nopack_f16_8x8); | |||
| } // namespace matmul | |||
| } // namespace aarch64 | |||
| @@ -9,8 +9,8 @@ | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "src/aarch64/matrix_mul/fp16/strategy.h" | |||
| #include "src/aarch64/matrix_mul/asm/common.h" | |||
| #include "src/aarch64/matrix_mul/fp16/strategy.h" | |||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||
| #include "src/common/utils.h" | |||
| @@ -21,8 +21,9 @@ using namespace aarch64::matmul; | |||
| namespace { | |||
| void kern_8x1(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||
| dt_float16* output) { | |||
| void kern_8x1( | |||
| const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||
| dt_float16* output) { | |||
| LDB *= sizeof(dt_float16); | |||
| asm volatile( | |||
| ".arch armv8.2-a+fp16\n" | |||
| @@ -86,9 +87,8 @@ void kern_8x1(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
| [output] "+r"(output), [LDB] "+r"(LDB) | |||
| : | |||
| : "v0", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", | |||
| "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", | |||
| "memory"); | |||
| : "v0", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", | |||
| "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory"); | |||
| } | |||
| // Overview of register layout: | |||
| @@ -115,8 +115,9 @@ void kern_8x1(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||
| // |v23[0-7]| |v27[0-7]| | |||
| // +--------+ +--------+ | |||
| // Accumulator | |||
| void kern_8x4(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||
| dt_float16* output) { | |||
| void kern_8x4( | |||
| const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||
| dt_float16* output) { | |||
| //! LDB means number of elements in one block in B. we will read 24 numbers | |||
| //! first. so minus 24 * 2 bytes here. | |||
| LDB = (LDB - 24) * sizeof(dt_float16); | |||
| @@ -263,8 +264,8 @@ void kern_8x4(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
| [output] "+r"(output), [LDB] "+r"(LDB) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19", "v20", "v21", | |||
| "v22", "v23", "v24", "v25", "v26", "v27", "cc", "memory"); | |||
| : "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19", "v20", "v21", "v22", | |||
| "v23", "v24", "v25", "v26", "v27", "cc", "memory"); | |||
| } | |||
| // Overview of register layout: | |||
| @@ -295,8 +296,9 @@ void kern_8x4(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||
| // | v7[0-7]| |v31[0-7]| | |||
| // +--------+ +--------+ | |||
| // Accumulator | |||
| void kern_8x8(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||
| dt_float16* output) { | |||
| void kern_8x8( | |||
| const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||
| dt_float16* output) { | |||
| //! As each load 128 number from B, but the pos add 112 * 2, so we minus 112 | |||
| //! here. | |||
| LDB = (LDB - 32) * sizeof(dt_float16); | |||
| @@ -467,20 +469,19 @@ void kern_8x8(const dt_float16* a_ptr, const dt_float16* b_ptr, int LDB, int K, | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
| [output] "+r"(output), [LDB] "+r"(LDB) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
| "v11", "v12", "v13", "v14", "v15", "v24", "v25", "v26", "v27", | |||
| "v28", "v29", "v30", "v31", "cc", "memory"); | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||
| "v12", "v13", "v14", "v15", "v24", "v25", "v26", "v27", "v28", "v29", | |||
| "v30", "v31", "cc", "memory"); | |||
| } | |||
| } // anonymous namespace | |||
| MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gemm_nopack_f16_8x8); | |||
| void gemm_nopack_f16_8x8::kern(const dt_float16* A, size_t LDA, | |||
| const dt_float16* B, size_t LDB, dt_float16* C, | |||
| size_t LDC, size_t M, size_t K, size_t N, | |||
| const dt_float16*, void*, bool trA, | |||
| bool trB) const { | |||
| void gemm_nopack_f16_8x8::kern( | |||
| const dt_float16* A, size_t LDA, const dt_float16* B, size_t LDB, dt_float16* C, | |||
| size_t LDC, size_t M, size_t K, size_t N, const dt_float16*, void*, bool trA, | |||
| bool trB) const { | |||
| constexpr static size_t MB = 8; | |||
| constexpr static size_t KB = 8; | |||
| constexpr static size_t NB = 8; | |||
| @@ -17,21 +17,23 @@ | |||
| namespace megdnn { | |||
| namespace aarch64 { | |||
| MEGDNN_NOINLINE void sgemm_packA_n(const float* A, float* Apacked, size_t M, | |||
| size_t K, size_t LDA, const float* alpha); | |||
| MEGDNN_NOINLINE void sgemm_packA_n( | |||
| const float* A, float* Apacked, size_t M, size_t K, size_t LDA, | |||
| const float* alpha); | |||
| MEGDNN_NOINLINE void sgemm_packA_t(const float* A, float* Apacked, size_t M, | |||
| size_t K, size_t LDA, const float* alpha); | |||
| MEGDNN_NOINLINE void sgemm_packA_t( | |||
| const float* A, float* Apacked, size_t M, size_t K, size_t LDA, | |||
| const float* alpha); | |||
| MEGDNN_NOINLINE void sgemm_packB_n(const float* B, float* Bpacked, size_t K, | |||
| size_t N, size_t LDB); | |||
| MEGDNN_NOINLINE void sgemm_packB_n( | |||
| const float* B, float* Bpacked, size_t K, size_t N, size_t LDB); | |||
| MEGDNN_NOINLINE void sgemm_packB_t(const float* B, float* Bpacked, size_t K, | |||
| size_t N, size_t LDB); | |||
| MEGDNN_NOINLINE void sgemm_packB_t( | |||
| const float* B, float* Bpacked, size_t K, size_t N, size_t LDB); | |||
| MEGDNN_NOINLINE void sgemm_kernel12x8(const float* A, const float* B, float* C, | |||
| size_t LDC, size_t M, size_t N, size_t K, | |||
| int type, const float* beta); | |||
| MEGDNN_NOINLINE void sgemm_kernel12x8( | |||
| const float* A, const float* B, float* C, size_t LDC, size_t M, size_t N, | |||
| size_t K, int type, const float* beta); | |||
| } // namespace aarch64 | |||
| } // namespace megdnn | |||
| @@ -12,7 +12,6 @@ | |||
| #include "src/aarch64/matrix_mul/asm/common.h" | |||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||
| namespace megdnn { | |||
| namespace aarch64 { | |||
| namespace matmul_general_4x16 { | |||
| @@ -39,8 +38,9 @@ namespace matmul_general_4x16 { | |||
| // +--+ - - - - +--------+--------+--------+--------+ | |||
| // | |||
| // Accumulator | |||
| void kern_4x16(const float* packA, const float* packB, int K, | |||
| float* output, int LDC, bool is_first_k, int m_remain) { | |||
| void kern_4x16( | |||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||
| bool is_first_k, int m_remain) { | |||
| const float* a_ptr = packA; | |||
| const float* b_ptr = packB; | |||
| int oddk = (K & 1); | |||
| @@ -224,14 +224,14 @@ void kern_4x16(const float* packA, const float* packB, int K, | |||
| "6:\n" STORE_C | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC), | |||
| [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
| [m_remain] "+r"(m_remain), [outptr] "+r"(outptr) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||
| "v20", "v21", "v22", "v23", "v24", "v25", "x1", "x2", "x3", "x9", | |||
| "x10", "cc", "memory"); | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||
| "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||
| "v22", "v23", "v24", "v25", "x1", "x2", "x3", "x9", "x10", "cc", | |||
| "memory"); | |||
| #undef LOAD_LINE | |||
| #undef LOAD_C | |||
| @@ -263,8 +263,9 @@ void kern_4x16(const float* packA, const float* packB, int K, | |||
| // +--+--+ - - - - +--------+ | |||
| // | |||
| // Accumulator | |||
| void kern_4x4(const float* packA, const float* packB, int K, float* output, | |||
| int LDC, bool is_first_k, int m_remain, int n_remain) { | |||
| void kern_4x4( | |||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||
| bool is_first_k, int m_remain, int n_remain) { | |||
| const float* a_ptr = packA; | |||
| const float* b_ptr = packB; | |||
| int oddk = (K & 1); | |||
| @@ -330,99 +331,100 @@ void kern_4x4(const float* packA, const float* packB, int K, float* output, | |||
| STORE_LINE("6", "2") \ | |||
| STORE_LINE("7", "3") \ | |||
| "105:\n" | |||
| // clang-format on | |||
| asm volatile( | |||
| // load accumulator C | |||
| "add x1, x0, %x[LDC]\n" | |||
| "add x2, x1, %x[LDC]\n" | |||
| "add x3, x2, %x[LDC]\n" | |||
| "cmp %w[is_first_k], #1\n" | |||
| "beq 1f\n" LOAD_C | |||
| "b 2f\n" | |||
| "1:\n" | |||
| "eor v4.16b, v4.16b, v4.16b\n" | |||
| "eor v5.16b, v5.16b, v5.16b\n" | |||
| "eor v6.16b, v6.16b, v6.16b\n" | |||
| "eor v7.16b, v7.16b, v7.16b\n" | |||
| "2: \n" | |||
| "ld1 {v0.4s}, [%[a_ptr]], 16\n" | |||
| "ld1 {v2.4s}, [%[b_ptr]], 16\n" | |||
| "cmp %w[K], #0\n" | |||
| "beq 4f\n" | |||
| "3:\n" | |||
| "ld1 {v1.4s}, [%[a_ptr]], 16\n" | |||
| "ld1 {v3.4s}, [%[b_ptr]], 16\n" | |||
| "fmla v4.4s, v2.4s, v0.s[0]\n" | |||
| "fmla v5.4s, v2.4s, v0.s[1]\n" | |||
| "fmla v6.4s, v2.4s, v0.s[2]\n" | |||
| "fmla v7.4s, v2.4s, v0.s[3]\n" | |||
| "ld1 {v0.4s}, [%[a_ptr]], 16\n" | |||
| "ld1 {v2.4s}, [%[b_ptr]], 16\n" | |||
| "fmla v4.4s, v3.4s, v1.s[0]\n" | |||
| "fmla v5.4s, v3.4s, v1.s[1]\n" | |||
| "fmla v6.4s, v3.4s, v1.s[2]\n" | |||
| "fmla v7.4s, v3.4s, v1.s[3]\n" | |||
| "subs %w[K], %w[K], #1\n" | |||
| "bne 3b\n" | |||
| "4:\n" | |||
| "cmp %w[oddk], #1\n" | |||
| "beq 5f\n" | |||
| // Even tail | |||
| "ld1 {v1.4s}, [%[a_ptr]], 16\n" | |||
| "ld1 {v3.4s}, [%[b_ptr]], 16\n" | |||
| "fmla v4.4s, v2.4s, v0.s[0]\n" | |||
| "fmla v5.4s, v2.4s, v0.s[1]\n" | |||
| "fmla v6.4s, v2.4s, v0.s[2]\n" | |||
| "fmla v7.4s, v2.4s, v0.s[3]\n" | |||
| "fmla v4.4s, v3.4s, v1.s[0]\n" | |||
| "fmla v5.4s, v3.4s, v1.s[1]\n" | |||
| "fmla v6.4s, v3.4s, v1.s[2]\n" | |||
| "fmla v7.4s, v3.4s, v1.s[3]\n" | |||
| "b 6f\n" | |||
| // odd tail | |||
| "5:\n" | |||
| "fmla v4.4s, v2.4s, v0.s[0]\n" | |||
| "fmla v5.4s, v2.4s, v0.s[1]\n" | |||
| "fmla v6.4s, v2.4s, v0.s[2]\n" | |||
| "fmla v7.4s, v2.4s, v0.s[3]\n" | |||
| "6:\n" STORE_C | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||
| [oddk] "+r"(oddk), [m_remain] "+r"(m_remain), | |||
| [n_remain] "+r"(n_remain), [outptr] "+r"(outptr) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "x1", | |||
| "x2", "x3", "x10", "cc", "memory"); | |||
| // clang-format on | |||
| asm volatile( | |||
| // load accumulator C | |||
| "add x1, x0, %x[LDC]\n" | |||
| "add x2, x1, %x[LDC]\n" | |||
| "add x3, x2, %x[LDC]\n" | |||
| "cmp %w[is_first_k], #1\n" | |||
| "beq 1f\n" LOAD_C | |||
| "b 2f\n" | |||
| "1:\n" | |||
| "eor v4.16b, v4.16b, v4.16b\n" | |||
| "eor v5.16b, v5.16b, v5.16b\n" | |||
| "eor v6.16b, v6.16b, v6.16b\n" | |||
| "eor v7.16b, v7.16b, v7.16b\n" | |||
| "2: \n" | |||
| "ld1 {v0.4s}, [%[a_ptr]], 16\n" | |||
| "ld1 {v2.4s}, [%[b_ptr]], 16\n" | |||
| "cmp %w[K], #0\n" | |||
| "beq 4f\n" | |||
| "3:\n" | |||
| "ld1 {v1.4s}, [%[a_ptr]], 16\n" | |||
| "ld1 {v3.4s}, [%[b_ptr]], 16\n" | |||
| "fmla v4.4s, v2.4s, v0.s[0]\n" | |||
| "fmla v5.4s, v2.4s, v0.s[1]\n" | |||
| "fmla v6.4s, v2.4s, v0.s[2]\n" | |||
| "fmla v7.4s, v2.4s, v0.s[3]\n" | |||
| "ld1 {v0.4s}, [%[a_ptr]], 16\n" | |||
| "ld1 {v2.4s}, [%[b_ptr]], 16\n" | |||
| "fmla v4.4s, v3.4s, v1.s[0]\n" | |||
| "fmla v5.4s, v3.4s, v1.s[1]\n" | |||
| "fmla v6.4s, v3.4s, v1.s[2]\n" | |||
| "fmla v7.4s, v3.4s, v1.s[3]\n" | |||
| "subs %w[K], %w[K], #1\n" | |||
| "bne 3b\n" | |||
| "4:\n" | |||
| "cmp %w[oddk], #1\n" | |||
| "beq 5f\n" | |||
| // Even tail | |||
| "ld1 {v1.4s}, [%[a_ptr]], 16\n" | |||
| "ld1 {v3.4s}, [%[b_ptr]], 16\n" | |||
| "fmla v4.4s, v2.4s, v0.s[0]\n" | |||
| "fmla v5.4s, v2.4s, v0.s[1]\n" | |||
| "fmla v6.4s, v2.4s, v0.s[2]\n" | |||
| "fmla v7.4s, v2.4s, v0.s[3]\n" | |||
| "fmla v4.4s, v3.4s, v1.s[0]\n" | |||
| "fmla v5.4s, v3.4s, v1.s[1]\n" | |||
| "fmla v6.4s, v3.4s, v1.s[2]\n" | |||
| "fmla v7.4s, v3.4s, v1.s[3]\n" | |||
| "b 6f\n" | |||
| // odd tail | |||
| "5:\n" | |||
| "fmla v4.4s, v2.4s, v0.s[0]\n" | |||
| "fmla v5.4s, v2.4s, v0.s[1]\n" | |||
| "fmla v6.4s, v2.4s, v0.s[2]\n" | |||
| "fmla v7.4s, v2.4s, v0.s[3]\n" | |||
| "6:\n" STORE_C | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC), | |||
| [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
| [m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain), | |||
| [outptr] "+r"(outptr) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "x1", "x2", "x3", "x10", | |||
| "cc", "memory"); | |||
| #undef LOAD_LINE | |||
| #undef LOAD_C | |||
| #undef STORE_LINE | |||
| #undef STORE_C | |||
| } | |||
| void sgemm_4x16_pack_A_n(float * outptr, const float * inptr, int ldin, int y0, | |||
| int ymax, int k0, int kmax) { | |||
| void sgemm_4x16_pack_A_n( | |||
| float* outptr, const float* inptr, int ldin, int y0, int ymax, int k0, | |||
| int kmax) { | |||
| float zerobuff[4]; | |||
| std::memset(zerobuff, 0, sizeof(float) * 4); | |||
| constexpr int PACK_SIZE = 4*4; | |||
| constexpr int PACK_SIZE = 4 * 4; | |||
| int y = y0; | |||
| for (; y + 3 < ymax; y += 4) { | |||
| // printf("main loop pack_a_n %p \n",outptr); | |||
| // printf("main loop pack_a_n %p \n",outptr); | |||
| const float* inptr0 = inptr + y * ldin + k0; | |||
| const float* inptr1 = inptr0 + ldin; | |||
| const float* inptr2 = inptr1 + ldin; | |||
| @@ -459,9 +461,11 @@ void sgemm_4x16_pack_A_n(float * outptr, const float * inptr, int ldin, int y0, | |||
| switch ((y + 3) - ymax) { | |||
| /* Everything falls through in here */ | |||
| case 2: | |||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr1 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 1: | |||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr2 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 0: | |||
| inptr3 = zerobuff; | |||
| break; | |||
| @@ -478,9 +482,11 @@ void sgemm_4x16_pack_A_n(float * outptr, const float * inptr, int ldin, int y0, | |||
| if (y + 3 >= ymax) { | |||
| switch (y + 3 - ymax) { | |||
| case 2: | |||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr1 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 1: | |||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr2 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 0: | |||
| inptr3 = zerobuff; | |||
| break; | |||
| @@ -493,8 +499,8 @@ void sgemm_4x16_pack_A_n(float * outptr, const float * inptr, int ldin, int y0, | |||
| } | |||
| } | |||
| void sgemm_4x16_pack_A_t(float* out, const float* in, int ldin, int x0, | |||
| int xmax, int k0, int kmax) { | |||
| void sgemm_4x16_pack_A_t( | |||
| float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||
| int ksize = kmax - k0; | |||
| int ksize4 = (ksize << 2); | |||
| float* outptr_base = out; | |||
| @@ -515,8 +521,7 @@ void sgemm_4x16_pack_A_t(float* out, const float* in, int ldin, int x0, | |||
| auto outptr = outptr_base; | |||
| for (; x + 4 <= xmax; x += 4) { | |||
| auto outptr_interleave = outptr; | |||
| interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, | |||
| outptr_interleave); | |||
| interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); | |||
| outptr += ksize4; | |||
| } | |||
| @@ -546,8 +551,8 @@ void sgemm_4x16_pack_A_t(float* out, const float* in, int ldin, int x0, | |||
| } | |||
| } | |||
| void sgemm_4x16_pack_B_n(float* out, const float* in, int ldin, | |||
| int x0, int xmax, int k0, int kmax) { | |||
| void sgemm_4x16_pack_B_n( | |||
| float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||
| int ksize = kmax - k0; | |||
| int ksize16 = ksize * 16; | |||
| int ksize4 = (ksize << 2); | |||
| @@ -570,15 +575,13 @@ void sgemm_4x16_pack_B_n(float* out, const float* in, int ldin, | |||
| auto outptr = outptr_base; | |||
| for (; x + 16 <= xmax; x += 16) { | |||
| auto outptr_interleave = outptr; | |||
| interleave_4x16_1_s(inptr, inptr1, inptr2, inptr3, | |||
| outptr_interleave); | |||
| interleave_4x16_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); | |||
| outptr += ksize16; | |||
| } | |||
| outptr = outptr_base4; | |||
| for (; x + 4 <= xmax; x += 4) { | |||
| auto outptr_interleave = outptr; | |||
| interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, | |||
| outptr_interleave); | |||
| interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); | |||
| outptr += ksize4; | |||
| } | |||
| @@ -616,8 +619,8 @@ void sgemm_4x16_pack_B_n(float* out, const float* in, int ldin, | |||
| } | |||
| } | |||
| void sgemm_4x16_pack_B_t(float* out, const float* in, int ldin, | |||
| int y0, int ymax, int k0, int kmax) { | |||
| void sgemm_4x16_pack_B_t( | |||
| float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax) { | |||
| float* outptr = out; | |||
| const float* inptr = in; | |||
| float zerobuff[4]; | |||
| @@ -642,8 +645,7 @@ void sgemm_4x16_pack_B_t(float* out, const float* in, int ldin, | |||
| int x = (kmax - k0); | |||
| for (; x > 3; x -= 4) { | |||
| transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr_inner, | |||
| 64); | |||
| transpose_4x4_1_s(inptr0, inptr1, inptr2, inptr3, outptr_inner, 64); | |||
| outptr_inner += 64; | |||
| } | |||
| for (; x > 0; x--) { | |||
| @@ -676,9 +678,11 @@ void sgemm_4x16_pack_B_t(float* out, const float* in, int ldin, | |||
| switch ((y + 3) - ymax) { | |||
| /* Everything falls through in here */ | |||
| case 2: | |||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr1 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 1: | |||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr2 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 0: | |||
| inptr3 = zerobuff; | |||
| break; | |||
| @@ -696,9 +700,11 @@ void sgemm_4x16_pack_B_t(float* out, const float* in, int ldin, | |||
| switch ((y + 3) - ymax) { | |||
| /* Everything falls through in here */ | |||
| case 2: | |||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr1 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 1: | |||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr2 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 0: | |||
| inptr3 = zerobuff; | |||
| break; | |||
| @@ -711,8 +717,8 @@ void sgemm_4x16_pack_B_t(float* out, const float* in, int ldin, | |||
| } | |||
| } | |||
| } // matmul_general_4x16 | |||
| } // aarch64 | |||
| } // megdnn | |||
| } // namespace matmul_general_4x16 | |||
| } // namespace aarch64 | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -43,8 +43,9 @@ struct matmul_general_8x12 { | |||
| // +--+ --- - +--------+--------+--------+ | |||
| // | |||
| // Accumulator | |||
| static void kern_8x12(const float* packA, const float* packB, int K, | |||
| float* output, int LDC, bool is_first_k) { | |||
| static void kern_8x12( | |||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||
| bool is_first_k) { | |||
| const float* a_ptr = packA; | |||
| const float* b_ptr = packB; | |||
| int oddk = (K & 1); | |||
| @@ -306,14 +307,13 @@ struct matmul_general_8x12 { | |||
| "6:\n" | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||
| [oddk] "+r"(oddk), [outptr] "+r"(outptr) | |||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
| [outptr] "+r"(outptr) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||
| "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||
| "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", | |||
| "v28", "v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5", | |||
| "x6", "x7", "cc", "memory"); | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", | |||
| "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", | |||
| "v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "cc", "memory"); | |||
| #undef LOAD_LINE | |||
| #undef LOAD_C | |||
| @@ -348,9 +348,9 @@ struct matmul_general_8x12 { | |||
| // +--+ --- - +--------+ | |||
| // | |||
| // Accumulator | |||
| static void kern_8x4(const float* packA, const float* packB, int K, | |||
| float* output, int LDC, bool is_first_k, | |||
| int n_remain) { | |||
| static void kern_8x4( | |||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||
| bool is_first_k, int n_remain) { | |||
| const float* a_ptr = packA; | |||
| const float* b_ptr = packB; | |||
| int oddk = (K & 1); | |||
| @@ -520,13 +520,12 @@ struct matmul_general_8x12 { | |||
| "6:\n" STORE_C | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||
| [oddk] "+r"(oddk), [outptr] "+r"(outptr), | |||
| [n_remain] "+r"(n_remain) | |||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
| [outptr] "+r"(outptr), [n_remain] "+r"(n_remain) | |||
| : | |||
| : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20", | |||
| "v23", "v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7", | |||
| "cc", "memory"); | |||
| : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20", "v23", | |||
| "v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "cc", | |||
| "memory"); | |||
| #undef LOAD_LINE | |||
| #undef LOAD_C | |||
| @@ -557,9 +556,9 @@ struct matmul_general_8x12 { | |||
| // +--+ --- - +--------+--------+--------+ | |||
| // | |||
| // Accumulator | |||
| static void kern_4x12(const float* packA, const float* packB, int K, | |||
| float* output, int LDC, bool is_first_k, | |||
| int m_remain) { | |||
| static void kern_4x12( | |||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||
| bool is_first_k, int m_remain) { | |||
| const float* a_ptr = packA; | |||
| const float* b_ptr = packB; | |||
| int oddk = (K & 1); | |||
| @@ -717,13 +716,12 @@ struct matmul_general_8x12 { | |||
| "6:\n" STORE_C | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||
| [oddk] "+r"(oddk), [outptr] "+r"(outptr), | |||
| [m_remain] "+r"(m_remain) | |||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
| [outptr] "+r"(outptr), [m_remain] "+r"(m_remain) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||
| "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||
| "v19", "x1", "x2", "x3", "x10", "cc", "memory"); | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", | |||
| "x2", "x3", "x10", "cc", "memory"); | |||
| #undef LOAD_LINE | |||
| #undef LOAD_C | |||
| @@ -754,9 +752,9 @@ struct matmul_general_8x12 { | |||
| // +--+ --- - +--------+ | |||
| // | |||
| // Accumulator | |||
| static void kern_4x4(const float* packA, const float* packB, int K, | |||
| float* output, int LDC, bool is_first_k, int m_remain, | |||
| int n_remain) { | |||
| static void kern_4x4( | |||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||
| bool is_first_k, int m_remain, int n_remain) { | |||
| const float* a_ptr = packA; | |||
| const float* b_ptr = packB; | |||
| int oddk = (K & 1); | |||
| @@ -895,20 +893,21 @@ struct matmul_general_8x12 { | |||
| "6:\n" STORE_C | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||
| [oddk] "+r"(oddk), [outptr] "+r"(outptr), | |||
| [n_remain] "+r"(n_remain), [m_remain] "+r"(m_remain) | |||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
| [outptr] "+r"(outptr), [n_remain] "+r"(n_remain), | |||
| [m_remain] "+r"(m_remain) | |||
| : | |||
| : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", | |||
| "x3", "x10", "cc", "memory"); | |||
| : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", "x3", | |||
| "x10", "cc", "memory"); | |||
| #undef LOAD_LINE | |||
| #undef LOAD_C | |||
| #undef STORE_LINE | |||
| #undef STORE_C | |||
| } | |||
| static void sgemm_8x12_pack_A_n(float* outptr, const float* inptr, int ldin, | |||
| int y0, int ymax, int k0, int kmax) { | |||
| static void sgemm_8x12_pack_A_n( | |||
| float* outptr, const float* inptr, int ldin, int y0, int ymax, int k0, | |||
| int kmax) { | |||
| float zerobuff[8]; | |||
| std::memset(zerobuff, 0, sizeof(float) * 8); | |||
| constexpr int PACK_SIZE_32 = 4 * 8; | |||
| @@ -933,8 +932,9 @@ struct matmul_general_8x12 { | |||
| prefetch_2x(inptr7); | |||
| int x = (kmax - k0); | |||
| for (; x > 3; x -= 4) { | |||
| transpose_8x4_1_s(inptr0, inptr1, inptr2, inptr3, inptr4, | |||
| inptr5, inptr6, inptr7, outptr); | |||
| transpose_8x4_1_s( | |||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
| outptr); | |||
| outptr += PACK_SIZE_32; | |||
| } | |||
| for (; x > 0; x--) { | |||
| @@ -1004,8 +1004,8 @@ struct matmul_general_8x12 { | |||
| } | |||
| } | |||
| static void sgemm_8x12_pack_A_t(float* out, const float* in, int ldin, | |||
| int x0, int xmax, int k0, int kmax) { | |||
| static void sgemm_8x12_pack_A_t( | |||
| float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||
| int ksize = kmax - k0; | |||
| int ksize8 = (ksize << 3); | |||
| int ksize4 = (ksize << 2); | |||
| @@ -1028,20 +1028,17 @@ struct matmul_general_8x12 { | |||
| auto outptr = outptr_base; | |||
| for (; x + 8 <= xmax; x += 8) { | |||
| auto outptr_interleave = outptr; | |||
| interleave_4x8_1_s(inptr, inptr1, inptr2, inptr3, | |||
| outptr_interleave); | |||
| interleave_4x8_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); | |||
| outptr += ksize8; | |||
| } | |||
| outptr = outptr_base4; | |||
| for (; x + 4 <= xmax; x += 4) { | |||
| auto outptr_interleave = outptr; | |||
| interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, | |||
| outptr_interleave); | |||
| interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); | |||
| outptr += ksize4; | |||
| } | |||
| if (x < xmax) { | |||
| interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, | |||
| xmax - x); | |||
| interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, xmax - x); | |||
| } | |||
| outptr_base += 4 * 8; | |||
| outptr_base4 += 4 * 4; | |||
| @@ -1071,8 +1068,8 @@ struct matmul_general_8x12 { | |||
| } | |||
| } | |||
| static void sgemm_8x12_pack_B_n(float* out, const float* in, int ldin, | |||
| int x0, int xmax, int k0, int kmax) { | |||
| static void sgemm_8x12_pack_B_n( | |||
| float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||
| int ksize = kmax - k0; | |||
| int ksize12 = ksize * 12; | |||
| int ksize4 = (ksize << 2); | |||
| @@ -1095,20 +1092,17 @@ struct matmul_general_8x12 { | |||
| auto outptr = outptr_base; | |||
| for (; x + 12 <= xmax; x += 12) { | |||
| auto outptr_interleave = outptr; | |||
| interleave_4x12_1_s(inptr, inptr1, inptr2, inptr3, | |||
| outptr_interleave); | |||
| interleave_4x12_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); | |||
| outptr += ksize12; | |||
| } | |||
| outptr = outptr_base4; | |||
| for (; x + 4 <= xmax; x += 4) { | |||
| auto outptr_interleave = outptr; | |||
| interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, | |||
| outptr_interleave); | |||
| interleave_4x4_1_s(inptr, inptr1, inptr2, inptr3, outptr_interleave); | |||
| outptr += ksize4; | |||
| } | |||
| if (x < xmax) { | |||
| interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, | |||
| xmax - x); | |||
| interleave_4(inptr, inptr1, inptr2, inptr3, outptr, 4, xmax - x); | |||
| } | |||
| outptr_base += 12 * 4; | |||
| outptr_base4 += 4 * 4; | |||
| @@ -1138,8 +1132,8 @@ struct matmul_general_8x12 { | |||
| } | |||
| } | |||
| static void sgemm_8x12_pack_B_t(float* out, const float* in, int ldin, | |||
| int y0, int ymax, int k0, int kmax) { | |||
| static void sgemm_8x12_pack_B_t( | |||
| float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax) { | |||
| float* outptr = out; | |||
| const float* inptr = in; | |||
| float zerobuff[12]; | |||
| @@ -1172,9 +1166,9 @@ struct matmul_general_8x12 { | |||
| prefetch_2x(inptr11); | |||
| int x = (kmax - k0); | |||
| for (; x > 3; x -= 4) { | |||
| transpose_12x4_1_s(inptr0, inptr1, inptr2, inptr3, inptr4, | |||
| inptr5, inptr6, inptr7, inptr8, inptr9, | |||
| inptr10, inptr11, outptr); | |||
| transpose_12x4_1_s( | |||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
| inptr8, inptr9, inptr10, inptr11, outptr); | |||
| outptr += 48; | |||
| } | |||
| for (; x > 0; x--) { | |||
| @@ -43,8 +43,9 @@ struct matmul_general_8x12_a53 { | |||
| // +--+ --- - +--------+--------+--------+ | |||
| // | |||
| // Accumulator | |||
| static void kern_8x12(const float* packA, const float* packB, int K, | |||
| float* output, int LDC, bool is_first_k) { | |||
| static void kern_8x12( | |||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||
| bool is_first_k) { | |||
| const float* a_ptr = packA; | |||
| const float* b_ptr = packB; | |||
| int oddk = (K & 1); | |||
| @@ -575,15 +576,14 @@ struct matmul_general_8x12_a53 { | |||
| "6:\n" | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||
| [oddk] "+r"(oddk), [outptr] "+r"(outptr) | |||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
| [outptr] "+r"(outptr) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||
| "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||
| "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", | |||
| "v28", "v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5", | |||
| "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", | |||
| "memory"); | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", | |||
| "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", | |||
| "v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", | |||
| "x11", "x12", "x13", "cc", "memory"); | |||
| #undef LOAD_LINE | |||
| #undef LOAD_C | |||
| } | |||
| @@ -615,9 +615,9 @@ struct matmul_general_8x12_a53 { | |||
| // +--+ --- - +--------+ | |||
| // | |||
| // Accumulator | |||
| static void kern_8x4(const float* packA, const float* packB, int K, | |||
| float* output, int LDC, bool is_first_k, | |||
| int n_remain) { | |||
| static void kern_8x4( | |||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||
| bool is_first_k, int n_remain) { | |||
| const float* a_ptr = packA; | |||
| const float* b_ptr = packB; | |||
| int oddk = (K & 1); | |||
| @@ -856,13 +856,12 @@ struct matmul_general_8x12_a53 { | |||
| "6:\n" STORE_C | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||
| [oddk] "+r"(oddk), [outptr] "+r"(outptr), | |||
| [n_remain] "+r"(n_remain) | |||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
| [outptr] "+r"(outptr), [n_remain] "+r"(n_remain) | |||
| : | |||
| : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20", | |||
| "v23", "v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7", | |||
| "x8", "x9", "x10", "cc", "memory"); | |||
| : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20", "v23", | |||
| "v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", | |||
| "x10", "cc", "memory"); | |||
| #undef LOAD_LINE | |||
| #undef LOAD_C | |||
| @@ -893,9 +892,9 @@ struct matmul_general_8x12_a53 { | |||
| // +--+ --- - +--------+--------+--------+ | |||
| // | |||
| // Accumulator | |||
| static void kern_4x12(const float* packA, const float* packB, int K, | |||
| float* output, int LDC, bool is_first_k, | |||
| int m_remain) { | |||
| static void kern_4x12( | |||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||
| bool is_first_k, int m_remain) { | |||
| const float* a_ptr = packA; | |||
| const float* b_ptr = packB; | |||
| int oddk = (K & 1); | |||
| @@ -1133,14 +1132,12 @@ struct matmul_general_8x12_a53 { | |||
| "6:\n" STORE_C | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||
| [oddk] "+r"(oddk), [outptr] "+r"(outptr), | |||
| [m_remain] "+r"(m_remain) | |||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
| [outptr] "+r"(outptr), [m_remain] "+r"(m_remain) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||
| "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||
| "v19", "x1", "x2", "x3", "x8", "x9", "x10", "x20", "x21", | |||
| "x22", "cc", "memory"); | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", | |||
| "x2", "x3", "x8", "x9", "x10", "x20", "x21", "x22", "cc", "memory"); | |||
| #undef LOAD_LINE | |||
| #undef LOAD_C | |||
| @@ -1171,9 +1168,9 @@ struct matmul_general_8x12_a53 { | |||
| // +--+ --- - +--------+ | |||
| // | |||
| // Accumulator | |||
| static void kern_4x4(const float* packA, const float* packB, int K, | |||
| float* output, int LDC, bool is_first_k, int m_remain, | |||
| int n_remain) { | |||
| static void kern_4x4( | |||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||
| bool is_first_k, int m_remain, int n_remain) { | |||
| const float* a_ptr = packA; | |||
| const float* b_ptr = packB; | |||
| int oddk = (K & 1); | |||
| @@ -1312,12 +1309,12 @@ struct matmul_general_8x12_a53 { | |||
| "6:\n" STORE_C | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||
| [oddk] "+r"(oddk), [outptr] "+r"(outptr), | |||
| [n_remain] "+r"(n_remain), [m_remain] "+r"(m_remain) | |||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
| [outptr] "+r"(outptr), [n_remain] "+r"(n_remain), | |||
| [m_remain] "+r"(m_remain) | |||
| : | |||
| : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", | |||
| "x3", "x10", "cc", "memory"); | |||
| : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", "x3", | |||
| "x10", "cc", "memory"); | |||
| #undef LOAD_LINE | |||
| #undef LOAD_C | |||
| #undef STORE_LINE | |||
| @@ -43,8 +43,9 @@ struct matmul_general_8x12_a55 { | |||
| // +--+ --- - +--------+--------+--------+ | |||
| // | |||
| // Accumulator | |||
| static void kern_8x12(const float* packA, const float* packB, int K, | |||
| float* output, int LDC, bool is_first_k) { | |||
| static void kern_8x12( | |||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||
| bool is_first_k) { | |||
| const float* a_ptr = packA; | |||
| const float* b_ptr = packB; | |||
| int oddk = (K & 1); | |||
| @@ -525,15 +526,14 @@ struct matmul_general_8x12_a55 { | |||
| "6:\n" | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||
| [oddk] "+r"(oddk), [outptr] "+r"(outptr) | |||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
| [outptr] "+r"(outptr) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||
| "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||
| "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", | |||
| "v28", "v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5", | |||
| "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc", | |||
| "memory"); | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", | |||
| "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", | |||
| "v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", | |||
| "x11", "x12", "x13", "cc", "memory"); | |||
| #undef LOAD_LINE | |||
| #undef LOAD_C | |||
| } | |||
| @@ -565,9 +565,9 @@ struct matmul_general_8x12_a55 { | |||
| // +--+ --- - +--------+ | |||
| // | |||
| // Accumulator | |||
| static void kern_8x4(const float* packA, const float* packB, int K, | |||
| float* output, int LDC, bool is_first_k, | |||
| int n_remain) { | |||
| static void kern_8x4( | |||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||
| bool is_first_k, int n_remain) { | |||
| const float* a_ptr = packA; | |||
| const float* b_ptr = packB; | |||
| int oddk = (K & 1); | |||
| @@ -742,13 +742,12 @@ struct matmul_general_8x12_a55 { | |||
| "6:\n" STORE_C | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||
| [oddk] "+r"(oddk), [outptr] "+r"(outptr), | |||
| [n_remain] "+r"(n_remain) | |||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
| [outptr] "+r"(outptr), [n_remain] "+r"(n_remain) | |||
| : | |||
| : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20", | |||
| "v23", "v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7", | |||
| "x10", "cc", "memory"); | |||
| : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "v20", "v23", | |||
| "v26", "v29", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x10", "cc", | |||
| "memory"); | |||
| #undef LOAD_LINE | |||
| #undef LOAD_C | |||
| @@ -779,9 +778,9 @@ struct matmul_general_8x12_a55 { | |||
| // +--+ --- - +--------+--------+--------+ | |||
| // | |||
| // Accumulator | |||
| static void kern_4x12(const float* packA, const float* packB, int K, | |||
| float* output, int LDC, bool is_first_k, | |||
| int m_remain) { | |||
| static void kern_4x12( | |||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||
| bool is_first_k, int m_remain) { | |||
| const float* a_ptr = packA; | |||
| const float* b_ptr = packB; | |||
| int oddk = (K & 1); | |||
| @@ -972,14 +971,12 @@ struct matmul_general_8x12_a55 { | |||
| "6:\n" STORE_C | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||
| [oddk] "+r"(oddk), [outptr] "+r"(outptr), | |||
| [m_remain] "+r"(m_remain) | |||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
| [outptr] "+r"(outptr), [m_remain] "+r"(m_remain) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||
| "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||
| "v19", "x1", "x2", "x3", "x10", "x20", "x21", "x22", "cc", | |||
| "memory"); | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", | |||
| "x2", "x3", "x10", "x20", "x21", "x22", "cc", "memory"); | |||
| #undef LOAD_LINE | |||
| #undef LOAD_C | |||
| @@ -1010,9 +1007,9 @@ struct matmul_general_8x12_a55 { | |||
| // +--+ --- - +--------+ | |||
| // | |||
| // Accumulator | |||
| static void kern_4x4(const float* packA, const float* packB, int K, | |||
| float* output, int LDC, bool is_first_k, int m_remain, | |||
| int n_remain) { | |||
| static void kern_4x4( | |||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||
| bool is_first_k, int m_remain, int n_remain) { | |||
| const float* a_ptr = packA; | |||
| const float* b_ptr = packB; | |||
| int oddk = (K & 1); | |||
| @@ -1151,12 +1148,12 @@ struct matmul_general_8x12_a55 { | |||
| "6:\n" STORE_C | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||
| [oddk] "+r"(oddk), [outptr] "+r"(outptr), | |||
| [n_remain] "+r"(n_remain), [m_remain] "+r"(m_remain) | |||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
| [outptr] "+r"(outptr), [n_remain] "+r"(n_remain), | |||
| [m_remain] "+r"(m_remain) | |||
| : | |||
| : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", | |||
| "x3", "x10", "cc", "memory"); | |||
| : "v0", "v1", "v2", "v5", "v8", "v11", "v14", "v17", "x1", "x2", "x3", | |||
| "x10", "cc", "memory"); | |||
| #undef LOAD_LINE | |||
| #undef LOAD_C | |||
| #undef STORE_LINE | |||
| @@ -44,8 +44,9 @@ struct matmul_mk4_8x12 { | |||
| // +--+ --- - +--------+--------+--------+ | |||
| // | |||
| // Accumulator | |||
| static void kern_8x12(const float* packA, const float* packB, int K, | |||
| float* output, int LDC, bool is_first_k) { | |||
| static void kern_8x12( | |||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||
| bool is_first_k) { | |||
| const float* a_ptr = packA; | |||
| const float* b_ptr = packB; | |||
| float* output0 = output; | |||
| @@ -307,10 +308,10 @@ struct matmul_mk4_8x12 { | |||
| [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
| [output0] "+r"(output0), [output1] "+r"(output1) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||
| "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||
| "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", | |||
| "v28", "v29", "v30", "v31", "x1", "x2", "cc", "memory"); | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", | |||
| "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", | |||
| "v31", "x1", "x2", "cc", "memory"); | |||
| } | |||
| // Overview of register layout: | |||
| @@ -340,9 +341,9 @@ struct matmul_mk4_8x12 { | |||
| // +--+ --- - +--------+ | |||
| // | |||
| // Accumulator | |||
| static void kern_8x4(const float* packA, const float* packB, int K, | |||
| float* output, int LDC, bool is_first_k, | |||
| int n_remain) { | |||
| static void kern_8x4( | |||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||
| bool is_first_k, int n_remain) { | |||
| const float* a_ptr = packA; | |||
| const float* b_ptr = packB; | |||
| float* output0 = output; | |||
| @@ -500,8 +501,8 @@ struct matmul_mk4_8x12 { | |||
| [output0] "+r"(output0), [output1] "+r"(output1), | |||
| [n_remain] "+r"(n_remain) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", | |||
| "v13", "v14", "v15", "cc", "memory"); | |||
| : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||
| "v15", "cc", "memory"); | |||
| #undef LOAD_C | |||
| #undef STORE_C | |||
| @@ -531,8 +532,9 @@ struct matmul_mk4_8x12 { | |||
| // | |||
| // Accumulator | |||
| static void kern_4x12(const float* packA, const float* packB, int K, | |||
| float* output, int LDC, bool is_first_k) { | |||
| static void kern_4x12( | |||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||
| bool is_first_k) { | |||
| MEGDNN_MARK_USED_VAR(LDC); | |||
| const float* a_ptr = packA; | |||
| const float* b_ptr = packB; | |||
| @@ -669,9 +671,9 @@ struct matmul_mk4_8x12 { | |||
| [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
| [output0] "+r"(output0) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||
| "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||
| "v19", "x1", "cc", "memory"); | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", | |||
| "cc", "memory"); | |||
| } | |||
| // Overview of register layout: | |||
| @@ -697,9 +699,9 @@ struct matmul_mk4_8x12 { | |||
| // +--+ --- - +--------+ | |||
| // | |||
| // Accumulator | |||
| static void kern_4x4(const float* packA, const float* packB, int K, | |||
| float* output, int LDC, bool is_first_k, | |||
| int n_remain) { | |||
| static void kern_4x4( | |||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||
| bool is_first_k, int n_remain) { | |||
| MEGDNN_MARK_USED_VAR(LDC); | |||
| const float* a_ptr = packA; | |||
| const float* b_ptr = packB; | |||
| @@ -818,15 +820,15 @@ struct matmul_mk4_8x12 { | |||
| [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
| [output0] "+r"(output0), [n_remain] "+r"(n_remain) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc", | |||
| "memory"); | |||
| : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc", "memory"); | |||
| #undef LOAD_C | |||
| #undef STORE_C | |||
| } | |||
| static void sgemm_8x12_pack_A(float* outptr, const float* inptr, int ldin, | |||
| int y0, int ymax, int k0, int kmax) { | |||
| static void sgemm_8x12_pack_A( | |||
| float* outptr, const float* inptr, int ldin, int y0, int ymax, int k0, | |||
| int kmax) { | |||
| megdnn_assert(y0 % 4 == 0 && ymax % 4 == 0, "M must be time of 4"); | |||
| megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | |||
| constexpr int PACK_SIZE_32 = 4 * 8; | |||
| @@ -855,8 +857,8 @@ struct matmul_mk4_8x12 { | |||
| } | |||
| } | |||
| static void sgemm_8x12_pack_B(float* out, const float* in, int ldin, int x0, | |||
| int xmax, int k0, int kmax) { | |||
| static void sgemm_8x12_pack_B( | |||
| float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||
| megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | |||
| float tmpbuff[16] = {0.0f}; | |||
| @@ -886,8 +888,7 @@ struct matmul_mk4_8x12 { | |||
| outptr += ksize4; | |||
| } | |||
| if (x < xmax) { | |||
| std::memcpy(tmpbuff, inptr, | |||
| sizeof(float) * (xmax - x) * PACK_C_SIZE); | |||
| std::memcpy(tmpbuff, inptr, sizeof(float) * (xmax - x) * PACK_C_SIZE); | |||
| auto outptr_interleave = outptr; | |||
| const float* tmp_ptr = &tmpbuff[0]; | |||
| transpose_1x4_4_s<float>(tmp_ptr, outptr_interleave); | |||
| @@ -44,8 +44,9 @@ struct matmul_mk4_8x12_a53 { | |||
| // +--+ --- - +--------+--------+--------+ | |||
| // | |||
| // Accumulator | |||
| static void kern_8x12(const float* packA, const float* packB, int K, | |||
| float* output, int LDC, bool is_first_k) { | |||
| static void kern_8x12( | |||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||
| bool is_first_k) { | |||
| const float* a_ptr = packA; | |||
| const float* b_ptr = packB; | |||
| float* output0 = output; | |||
| @@ -553,11 +554,11 @@ struct matmul_mk4_8x12_a53 { | |||
| [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
| [output0] "+r"(output0), [output1] "+r"(output1) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||
| "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||
| "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", | |||
| "v28", "v29", "v30", "v31", "x1", "x2", "x8", "x9", "x10", | |||
| "x11", "x12", "x13", "cc", "memory"); | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", | |||
| "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", | |||
| "v31", "x1", "x2", "x8", "x9", "x10", "x11", "x12", "x13", "cc", | |||
| "memory"); | |||
| } | |||
| // Overview of register layout: | |||
| @@ -587,9 +588,9 @@ struct matmul_mk4_8x12_a53 { | |||
| // +--+ --- - +--------+ | |||
| // | |||
| // Accumulator | |||
| static void kern_8x4(const float* packA, const float* packB, int K, | |||
| float* output, int LDC, bool is_first_k, | |||
| int n_remain) { | |||
| static void kern_8x4( | |||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||
| bool is_first_k, int n_remain) { | |||
| const float* a_ptr = packA; | |||
| const float* b_ptr = packB; | |||
| float* output0 = output; | |||
| @@ -831,8 +832,8 @@ struct matmul_mk4_8x12_a53 { | |||
| [output0] "+r"(output0), [output1] "+r"(output1), | |||
| [n_remain] "+r"(n_remain) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", | |||
| "v13", "v14", "v15", "x8", "x9", "x10", "cc", "memory"); | |||
| : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||
| "v15", "x8", "x9", "x10", "cc", "memory"); | |||
| #undef LOAD_C | |||
| #undef STORE_C | |||
| @@ -862,8 +863,9 @@ struct matmul_mk4_8x12_a53 { | |||
| // | |||
| // Accumulator | |||
| static void kern_4x12(const float* packA, const float* packB, int K, | |||
| float* output, int LDC, bool is_first_k) { | |||
| static void kern_4x12( | |||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||
| bool is_first_k) { | |||
| MEGDNN_MARK_USED_VAR(LDC); | |||
| const float* a_ptr = packA; | |||
| const float* b_ptr = packB; | |||
| @@ -1098,9 +1100,9 @@ struct matmul_mk4_8x12_a53 { | |||
| [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
| [output0] "+r"(output0) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||
| "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||
| "v19", "x1", "x8", "x9", "x10", "x11", "x12", "cc", "memory"); | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", | |||
| "x8", "x9", "x10", "x11", "x12", "cc", "memory"); | |||
| } | |||
| // Overview of register layout: | |||
| @@ -1126,9 +1128,9 @@ struct matmul_mk4_8x12_a53 { | |||
| // +--+ --- - +--------+ | |||
| // | |||
| // Accumulator | |||
| static void kern_4x4(const float* packA, const float* packB, int K, | |||
| float* output, int LDC, bool is_first_k, | |||
| int n_remain) { | |||
| static void kern_4x4( | |||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||
| bool is_first_k, int n_remain) { | |||
| MEGDNN_MARK_USED_VAR(LDC); | |||
| const float* a_ptr = packA; | |||
| const float* b_ptr = packB; | |||
| @@ -1246,8 +1248,7 @@ struct matmul_mk4_8x12_a53 { | |||
| [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
| [output0] "+r"(output0), [n_remain] "+r"(n_remain) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc", | |||
| "memory"); | |||
| : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc", "memory"); | |||
| #undef LOAD_C | |||
| #undef STORE_C | |||
| @@ -44,8 +44,9 @@ struct matmul_mk4_8x12_a55 { | |||
| // +--+ --- - +--------+--------+--------+ | |||
| // | |||
| // Accumulator | |||
| static void kern_8x12(const float* packA, const float* packB, int K, | |||
| float* output, int LDC, bool is_first_k) { | |||
| static void kern_8x12( | |||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||
| bool is_first_k) { | |||
| const float* a_ptr = packA; | |||
| const float* b_ptr = packB; | |||
| float* output0 = output; | |||
| @@ -519,11 +520,11 @@ struct matmul_mk4_8x12_a55 { | |||
| [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
| [output0] "+r"(output0), [output1] "+r"(output1) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||
| "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||
| "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", | |||
| "v28", "v29", "v30", "v31", "x1", "x2", "x8", "x9", "x10", | |||
| "x11", "x12", "x13", "cc", "memory"); | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", | |||
| "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", | |||
| "v31", "x1", "x2", "x8", "x9", "x10", "x11", "x12", "x13", "cc", | |||
| "memory"); | |||
| } | |||
| // Overview of register layout: | |||
| @@ -553,9 +554,9 @@ struct matmul_mk4_8x12_a55 { | |||
| // +--+ --- - +--------+ | |||
| // | |||
| // Accumulator | |||
| static void kern_8x4(const float* packA, const float* packB, int K, | |||
| float* output, int LDC, bool is_first_k, | |||
| int n_remain) { | |||
| static void kern_8x4( | |||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||
| bool is_first_k, int n_remain) { | |||
| const float* a_ptr = packA; | |||
| const float* b_ptr = packB; | |||
| float* output0 = output; | |||
| @@ -749,8 +750,8 @@ struct matmul_mk4_8x12_a55 { | |||
| [output0] "+r"(output0), [output1] "+r"(output1), | |||
| [n_remain] "+r"(n_remain) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", | |||
| "v13", "v14", "v15", "x8", "x9", "x10", "cc", "memory"); | |||
| : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||
| "v15", "x8", "x9", "x10", "cc", "memory"); | |||
| #undef LOAD_C | |||
| #undef STORE_C | |||
| @@ -780,8 +781,9 @@ struct matmul_mk4_8x12_a55 { | |||
| // | |||
| // Accumulator | |||
| static void kern_4x12(const float* packA, const float* packB, int K, | |||
| float* output, int LDC, bool is_first_k) { | |||
| static void kern_4x12( | |||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||
| bool is_first_k) { | |||
| MEGDNN_MARK_USED_VAR(LDC); | |||
| const float* a_ptr = packA; | |||
| const float* b_ptr = packB; | |||
| @@ -997,9 +999,9 @@ struct matmul_mk4_8x12_a55 { | |||
| [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
| [output0] "+r"(output0) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", | |||
| "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||
| "v19", "x1", "x8", "x9", "x10", "x11", "x12", "cc", "memory"); | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", | |||
| "x8", "x9", "x10", "x11", "x12", "cc", "memory"); | |||
| } | |||
| // Overview of register layout: | |||
| @@ -1025,9 +1027,9 @@ struct matmul_mk4_8x12_a55 { | |||
| // +--+ --- - +--------+ | |||
| // | |||
| // Accumulator | |||
| static void kern_4x4(const float* packA, const float* packB, int K, | |||
| float* output, int LDC, bool is_first_k, | |||
| int n_remain) { | |||
| static void kern_4x4( | |||
| const float* packA, const float* packB, int K, float* output, int LDC, | |||
| bool is_first_k, int n_remain) { | |||
| MEGDNN_MARK_USED_VAR(LDC); | |||
| const float* a_ptr = packA; | |||
| const float* b_ptr = packB; | |||
| @@ -1146,8 +1148,7 @@ struct matmul_mk4_8x12_a55 { | |||
| [is_first_k] "+r"(is_first_k), [oddk] "+r"(oddk), | |||
| [output0] "+r"(output0), [n_remain] "+r"(n_remain) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc", | |||
| "memory"); | |||
| : "v0", "v1", "v2", "v3", "v8", "v9", "v10", "v11", "cc", "memory"); | |||
| #undef LOAD_C | |||
| #undef STORE_C | |||
| @@ -10,6 +10,7 @@ | |||
| * implied. | |||
| */ | |||
| #include "src/aarch64/matrix_mul/fp32/strategy.h" | |||
| #include "src/aarch64/matrix_mul/fp32/kernel_general_4x16.h" | |||
| #include "src/aarch64/matrix_mul/fp32/kernel_general_8x12.h" | |||
| #include "src/aarch64/matrix_mul/fp32/kernel_general_8x12_a53.h" | |||
| @@ -17,44 +18,40 @@ | |||
| #include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12.h" | |||
| #include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a53.h" | |||
| #include "src/aarch64/matrix_mul/fp32/kernel_mk4_8x12_a55.h" | |||
| #include "src/aarch64/matrix_mul/fp32/strategy.h" | |||
| #include "src/common/utils.h" | |||
| using namespace megdnn; | |||
| using namespace aarch64; | |||
| using namespace aarch64::matmul; | |||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_4x16); | |||
| void sgemm_4x16::pack_A(float* out, const float* in, int ldin, int y0, int ymax, | |||
| int k0, int kmax, bool transpose_A) const { | |||
| void sgemm_4x16::pack_A( | |||
| float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax, | |||
| bool transpose_A) const { | |||
| if (transpose_A) { | |||
| matmul_general_4x16::sgemm_4x16_pack_A_t(out, in, ldin, y0, ymax, k0, | |||
| kmax); | |||
| matmul_general_4x16::sgemm_4x16_pack_A_t(out, in, ldin, y0, ymax, k0, kmax); | |||
| } else { | |||
| matmul_general_4x16::sgemm_4x16_pack_A_n(out, in, ldin, y0, ymax, k0, | |||
| kmax); | |||
| matmul_general_4x16::sgemm_4x16_pack_A_n(out, in, ldin, y0, ymax, k0, kmax); | |||
| } | |||
| } | |||
| void sgemm_4x16::pack_B(float* out, const float* in, int ldin, int x0, int xmax, | |||
| int k0, int kmax, bool transpose_B) const { | |||
| void sgemm_4x16::pack_B( | |||
| float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax, | |||
| bool transpose_B) const { | |||
| if (transpose_B) { | |||
| matmul_general_4x16::sgemm_4x16_pack_B_t(out, in, ldin, x0, xmax, k0, | |||
| kmax); | |||
| matmul_general_4x16::sgemm_4x16_pack_B_t(out, in, ldin, x0, xmax, k0, kmax); | |||
| } else { | |||
| matmul_general_4x16::sgemm_4x16_pack_B_n(out, in, ldin, x0, xmax, k0, | |||
| kmax); | |||
| matmul_general_4x16::sgemm_4x16_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | |||
| } | |||
| } | |||
| void sgemm_4x16::kern(const float* packA, const float* packB, size_t M, | |||
| size_t N, size_t K, float* C, size_t LDC, bool is_first_k, | |||
| const float*, float*) const { | |||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
| A_dtype.enumv() == C_dtype.enumv() && | |||
| A_dtype.enumv() == DTypeEnum::Float32); | |||
| void sgemm_4x16::kern( | |||
| const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C, | |||
| size_t LDC, bool is_first_k, const float*, float*) const { | |||
| megdnn_assert( | |||
| A_dtype.enumv() == B_dtype.enumv() && A_dtype.enumv() == C_dtype.enumv() && | |||
| A_dtype.enumv() == DTypeEnum::Float32); | |||
| MEGDNN_MARK_USED_VAR(A_dtype); | |||
| MEGDNN_MARK_USED_VAR(B_dtype); | |||
| MEGDNN_MARK_USED_VAR(C_dtype); | |||
| @@ -71,9 +68,9 @@ void sgemm_4x16::kern(const float* packA, const float* packB, size_t M, | |||
| size_t n = 0; | |||
| const float* cur_packB = packB; | |||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
| matmul_general_4x16::kern_4x16(packA, cur_packB, K, output, LDC, | |||
| is_first_k, | |||
| std::min<size_t>(M - m, 4)); | |||
| matmul_general_4x16::kern_4x16( | |||
| packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(M - m, 4)); | |||
| output += B_INTERLEAVE; | |||
| cur_packB += K16; | |||
| } | |||
| @@ -92,32 +89,30 @@ void sgemm_4x16::kern(const float* packA, const float* packB, size_t M, | |||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_8x12); | |||
| void sgemm_8x12::pack_A(float* out, const float* in, int ldin, int y0, int ymax, | |||
| int k0, int kmax, bool transpose_A) const { | |||
| void sgemm_8x12::pack_A( | |||
| float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax, | |||
| bool transpose_A) const { | |||
| if (transpose_A) { | |||
| matmul_general_8x12::sgemm_8x12_pack_A_t(out, in, ldin, y0, ymax, k0, | |||
| kmax); | |||
| matmul_general_8x12::sgemm_8x12_pack_A_t(out, in, ldin, y0, ymax, k0, kmax); | |||
| } else { | |||
| matmul_general_8x12::sgemm_8x12_pack_A_n(out, in, ldin, y0, ymax, k0, | |||
| kmax); | |||
| matmul_general_8x12::sgemm_8x12_pack_A_n(out, in, ldin, y0, ymax, k0, kmax); | |||
| } | |||
| } | |||
| void sgemm_8x12::pack_B(float* out, const float* in, int ldin, int x0, int xmax, | |||
| int k0, int kmax, bool transpose_B) const { | |||
| void sgemm_8x12::pack_B( | |||
| float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax, | |||
| bool transpose_B) const { | |||
| if (transpose_B) { | |||
| matmul_general_8x12::sgemm_8x12_pack_B_t(out, in, ldin, x0, xmax, k0, | |||
| kmax); | |||
| matmul_general_8x12::sgemm_8x12_pack_B_t(out, in, ldin, x0, xmax, k0, kmax); | |||
| } else { | |||
| matmul_general_8x12::sgemm_8x12_pack_B_n(out, in, ldin, x0, xmax, k0, | |||
| kmax); | |||
| matmul_general_8x12::sgemm_8x12_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | |||
| } | |||
| } | |||
| template <typename gemm_class> | |||
| static inline void sgemm_8x12_helper(const float* packA, const float* packB, | |||
| size_t M, size_t N, size_t K, float* C, | |||
| size_t LDC, bool is_first_k) { | |||
| static inline void sgemm_8x12_helper( | |||
| const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C, | |||
| size_t LDC, bool is_first_k) { | |||
| constexpr size_t A_INTERLEAVE = 8; | |||
| constexpr size_t A_INTERLEAVE4 = 4; | |||
| constexpr size_t B_INTERLEAVE = 12; | |||
| @@ -138,8 +133,9 @@ static inline void sgemm_8x12_helper(const float* packA, const float* packB, | |||
| } | |||
| for (; n < N; n += 4) { | |||
| gemm_class::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(N - n, 4)); | |||
| gemm_class::kern_8x4( | |||
| packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(N - n, 4)); | |||
| output += 4; | |||
| cur_packB += K4; | |||
| } | |||
| @@ -150,16 +146,17 @@ static inline void sgemm_8x12_helper(const float* packA, const float* packB, | |||
| size_t n = 0; | |||
| const float* cur_packB = packB; | |||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
| gemm_class::kern_4x12(packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(M - m, 4)); | |||
| gemm_class::kern_4x12( | |||
| packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(M - m, 4)); | |||
| output += B_INTERLEAVE; | |||
| cur_packB += K12; | |||
| } | |||
| for (; n < N; n += 4) { | |||
| gemm_class::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(M - m, 4), | |||
| std::min<size_t>(N - n, 4)); | |||
| gemm_class::kern_4x4( | |||
| packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||
| output += 4; | |||
| cur_packB += K4; | |||
| } | |||
| @@ -167,56 +164,55 @@ static inline void sgemm_8x12_helper(const float* packA, const float* packB, | |||
| } | |||
| } | |||
| void sgemm_8x12::kern(const float* packA, const float* packB, size_t M, | |||
| size_t N, size_t K, float* C, size_t LDC, bool is_first_k, | |||
| const float*, float*) const { | |||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
| A_dtype.enumv() == C_dtype.enumv() && | |||
| A_dtype.enumv() == DTypeEnum::Float32); | |||
| void sgemm_8x12::kern( | |||
| const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C, | |||
| size_t LDC, bool is_first_k, const float*, float*) const { | |||
| megdnn_assert( | |||
| A_dtype.enumv() == B_dtype.enumv() && A_dtype.enumv() == C_dtype.enumv() && | |||
| A_dtype.enumv() == DTypeEnum::Float32); | |||
| MEGDNN_MARK_USED_VAR(A_dtype); | |||
| MEGDNN_MARK_USED_VAR(B_dtype); | |||
| MEGDNN_MARK_USED_VAR(C_dtype); | |||
| #if !MGB_ENABLE_CPUINFO | |||
| sgemm_8x12_helper<matmul_general_8x12>(packA, packB, M, N, K, C, LDC, | |||
| is_first_k); | |||
| sgemm_8x12_helper<matmul_general_8x12>(packA, packB, M, N, K, C, LDC, is_first_k); | |||
| #else | |||
| auto arch = cpuinfo_get_current_core()->uarch; | |||
| #ifdef __IN_TEE_ENV__ | |||
| arch = cpuinfo_uarch_unknown; | |||
| #endif | |||
| if (arch == cpuinfo_uarch_cortex_a53) { | |||
| sgemm_8x12_helper<matmul_general_8x12_a53>(packA, packB, M, N, K, C, | |||
| LDC, is_first_k); | |||
| sgemm_8x12_helper<matmul_general_8x12_a53>( | |||
| packA, packB, M, N, K, C, LDC, is_first_k); | |||
| } else if (arch == cpuinfo_uarch_cortex_a55) { | |||
| sgemm_8x12_helper<matmul_general_8x12_a55>(packA, packB, M, N, K, C, | |||
| LDC, is_first_k); | |||
| sgemm_8x12_helper<matmul_general_8x12_a55>( | |||
| packA, packB, M, N, K, C, LDC, is_first_k); | |||
| } else { | |||
| sgemm_8x12_helper<matmul_general_8x12>(packA, packB, M, N, K, C, LDC, | |||
| is_first_k); | |||
| sgemm_8x12_helper<matmul_general_8x12>( | |||
| packA, packB, M, N, K, C, LDC, is_first_k); | |||
| } | |||
| #endif | |||
| } | |||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(sgemm_mk4_8x12); | |||
| void sgemm_mk4_8x12::pack_A(float* out, const float* in, int ldin, int y0, | |||
| int ymax, int k0, int kmax, | |||
| bool transpose_A) const { | |||
| void sgemm_mk4_8x12::pack_A( | |||
| float* out, const float* in, int ldin, int y0, int ymax, int k0, int kmax, | |||
| bool transpose_A) const { | |||
| megdnn_assert(!transpose_A, "mk4 float matmul not support transpose A"); | |||
| matmul_mk4_8x12::sgemm_8x12_pack_A(out, in, ldin, y0, ymax, k0, kmax); | |||
| } | |||
| void sgemm_mk4_8x12::pack_B(float* out, const float* in, int ldin, int x0, | |||
| int xmax, int k0, int kmax, | |||
| bool transpose_B) const { | |||
| void sgemm_mk4_8x12::pack_B( | |||
| float* out, const float* in, int ldin, int x0, int xmax, int k0, int kmax, | |||
| bool transpose_B) const { | |||
| megdnn_assert(!transpose_B, "mk4 float matmul not support transpose B"); | |||
| matmul_mk4_8x12::sgemm_8x12_pack_B(out, in, ldin, x0, xmax, k0, kmax); | |||
| } | |||
| template <typename gemm_name> | |||
| static inline void sgemm_mk4_8x12_helper(const float* packA, const float* packB, | |||
| size_t M, size_t N, size_t K, float* C, | |||
| size_t LDC, bool is_first_k) { | |||
| static inline void sgemm_mk4_8x12_helper( | |||
| const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C, | |||
| size_t LDC, bool is_first_k) { | |||
| const int K12 = K * 12; | |||
| const int K8 = K * 8; | |||
| const int K4 = K * 4; | |||
| @@ -237,8 +233,9 @@ static inline void sgemm_mk4_8x12_helper(const float* packA, const float* packB, | |||
| } | |||
| for (; n < N; n += 4) { | |||
| gemm_name::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(N - n, 4)); | |||
| gemm_name::kern_8x4( | |||
| packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(N - n, 4)); | |||
| output += 4 * PACK_C_SIZE; | |||
| cur_packB += K4; | |||
| } | |||
| @@ -254,41 +251,41 @@ static inline void sgemm_mk4_8x12_helper(const float* packA, const float* packB, | |||
| cur_packB += K12; | |||
| } | |||
| for (; n < N; n += 4) { | |||
| gemm_name::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(N - n, 4)); | |||
| gemm_name::kern_4x4( | |||
| packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(N - n, 4)); | |||
| output += 4 * PACK_C_SIZE; | |||
| cur_packB += K4; | |||
| } | |||
| packA += K4; | |||
| } | |||
| } | |||
| void sgemm_mk4_8x12::kern(const float* packA, const float* packB, size_t M, | |||
| size_t N, size_t K, float* C, size_t LDC, | |||
| bool is_first_k, const float*, float*) const { | |||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
| A_dtype.enumv() == C_dtype.enumv() && | |||
| A_dtype.enumv() == DTypeEnum::Float32); | |||
| void sgemm_mk4_8x12::kern( | |||
| const float* packA, const float* packB, size_t M, size_t N, size_t K, float* C, | |||
| size_t LDC, bool is_first_k, const float*, float*) const { | |||
| megdnn_assert( | |||
| A_dtype.enumv() == B_dtype.enumv() && A_dtype.enumv() == C_dtype.enumv() && | |||
| A_dtype.enumv() == DTypeEnum::Float32); | |||
| MEGDNN_MARK_USED_VAR(A_dtype); | |||
| MEGDNN_MARK_USED_VAR(B_dtype); | |||
| MEGDNN_MARK_USED_VAR(C_dtype); | |||
| megdnn_assert(M % 4 == 0 && K % 4 == 0, "M and K must be time of 4"); | |||
| #if !MGB_ENABLE_CPUINFO | |||
| sgemm_mk4_8x12_helper<matmul_mk4_8x12>(packA, packB, M, N, K, C, LDC, | |||
| is_first_k); | |||
| sgemm_mk4_8x12_helper<matmul_mk4_8x12>(packA, packB, M, N, K, C, LDC, is_first_k); | |||
| #else | |||
| auto arch = cpuinfo_get_current_core()->uarch; | |||
| #ifdef __IN_TEE_ENV__ | |||
| arch = cpuinfo_uarch_unknown; | |||
| #endif | |||
| if (arch == cpuinfo_uarch_cortex_a53) { | |||
| sgemm_mk4_8x12_helper<matmul_mk4_8x12_a53>(packA, packB, M, N, K, C, | |||
| LDC, is_first_k); | |||
| sgemm_mk4_8x12_helper<matmul_mk4_8x12_a53>( | |||
| packA, packB, M, N, K, C, LDC, is_first_k); | |||
| } else if (arch == cpuinfo_uarch_cortex_a55) { | |||
| sgemm_mk4_8x12_helper<matmul_mk4_8x12_a55>(packA, packB, M, N, K, C, | |||
| LDC, is_first_k); | |||
| sgemm_mk4_8x12_helper<matmul_mk4_8x12_a55>( | |||
| packA, packB, M, N, K, C, LDC, is_first_k); | |||
| } else { | |||
| sgemm_mk4_8x12_helper<matmul_mk4_8x12>(packA, packB, M, N, K, C, LDC, | |||
| is_first_k); | |||
| sgemm_mk4_8x12_helper<matmul_mk4_8x12>( | |||
| packA, packB, M, N, K, C, LDC, is_first_k); | |||
| } | |||
| #endif | |||
| } | |||
| @@ -15,17 +15,14 @@ | |||
| namespace megdnn { | |||
| namespace aarch64 { | |||
| namespace matmul { | |||
| MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, true, | |||
| sgemm_8x12); | |||
| MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, true, sgemm_8x12); | |||
| MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 16, 1, false, true, | |||
| sgemm_4x16); | |||
| MEGDNN_REG_GEMM_STRATEGY(float, float, float, 4, 16, 1, false, true, sgemm_4x16); | |||
| MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, false, | |||
| sgemm_mk4_8x12); | |||
| MEGDNN_REG_GEMM_STRATEGY(float, float, float, 8, 12, 1, false, false, sgemm_mk4_8x12); | |||
| MEGDNN_REG_GEMM_STRATEGY_NOPACK(float, float, float, 4, 16, 1, false, true, | |||
| sgemm_nopack_4x16); | |||
| MEGDNN_REG_GEMM_STRATEGY_NOPACK( | |||
| float, float, float, 4, 16, 1, false, true, sgemm_nopack_4x16); | |||
| } // namespace matmul | |||
| } // namespace aarch64 | |||
| @@ -20,8 +20,8 @@ using namespace aarch64::matmul; | |||
| namespace { | |||
| void kern_4x1(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||
| float* output) { | |||
| void kern_4x1( | |||
| const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, float* output) { | |||
| LDB *= sizeof(float); | |||
| asm volatile( | |||
| "subs %w[K], %w[K], #4\n" | |||
| @@ -64,8 +64,7 @@ void kern_4x1(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
| [output] "+r"(output), [LDB] "+r"(LDB) | |||
| : | |||
| : "v0", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "cc", | |||
| "memory"); | |||
| : "v0", "v4", "v5", "v6", "v7", "v16", "v17", "v18", "v19", "cc", "memory"); | |||
| } | |||
| // Overview of register layout: | |||
| @@ -89,8 +88,8 @@ void kern_4x1(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||
| // +--------+ - - - - -+--------+ | |||
| // Accumulator | |||
| void kern_4x4(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||
| float* output) { | |||
| void kern_4x4( | |||
| const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, float* output) { | |||
| //! As each load 16 number from B, but the pos add 12 * 4, so we minus 12 | |||
| //! here. | |||
| LDB = (LDB - 12) * sizeof(float); | |||
| @@ -165,8 +164,8 @@ void kern_4x4(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
| [output] "+r"(output), [LDB] "+r"(LDB) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", | |||
| "v18", "v19", "cc", "memory"); | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", | |||
| "v19", "cc", "memory"); | |||
| } | |||
| // Overview of register layout: | |||
| @@ -195,8 +194,8 @@ void kern_4x4(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||
| // +--------+ - - - - -+--------+ | |||
| // Accumulator | |||
| void kern_4x8(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||
| float* output) { | |||
| void kern_4x8( | |||
| const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, float* output) { | |||
| //! As each load 32 number from B, but the pos add 24 * 4, so we minus 24 | |||
| //! here. | |||
| LDB = (LDB - 24) * sizeof(float); | |||
| @@ -304,9 +303,9 @@ void kern_4x8(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
| [output] "+r"(output), [LDB] "+r"(LDB) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", | |||
| "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", | |||
| "v27", "cc", "memory"); | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v16", "v17", "v18", | |||
| "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "cc", | |||
| "memory"); | |||
| } | |||
| // Overview of register layout: | |||
| @@ -342,8 +341,7 @@ void kern_4x8(const float* a_ptr, const float* b_ptr, size_t LDB, size_t K, | |||
| // +--------+ | |||
| // Accumulator | |||
| void kern_4x16(const float* a_ptr, const float* b_ptr, int LDB, int K, | |||
| float* output) { | |||
| void kern_4x16(const float* a_ptr, const float* b_ptr, int LDB, int K, float* output) { | |||
| //! As each load 64 number from B, but the pos add 56 * 4, so we minus 56 | |||
| //! here. | |||
| LDB = (LDB - 56) * sizeof(float); | |||
| @@ -565,20 +563,18 @@ void kern_4x16(const float* a_ptr, const float* b_ptr, int LDB, int K, | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
| [output] "+r"(output), [LDB] "+r"(LDB) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
| "v11", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", | |||
| "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", | |||
| "memory"); | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||
| "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", | |||
| "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory"); | |||
| } | |||
| } // namespace | |||
| MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(sgemm_nopack_4x16); | |||
| void sgemm_nopack_4x16::kern(const float* A, size_t LDA, const float* B, | |||
| size_t LDB, float* C, size_t LDC, size_t M, | |||
| size_t K, size_t N, const float*, void*, bool trA, | |||
| bool trB) const { | |||
| void sgemm_nopack_4x16::kern( | |||
| const float* A, size_t LDA, const float* B, size_t LDB, float* C, size_t LDC, | |||
| size_t M, size_t K, size_t N, const float*, void*, bool trA, bool trB) const { | |||
| constexpr static size_t MB = 4; | |||
| constexpr static size_t KB = 4; | |||
| constexpr static size_t NB = 16; | |||
| @@ -46,8 +46,9 @@ namespace matmul_12x8x1 { | |||
| * Accumulator | |||
| */ | |||
| static void kern_12x8(const int16_t* packA, const int16_t* packB, int K, | |||
| int32_t* output, int LDC, bool is_first_k) { | |||
| static void kern_12x8( | |||
| const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC, | |||
| bool is_first_k) { | |||
| const int16_t* a_ptr = packA; | |||
| const int16_t* b_ptr = packB; | |||
| @@ -155,15 +156,13 @@ static void kern_12x8(const int16_t* packA, const int16_t* packB, int K, | |||
| "stp q25, q26, [x9]\n" | |||
| "stp q27, q28, [x10]\n" | |||
| "stp q29, q30, [x11]\n" | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||
| [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||
| [output] "+r"(output) | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||
| [K] "+r"(K), [LDC] "+r"(LDC), [output] "+r"(output) | |||
| : | |||
| : "v0", "v1", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", | |||
| "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", | |||
| "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "x1", | |||
| "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", | |||
| "cc", "memory"); | |||
| : "v0", "v1", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||
| "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", | |||
| "v25", "v26", "v27", "v28", "v29", "v30", "x1", "x2", "x3", "x4", "x5", | |||
| "x6", "x7", "x8", "x9", "x10", "x11", "cc", "memory"); | |||
| #undef LOAD_LINE | |||
| #undef LOAD_C | |||
| #undef STORE_LINE | |||
| @@ -196,8 +195,9 @@ static void kern_12x8(const int16_t* packA, const int16_t* packB, int K, | |||
| * Accumulator | |||
| */ | |||
| static void kern_8x8(const int16_t* packA, const int16_t* packB, int K, | |||
| int32_t* output, int LDC, bool is_first_k) { | |||
| static void kern_8x8( | |||
| const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC, | |||
| bool is_first_k) { | |||
| const int16_t* a_ptr = packA; | |||
| const int16_t* b_ptr = packB; | |||
| @@ -276,13 +276,12 @@ static void kern_8x8(const int16_t* packA, const int16_t* packB, int K, | |||
| "stp q17, q18, [x5]\n" | |||
| "stp q19, q20, [x6]\n" | |||
| "stp q21, q22, [x7]\n" | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||
| [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||
| [output] "+r"(output) | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||
| [K] "+r"(K), [LDC] "+r"(LDC), [output] "+r"(output) | |||
| : | |||
| : "v0", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||
| "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "x1", | |||
| "x2", "x3", "x4", "x5", "x6", "x7", "cc", "memory"); | |||
| : "v0", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", | |||
| "v16", "v17", "v18", "v19", "v20", "v21", "v22", "x1", "x2", "x3", "x4", | |||
| "x5", "x6", "x7", "cc", "memory"); | |||
| #undef LOAD_LINE | |||
| #undef LOAD_C | |||
| #undef STORE_LINE | |||
| @@ -311,9 +310,9 @@ static void kern_8x8(const int16_t* packA, const int16_t* packB, int K, | |||
| * Accumulator | |||
| */ | |||
| static void kern_4x8(const int16_t* packA, const int16_t* packB, int K, | |||
| int32_t* output, int LDC, bool is_first_k, | |||
| size_t m_remain) { | |||
| static void kern_4x8( | |||
| const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC, | |||
| bool is_first_k, size_t m_remain) { | |||
| const int16_t* a_ptr = packA; | |||
| const int16_t* b_ptr = packB; | |||
| @@ -388,14 +387,13 @@ static void kern_4x8(const int16_t* packA, const int16_t* packB, int K, | |||
| "cbnz %w[K], 2b\n" | |||
| "3:\n" STORE_C | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||
| [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||
| [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), | |||
| [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [x0] "+r"(x0), | |||
| [m_remain] "+r"(m_remain) | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||
| [K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), | |||
| [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||
| [x0] "+r"(x0), [m_remain] "+r"(m_remain) | |||
| : | |||
| : "v0", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||
| "cc", "memory"); | |||
| : "v0", "v2", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "cc", | |||
| "memory"); | |||
| #undef LOAD_LINE | |||
| #undef LOAD_C | |||
| #undef STORE_LINE | |||
| @@ -432,9 +430,9 @@ static void kern_4x8(const int16_t* packA, const int16_t* packB, int K, | |||
| * Accumulator | |||
| */ | |||
| static void kern_12x4(const int16_t* packA, const int16_t* packB, int K, | |||
| int32_t* output, int LDC, bool is_first_k, | |||
| size_t n_remain) { | |||
| static void kern_12x4( | |||
| const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC, | |||
| bool is_first_k, size_t n_remain) { | |||
| const int16_t* a_ptr = packA; | |||
| const int16_t* b_ptr = packB; | |||
| @@ -573,18 +571,16 @@ static void kern_12x4(const int16_t* packA, const int16_t* packB, int K, | |||
| "cbnz %w[K], 2b\n" | |||
| "3:\n" STORE_C | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||
| [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||
| [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), | |||
| [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||
| [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), | |||
| [outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), | |||
| [outptr8] "=r"(outptr8), [outptr9] "=r"(outptr9), | |||
| [outptr10] "=r"(outptr10), [outptr11] "=r"(outptr11), | |||
| [x0] "+r"(x0), [n_remain] "+r"(n_remain) | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||
| [K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), | |||
| [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||
| [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6), | |||
| [outptr7] "=r"(outptr7), [outptr8] "=r"(outptr8), [outptr9] "=r"(outptr9), | |||
| [outptr10] "=r"(outptr10), [outptr11] "=r"(outptr11), [x0] "+r"(x0), | |||
| [n_remain] "+r"(n_remain) | |||
| : | |||
| : "v0", "v1", "v2", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||
| "v15", "v16", "v17", "v18", "v19", "cc", "memory"); | |||
| : "v0", "v1", "v2", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", | |||
| "v16", "v17", "v18", "v19", "cc", "memory"); | |||
| #undef LOAD_LINE | |||
| #undef LOAD_C | |||
| @@ -618,9 +614,9 @@ static void kern_12x4(const int16_t* packA, const int16_t* packB, int K, | |||
| * Accumulator | |||
| */ | |||
| static void kern_8x4(const int16_t* packA, const int16_t* packB, int K, | |||
| int32_t* output, int LDC, bool is_first_k, | |||
| size_t n_remain) { | |||
| static void kern_8x4( | |||
| const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC, | |||
| bool is_first_k, size_t n_remain) { | |||
| const int16_t* a_ptr = packA; | |||
| const int16_t* b_ptr = packB; | |||
| @@ -734,16 +730,14 @@ static void kern_8x4(const int16_t* packA, const int16_t* packB, int K, | |||
| "cbnz %w[K], 2b\n" | |||
| "3:\n" STORE_C | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||
| [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||
| [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), | |||
| [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||
| [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), | |||
| [outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), [x0] "+r"(x0), | |||
| [n_remain] "+r"(n_remain) | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||
| [K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), | |||
| [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||
| [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6), | |||
| [outptr7] "=r"(outptr7), [x0] "+r"(x0), [n_remain] "+r"(n_remain) | |||
| : | |||
| : "v0", "v2", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", | |||
| "cc", "memory"); | |||
| : "v0", "v2", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "cc", | |||
| "memory"); | |||
| #undef LOAD_LINE | |||
| #undef LOAD_C | |||
| @@ -773,9 +767,9 @@ static void kern_8x4(const int16_t* packA, const int16_t* packB, int K, | |||
| * Accumulator | |||
| */ | |||
| static void kern_4x4(const int16_t* packA, const int16_t* packB, int K, | |||
| int32_t* output, int LDC, bool is_first_k, size_t m_remain, | |||
| size_t n_remain) { | |||
| static void kern_4x4( | |||
| const int16_t* packA, const int16_t* packB, int K, int32_t* output, int LDC, | |||
| bool is_first_k, size_t m_remain, size_t n_remain) { | |||
| const int16_t* a_ptr = packA; | |||
| const int16_t* b_ptr = packB; | |||
| @@ -874,11 +868,10 @@ static void kern_4x4(const int16_t* packA, const int16_t* packB, int K, | |||
| "cbnz %w[K], 2b\n" | |||
| "3:\n" STORE_C | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||
| [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||
| [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), | |||
| [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [x0] "+r"(x0), | |||
| [m_remain] "+r"(m_remain), [x1] "+r"(x1), | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||
| [K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), | |||
| [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||
| [x0] "+r"(x0), [m_remain] "+r"(m_remain), [x1] "+r"(x1), | |||
| [n_remain] "+r"(n_remain) | |||
| : | |||
| : "v0", "v2", "v8", "v9", "v10", "v11", "cc", "memory"); | |||
| @@ -889,9 +882,9 @@ static void kern_4x4(const int16_t* packA, const int16_t* packB, int K, | |||
| #undef STORE_C | |||
| } | |||
| static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr, | |||
| int ldin, int y0, int ymax, int k0, | |||
| int kmax) { | |||
| static void gemm_s16_12x8x1_pack_A_n( | |||
| int16_t* outptr, const int16_t* inptr, int ldin, int y0, int ymax, int k0, | |||
| int kmax) { | |||
| int16_t zerobuff[4]; | |||
| std::memset(zerobuff, 0, sizeof(int16_t) * 4); | |||
| @@ -925,15 +918,15 @@ static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr, | |||
| int K = kmax - k0; | |||
| for (; K > 3; K -= 4) { | |||
| interleave_12x1_4_h(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||
| inptr6, inptr7, inptr8, inptr9, inptr10, | |||
| inptr11, outptr); | |||
| interleave_12x1_4_h( | |||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
| inptr8, inptr9, inptr10, inptr11, outptr); | |||
| } | |||
| if (K > 0) { | |||
| interleave_12(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||
| inptr6, inptr7, inptr8, inptr9, inptr10, inptr11, | |||
| outptr, 1, K); | |||
| interleave_12( | |||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
| inptr8, inptr9, inptr10, inptr11, outptr, 1, K); | |||
| } | |||
| } | |||
| @@ -949,13 +942,15 @@ static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr, | |||
| int K = kmax - k0; | |||
| for (; K > 7; K -= 8) { | |||
| interleave_8x1_8_h(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||
| inptr6, inptr7, outptr); | |||
| interleave_8x1_8_h( | |||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
| outptr); | |||
| } | |||
| if (K > 0) { | |||
| interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||
| inptr7, outptr, 1, K); | |||
| interleave_8( | |||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
| outptr, 1, K); | |||
| } | |||
| } | |||
| @@ -975,9 +970,11 @@ static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr, | |||
| if (y + 3 >= ymax) { | |||
| switch (y + 3 - ymax) { | |||
| case 2: | |||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr1 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 1: | |||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr2 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 0: | |||
| inptr3 = zerobuff; | |||
| break; | |||
| @@ -992,9 +989,11 @@ static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr, | |||
| if (y + 3 >= ymax) { | |||
| switch (y + 3 - ymax) { | |||
| case 2: | |||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr1 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 1: | |||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr2 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 0: | |||
| inptr3 = zerobuff; | |||
| break; | |||
| @@ -1007,9 +1006,8 @@ static void gemm_s16_12x8x1_pack_A_n(int16_t* outptr, const int16_t* inptr, | |||
| } | |||
| } | |||
| static void gemm_s16_12x8x1_transpose_pack_A_n(int16_t* out, const int16_t* in, | |||
| int ldin, int x0, int xmax, | |||
| int k0, int kmax) { | |||
| static void gemm_s16_12x8x1_transpose_pack_A_n( | |||
| int16_t* out, const int16_t* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||
| const int ksize = kmax - k0; | |||
| const int ksize4 = ksize * 4; | |||
| const int ksize8 = ksize4 * 2; | |||
| @@ -1054,8 +1052,8 @@ static void gemm_s16_12x8x1_transpose_pack_A_n(int16_t* out, const int16_t* in, | |||
| } | |||
| } | |||
| static void gemm_s16_12x8x1_pack_B_n(int16_t* out, const int16_t* in, int ldin, | |||
| int x0, int xmax, int k0, int kmax) { | |||
| static void gemm_s16_12x8x1_pack_B_n( | |||
| int16_t* out, const int16_t* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||
| const int ksize = kmax - k0; | |||
| const int ksize4 = ksize * 4; | |||
| const int ksize8 = ksize4 * 2; | |||
| @@ -1090,10 +1088,9 @@ static void gemm_s16_12x8x1_pack_B_n(int16_t* out, const int16_t* in, int ldin, | |||
| } | |||
| } | |||
| static void gemm_s16_12x8x1_transpose_pack_B_n(int16_t* outptr, | |||
| const int16_t* inptr, int ldin, | |||
| int y0, int ymax, int k0, | |||
| int kmax) { | |||
| static void gemm_s16_12x8x1_transpose_pack_B_n( | |||
| int16_t* outptr, const int16_t* inptr, int ldin, int y0, int ymax, int k0, | |||
| int kmax) { | |||
| int16_t zerobuff[4]; | |||
| std::memset(zerobuff, 0, sizeof(int16_t) * 4); | |||
| @@ -1110,13 +1107,15 @@ static void gemm_s16_12x8x1_transpose_pack_B_n(int16_t* outptr, | |||
| int K = kmax - k0; | |||
| for (; K > 7; K -= 8) { | |||
| interleave_8x1_8_h(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||
| inptr6, inptr7, outptr); | |||
| interleave_8x1_8_h( | |||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
| outptr); | |||
| } | |||
| if (K > 0) { | |||
| interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||
| inptr7, outptr, 1, K); | |||
| interleave_8( | |||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
| outptr, 1, K); | |||
| } | |||
| } | |||
| @@ -1136,9 +1135,11 @@ static void gemm_s16_12x8x1_transpose_pack_B_n(int16_t* outptr, | |||
| if (y + 3 >= ymax) { | |||
| switch (y + 3 - ymax) { | |||
| case 2: | |||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr1 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 1: | |||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr2 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 0: | |||
| inptr3 = zerobuff; | |||
| break; | |||
| @@ -1153,9 +1154,11 @@ static void gemm_s16_12x8x1_transpose_pack_B_n(int16_t* outptr, | |||
| if (y + 3 >= ymax) { | |||
| switch (y + 3 - ymax) { | |||
| case 2: | |||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr1 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 1: | |||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr2 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 0: | |||
| inptr3 = zerobuff; | |||
| break; | |||
| @@ -22,39 +22,37 @@ using namespace aarch64::matmul; | |||
| ///////////////////////// gemm_s16_12x8x1 //////////////////////////////////// | |||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s16_12x8x1); | |||
| void gemm_s16_12x8x1::pack_A(dt_int16* outptr, const dt_int16* inptr, int ldin, | |||
| int y0, int ymax, int k0, int kmax, | |||
| bool transpose) const { | |||
| void gemm_s16_12x8x1::pack_A( | |||
| dt_int16* outptr, const dt_int16* inptr, int ldin, int y0, int ymax, int k0, | |||
| int kmax, bool transpose) const { | |||
| if (transpose) { | |||
| matmul_12x8x1::gemm_s16_12x8x1_transpose_pack_A_n(outptr, inptr, ldin, | |||
| y0, ymax, k0, kmax); | |||
| matmul_12x8x1::gemm_s16_12x8x1_transpose_pack_A_n( | |||
| outptr, inptr, ldin, y0, ymax, k0, kmax); | |||
| } else { | |||
| matmul_12x8x1::gemm_s16_12x8x1_pack_A_n(outptr, inptr, ldin, y0, ymax, | |||
| k0, kmax); | |||
| matmul_12x8x1::gemm_s16_12x8x1_pack_A_n( | |||
| outptr, inptr, ldin, y0, ymax, k0, kmax); | |||
| } | |||
| } | |||
| void gemm_s16_12x8x1::pack_B(dt_int16* out, const dt_int16* in, int ldin, | |||
| int x0, int xmax, int k0, int kmax, | |||
| bool transpose) const { | |||
| void gemm_s16_12x8x1::pack_B( | |||
| dt_int16* out, const dt_int16* in, int ldin, int x0, int xmax, int k0, int kmax, | |||
| bool transpose) const { | |||
| if (transpose) { | |||
| matmul_12x8x1::gemm_s16_12x8x1_transpose_pack_B_n(out, in, ldin, x0, | |||
| xmax, k0, kmax); | |||
| matmul_12x8x1::gemm_s16_12x8x1_transpose_pack_B_n( | |||
| out, in, ldin, x0, xmax, k0, kmax); | |||
| } else { | |||
| matmul_12x8x1::gemm_s16_12x8x1_pack_B_n(out, in, ldin, x0, xmax, k0, | |||
| kmax); | |||
| matmul_12x8x1::gemm_s16_12x8x1_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | |||
| } | |||
| } | |||
| void gemm_s16_12x8x1::kern(const dt_int16* packA, const dt_int16* packB, | |||
| size_t M, size_t N, size_t K, dt_int32* C, | |||
| size_t LDC, bool is_first_k, const dt_int32*, | |||
| dt_int32*) const { | |||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
| (A_dtype.enumv() == DTypeEnum::Int16 && | |||
| C_dtype.enumv() == DTypeEnum::Int32), | |||
| "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||
| C_dtype.name()); | |||
| void gemm_s16_12x8x1::kern( | |||
| const dt_int16* packA, const dt_int16* packB, size_t M, size_t N, size_t K, | |||
| dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const { | |||
| megdnn_assert( | |||
| A_dtype.enumv() == B_dtype.enumv() && | |||
| (A_dtype.enumv() == DTypeEnum::Int16 && | |||
| C_dtype.enumv() == DTypeEnum::Int32), | |||
| "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); | |||
| MEGDNN_MARK_USED_VAR(A_dtype); | |||
| MEGDNN_MARK_USED_VAR(B_dtype); | |||
| MEGDNN_MARK_USED_VAR(C_dtype); | |||
| @@ -72,15 +70,15 @@ void gemm_s16_12x8x1::kern(const dt_int16* packA, const dt_int16* packB, | |||
| size_t n = 0; | |||
| const dt_int16* cur_packB = packB; | |||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
| matmul_12x8x1::kern_12x8(packA, cur_packB, K, output, LDC, | |||
| is_first_k); | |||
| matmul_12x8x1::kern_12x8(packA, cur_packB, K, output, LDC, is_first_k); | |||
| output += B_INTERLEAVE; | |||
| cur_packB += K8; | |||
| } | |||
| for (; n < N; n += 4) { | |||
| matmul_12x8x1::kern_12x4(packA, cur_packB, K, output, LDC, | |||
| is_first_k, std::min<size_t>(N - n, 4)); | |||
| matmul_12x8x1::kern_12x4( | |||
| packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(N - n, 4)); | |||
| output += 4; | |||
| cur_packB += K4; | |||
| } | |||
| @@ -92,15 +90,15 @@ void gemm_s16_12x8x1::kern(const dt_int16* packA, const dt_int16* packB, | |||
| const dt_int16* cur_packB = packB; | |||
| size_t n = 0; | |||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
| matmul_12x8x1::kern_8x8(packA, cur_packB, K, output, LDC, | |||
| is_first_k); | |||
| matmul_12x8x1::kern_8x8(packA, cur_packB, K, output, LDC, is_first_k); | |||
| output += B_INTERLEAVE; | |||
| cur_packB += K8; | |||
| } | |||
| for (; n < N; n += 4) { | |||
| matmul_12x8x1::kern_8x4(packA, cur_packB, K, output, LDC, | |||
| is_first_k, std::min<size_t>(N - n, 4)); | |||
| matmul_12x8x1::kern_8x4( | |||
| packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(N - n, 4)); | |||
| output += 4; | |||
| cur_packB += K4; | |||
| } | |||
| @@ -112,16 +110,17 @@ void gemm_s16_12x8x1::kern(const dt_int16* packA, const dt_int16* packB, | |||
| const dt_int16* cur_packB = packB; | |||
| size_t n = 0; | |||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
| matmul_12x8x1::kern_4x8(packA, cur_packB, K, output, LDC, | |||
| is_first_k, std::min<size_t>(M - m, 4)); | |||
| matmul_12x8x1::kern_4x8( | |||
| packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(M - m, 4)); | |||
| output += B_INTERLEAVE; | |||
| cur_packB += K8; | |||
| } | |||
| for (; n < N; n += 4) { | |||
| matmul_12x8x1::kern_4x4(packA, cur_packB, K, output, LDC, | |||
| is_first_k, std::min<size_t>(M - m, 4), | |||
| std::min<size_t>(N - n, 4)); | |||
| matmul_12x8x1::kern_4x4( | |||
| packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||
| output += 4; | |||
| cur_packB += K4; | |||
| } | |||
| @@ -16,11 +16,11 @@ namespace megdnn { | |||
| namespace aarch64 { | |||
| namespace matmul { | |||
| MEGDNN_REG_GEMM_STRATEGY(dt_int16, dt_int32, dt_int32, 12, 8, 1, false, true, | |||
| gemm_s16_12x8x1); | |||
| MEGDNN_REG_GEMM_STRATEGY( | |||
| dt_int16, dt_int32, dt_int32, 12, 8, 1, false, true, gemm_s16_12x8x1); | |||
| MEGDNN_REG_GEMM_STRATEGY_NOPACK(dt_int16, dt_int32, dt_int32, 8, 8, 1, false, | |||
| true, gemm_nopack_s16_8x8); | |||
| MEGDNN_REG_GEMM_STRATEGY_NOPACK( | |||
| dt_int16, dt_int32, dt_int32, 8, 8, 1, false, true, gemm_nopack_s16_8x8); | |||
| } // namespace matmul | |||
| } // namespace aarch64 | |||
| @@ -9,8 +9,8 @@ | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include "src/aarch64/matrix_mul/int16/strategy.h" | |||
| #include "src/aarch64/matrix_mul/asm/common.h" | |||
| #include "src/aarch64/matrix_mul/int16/strategy.h" | |||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||
| #include "src/common/utils.h" | |||
| @@ -20,8 +20,9 @@ using namespace aarch64::matmul; | |||
| namespace { | |||
| void kern_8x1(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||
| dt_int32* output) { | |||
| void kern_8x1( | |||
| const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||
| dt_int32* output) { | |||
| //! As each load 32 number from B, but the pos add 24 * 2, so we minus 24 | |||
| //! here. | |||
| LDB *= sizeof(dt_int16); | |||
| @@ -91,9 +92,8 @@ void kern_8x1(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
| [output] "+r"(output), [LDB] "+r"(LDB) | |||
| : | |||
| : "v0", "v16", "v17", "v18", "v19", "v20", "v21", | |||
| "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||
| "v29", "v30", "v31", "cc", "memory"); | |||
| : "v0", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", | |||
| "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory"); | |||
| } | |||
| // Overview of register layout: | |||
| @@ -120,8 +120,9 @@ void kern_8x1(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||
| // | v31[0-7]| |v23[0-3]| | |||
| // +---------+ +--------+ | |||
| // Accumulator | |||
| void kern_8x4(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||
| dt_int32* output) { | |||
| void kern_8x4( | |||
| const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||
| dt_int32* output) { | |||
| //! As each load 32 number from B, but the pos add 24 * 2, so we minus 24 | |||
| //! here. | |||
| LDB = (LDB - 24) * sizeof(dt_int16); | |||
| @@ -349,9 +350,9 @@ void kern_8x4(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
| [output] "+r"(output), [LDB] "+r"(LDB) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19", | |||
| "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||
| "v29", "v30", "v31", "cc", "memory"); | |||
| : "v0", "v1", "v2", "v3", "v16", "v17", "v18", "v19", "v20", "v21", "v22", | |||
| "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", | |||
| "memory"); | |||
| } | |||
| // Overview of register layout: | |||
| @@ -382,8 +383,9 @@ void kern_8x4(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||
| // | v7[0-7]| |v30[0-3]|v31[0-3]| | |||
| // +--------+ +--------+--------+ | |||
| // Accumulator | |||
| void kern_8x8(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||
| dt_int32* output) { | |||
| void kern_8x8( | |||
| const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||
| dt_int32* output) { | |||
| //! As each load 64 number from B, but the pos add 48 * 2, so we minus 48 | |||
| //! here. | |||
| LDB = (LDB - 48) * sizeof(dt_int16); | |||
| @@ -693,20 +695,20 @@ void kern_8x8(const dt_int16* a_ptr, const dt_int16* b_ptr, int LDB, int K, | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
| [output] "+r"(output), [LDB] "+r"(LDB) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||
| "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||
| "v29", "v30", "v31", "cc", "memory"); | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||
| "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||
| "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", | |||
| "cc", "memory"); | |||
| } | |||
| } // anonymous namespace | |||
| MEGDNN_REG_GEMM_STRATEGY_IMPL_NOPACK(gemm_nopack_s16_8x8); | |||
| void gemm_nopack_s16_8x8::kern(const dt_int16* A, size_t LDA, const dt_int16* B, | |||
| size_t LDB, dt_int32* C, size_t LDC, size_t M, | |||
| size_t K, size_t N, const dt_int32*, void*, | |||
| bool trA, bool trB) const { | |||
| void gemm_nopack_s16_8x8::kern( | |||
| const dt_int16* A, size_t LDA, const dt_int16* B, size_t LDB, dt_int32* C, | |||
| size_t LDC, size_t M, size_t K, size_t N, const dt_int32*, void*, bool trA, | |||
| bool trB) const { | |||
| constexpr static size_t MB = 8; | |||
| constexpr static size_t KB = 8; | |||
| constexpr static size_t NB = 8; | |||
| @@ -36,9 +36,9 @@ namespace matmul_s4_4x4x16 { | |||
| * Accumulator | |||
| */ | |||
| static void s4_kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, | |||
| int16_t* output, int LDC, bool is_first_k, int m_remain, | |||
| int n_remain) { | |||
| static void s4_kern_8x8_remain( | |||
| const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||
| bool is_first_k, int m_remain, int n_remain) { | |||
| K /= 8; | |||
| LDC = LDC * sizeof(int16_t); | |||
| const int8_t* a_ptr = packA; | |||
| @@ -170,7 +170,7 @@ static void s4_kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, | |||
| "dup v5.8b,v20.b[5]\n" | |||
| "dup v6.8b,v20.b[6]\n" | |||
| "dup v7.8b,v20.b[7]\n" | |||
| "ld1 {v17.8b}, [%[b_ptr]], 8\n" | |||
| "dup v8.8b,v20.b[8]\n" | |||
| @@ -318,16 +318,16 @@ static void s4_kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, | |||
| STORE_C | |||
| : | |||
| [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), | |||
| [ is_first_k ] "+r"(is_first_k), [ K ] "+r"(K), [ LDC ] "+r"(LDC), | |||
| [ outptr ] "+r"(outptr), [ m_remain ] "+r"(m_remain), | |||
| [ n_remain ] "+r"(n_remain) //,[tmp_packa1]"+r"(tmp_packa1),[tmp_packb1]"+r"(tmp_packb1) | |||
| [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||
| [K] "+r"(K), [LDC] "+r"(LDC), [outptr] "+r"(outptr), | |||
| [m_remain] "+r"(m_remain), | |||
| [n_remain] "+r"( | |||
| n_remain) //,[tmp_packa1]"+r"(tmp_packa1),[tmp_packb1]"+r"(tmp_packb1) | |||
| : | |||
| : "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", | |||
| "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||
| "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||
| "v29", "v30", "v31"); | |||
| : "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "v0", | |||
| "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", | |||
| "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", | |||
| "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); | |||
| #undef LOAD_LINE | |||
| #undef LOAD_C | |||
| @@ -335,14 +335,14 @@ static void s4_kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, | |||
| #undef STORE_C | |||
| } | |||
| static void s4_kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||
| int16_t* output, int LDC, bool is_first_k, int m_remain, | |||
| int n_remain) { | |||
| static void s4_kern_8x8( | |||
| const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||
| bool is_first_k, int m_remain, int n_remain) { | |||
| K /= 8; | |||
| LDC = LDC * sizeof(int16_t); | |||
| const int8_t* a_ptr = packA; | |||
| const int8_t* b_ptr = packB; | |||
| // clang-format off | |||
| // clang-format off | |||
| #define LOAD_C_8 \ | |||
| "ld1 {v24.8h}, [x0], #16\n" \ | |||
| @@ -363,9 +363,9 @@ static void s4_kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||
| "st1 {v28.8h}, [x4], #16\n" \ | |||
| "st1 {v29.8h}, [x5], #16\n" \ | |||
| "st1 {v30.8h}, [x6], #16\n" \ | |||
| "st1 {v31.8h}, [x7], #16\n" \ | |||
| "st1 {v31.8h}, [x7], #16\n" | |||
| // clang-format on | |||
| // clang-format on | |||
| register int16_t* outptr asm("x0") = output; | |||
| asm volatile( | |||
| "add x1, x0, %x[LDC]\n" | |||
| @@ -395,8 +395,8 @@ static void s4_kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||
| "PRFM PLDL1KEEP, [%[a_ptr], #512]\n" | |||
| "PRFM PLDL1KEEP, [%[b_ptr], #512]\n" | |||
| "1:\n" | |||
| // "ld1 {v20.16b}, [%[a_ptr]],#16\n" | |||
| // "ld1 {v21.16b}, [%[a_ptr]],#16\n" | |||
| // "ld1 {v20.16b}, [%[a_ptr]],#16\n" | |||
| // "ld1 {v21.16b}, [%[a_ptr]],#16\n" | |||
| "dup v0.8b,v20.b[0]\n" | |||
| "ld1 {v22.16b}, [%[a_ptr]],#16\n" | |||
| "dup v1.8b,v20.b[1]\n" | |||
| @@ -409,7 +409,6 @@ static void s4_kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||
| "dup v5.8b,v20.b[5]\n" | |||
| "dup v6.8b,v20.b[6]\n" | |||
| "dup v7.8b,v20.b[7]\n" | |||
| "dup v8.8b,v20.b[8]\n" | |||
| "smlal v24.8h, v0.8b, v16.8b\n" | |||
| @@ -560,26 +559,26 @@ static void s4_kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||
| STORE_C_8 | |||
| : | |||
| [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), | |||
| [ is_first_k ] "+r"(is_first_k), [ K ] "+r"(K), [ LDC ] "+r"(LDC), | |||
| [ outptr ] "+r"(outptr), [ m_remain ] "+r"(m_remain), | |||
| [ n_remain ] "+r"(n_remain) //,[tmp_packa1]"+r"(tmp_packa1),[tmp_packb1]"+r"(tmp_packb1) | |||
| [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||
| [K] "+r"(K), [LDC] "+r"(LDC), [outptr] "+r"(outptr), | |||
| [m_remain] "+r"(m_remain), | |||
| [n_remain] "+r"( | |||
| n_remain) //,[tmp_packa1]"+r"(tmp_packa1),[tmp_packb1]"+r"(tmp_packb1) | |||
| : | |||
| : "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", | |||
| "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||
| "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||
| "v29", "v30", "v31"); | |||
| : "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "v0", | |||
| "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", | |||
| "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", | |||
| "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); | |||
| #undef LOAD_LINE | |||
| #undef LOAD_C | |||
| #undef STORE_LINE | |||
| #undef STORE_C | |||
| } | |||
| //packa | |||
| static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* inptr, | |||
| int ldin, int y0, int ymax, int k0, | |||
| int kmax) { | |||
| // packa | |||
| static void gemm_s4x4x16_8x8x8_transpose_pack( | |||
| dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||
| int kmax) { | |||
| int8_t zerobuff[8]; | |||
| int8_t tmpbuff0[8]; | |||
| int8_t tmpbuff1[8]; | |||
| @@ -617,22 +616,23 @@ static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* in | |||
| prefetch_2x(inptr5); | |||
| prefetch_2x(inptr6); | |||
| prefetch_2x(inptr7); | |||
| int K = (kmax - k0)/2; | |||
| int K = (kmax - k0) / 2; | |||
| //! read 4 * 16 in each row | |||
| for (; K > 3; K -= 4) { | |||
| transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, | |||
| inptr5, inptr6, inptr7, outptr); | |||
| transpose_4x8_1_b_with_shift( | |||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
| outptr); | |||
| } | |||
| if (K > 0) { | |||
| std::memcpy(tmpbuff0,inptr0,K); | |||
| std::memcpy(tmpbuff1,inptr1,K); | |||
| std::memcpy(tmpbuff2,inptr2,K); | |||
| std::memcpy(tmpbuff3,inptr3,K); | |||
| std::memcpy(tmpbuff4,inptr4,K); | |||
| std::memcpy(tmpbuff5,inptr5,K); | |||
| std::memcpy(tmpbuff6,inptr6,K); | |||
| std::memcpy(tmpbuff7,inptr7,K); | |||
| std::memcpy(tmpbuff0, inptr0, K); | |||
| std::memcpy(tmpbuff1, inptr1, K); | |||
| std::memcpy(tmpbuff2, inptr2, K); | |||
| std::memcpy(tmpbuff3, inptr3, K); | |||
| std::memcpy(tmpbuff4, inptr4, K); | |||
| std::memcpy(tmpbuff5, inptr5, K); | |||
| std::memcpy(tmpbuff6, inptr6, K); | |||
| std::memcpy(tmpbuff7, inptr7, K); | |||
| inptr0 = tmpbuff0; | |||
| inptr1 = tmpbuff1; | |||
| inptr2 = tmpbuff2; | |||
| @@ -641,8 +641,9 @@ static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* in | |||
| inptr5 = tmpbuff5; | |||
| inptr6 = tmpbuff6; | |||
| inptr7 = tmpbuff7; | |||
| transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, | |||
| inptr5, inptr6, inptr7, outptr); | |||
| transpose_4x8_1_b_with_shift( | |||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
| outptr); | |||
| } | |||
| } | |||
| for (; y < ymax; y += 8) { | |||
| @@ -655,23 +656,29 @@ static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* in | |||
| const int8_t* inptr6 = inptr5 + ldin; | |||
| const int8_t* inptr7 = inptr6 + ldin; | |||
| int K = (kmax - k0)/2; | |||
| int K = (kmax - k0) / 2; | |||
| //! read 4 * 16 in each row | |||
| for (; K > 3; K -= 4) { | |||
| if (y + 7 >= ymax) { | |||
| switch (y + 7 - ymax) { | |||
| case 6: | |||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr1 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 5: | |||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr2 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 4: | |||
| inptr3 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr3 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 3: | |||
| inptr4 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr4 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 2: | |||
| inptr5 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr5 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 1: | |||
| inptr6 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr6 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 0: | |||
| inptr7 = zerobuff; | |||
| break; | |||
| @@ -679,24 +686,31 @@ static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* in | |||
| megdnn_assert(0); | |||
| } | |||
| } | |||
| transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, | |||
| inptr5, inptr6, inptr7, outptr); | |||
| transpose_4x8_1_b_with_shift( | |||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
| outptr); | |||
| } | |||
| if (K > 0) { | |||
| if (y + 7 >= ymax) { | |||
| switch (y + 7 - ymax) { | |||
| case 6: | |||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr1 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 5: | |||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr2 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 4: | |||
| inptr3 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr3 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 3: | |||
| inptr4 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr4 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 2: | |||
| inptr5 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr5 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 1: | |||
| inptr6 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr6 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 0: | |||
| inptr7 = zerobuff; | |||
| break; | |||
| @@ -705,14 +719,14 @@ static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* in | |||
| } | |||
| } | |||
| std::memcpy(tmpbuff0,inptr0,K); | |||
| std::memcpy(tmpbuff1,inptr1,K); | |||
| std::memcpy(tmpbuff2,inptr2,K); | |||
| std::memcpy(tmpbuff3,inptr3,K); | |||
| std::memcpy(tmpbuff4,inptr4,K); | |||
| std::memcpy(tmpbuff5,inptr5,K); | |||
| std::memcpy(tmpbuff6,inptr6,K); | |||
| std::memcpy(tmpbuff7,inptr7,K); | |||
| std::memcpy(tmpbuff0, inptr0, K); | |||
| std::memcpy(tmpbuff1, inptr1, K); | |||
| std::memcpy(tmpbuff2, inptr2, K); | |||
| std::memcpy(tmpbuff3, inptr3, K); | |||
| std::memcpy(tmpbuff4, inptr4, K); | |||
| std::memcpy(tmpbuff5, inptr5, K); | |||
| std::memcpy(tmpbuff6, inptr6, K); | |||
| std::memcpy(tmpbuff7, inptr7, K); | |||
| inptr0 = tmpbuff0; | |||
| inptr1 = tmpbuff1; | |||
| inptr2 = tmpbuff2; | |||
| @@ -721,14 +735,15 @@ static void gemm_s4x4x16_8x8x8_transpose_pack(dt_int8* outptr, const dt_int8* in | |||
| inptr5 = tmpbuff5; | |||
| inptr6 = tmpbuff6; | |||
| inptr7 = tmpbuff7; | |||
| transpose_4x8_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, | |||
| inptr5, inptr6, inptr7, outptr); | |||
| transpose_4x8_1_b_with_shift( | |||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
| outptr); | |||
| } | |||
| } | |||
| } | |||
| //packb | |||
| static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in, int ldin, | |||
| int x0, int xmax, int k0, int kmax) { | |||
| // packb | |||
| static void gemm_s4x4x16_8x8x8_interleave_pack( | |||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||
| int8_t zerobuff[8]; | |||
| int8_t tmpbuff0[8]; | |||
| int8_t tmpbuff1[8]; | |||
| @@ -748,7 +763,7 @@ static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in, | |||
| std::memset(tmpbuff6, 0, sizeof(int8_t) * 8); | |||
| std::memset(tmpbuff7, 0, sizeof(int8_t) * 8); | |||
| const int ksize = kmax - k0; | |||
| const int ksize8 = round_up(ksize, 8) * 8; //pack to int8 *8 packto s4 *4 | |||
| const int ksize8 = round_up(ksize, 8) * 8; // pack to int8 *8 packto s4 *4 | |||
| int8_t* outptr = out; | |||
| int8_t* outptr_interleave = nullptr; | |||
| @@ -776,21 +791,22 @@ static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in, | |||
| int8_t* outptr_inner = outptr; | |||
| for (; x + 3 < xmax; x += 4) { | |||
| outptr_interleave = outptr_inner; | |||
| interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||
| inptr6, inptr7, outptr_interleave); | |||
| interleave_8x4_1_b_with_shift( | |||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
| outptr_interleave); | |||
| outptr_inner += ksize8; | |||
| } | |||
| if (x < xmax) { | |||
| int remainx = xmax - x; | |||
| std::memcpy(tmpbuff0,inptr0,remainx); | |||
| std::memcpy(tmpbuff1,inptr1,remainx); | |||
| std::memcpy(tmpbuff2,inptr2,remainx); | |||
| std::memcpy(tmpbuff3,inptr3,remainx); | |||
| std::memcpy(tmpbuff4,inptr4,remainx); | |||
| std::memcpy(tmpbuff5,inptr5,remainx); | |||
| std::memcpy(tmpbuff6,inptr6,remainx); | |||
| std::memcpy(tmpbuff7,inptr7,remainx); | |||
| std::memcpy(tmpbuff0, inptr0, remainx); | |||
| std::memcpy(tmpbuff1, inptr1, remainx); | |||
| std::memcpy(tmpbuff2, inptr2, remainx); | |||
| std::memcpy(tmpbuff3, inptr3, remainx); | |||
| std::memcpy(tmpbuff4, inptr4, remainx); | |||
| std::memcpy(tmpbuff5, inptr5, remainx); | |||
| std::memcpy(tmpbuff6, inptr6, remainx); | |||
| std::memcpy(tmpbuff7, inptr7, remainx); | |||
| inptr0 = tmpbuff0; | |||
| inptr1 = tmpbuff1; | |||
| inptr2 = tmpbuff2; | |||
| @@ -801,8 +817,9 @@ static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in, | |||
| inptr7 = tmpbuff7; | |||
| outptr_interleave = outptr_inner; | |||
| interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||
| inptr6, inptr7, outptr_interleave); | |||
| interleave_8x4_1_b_with_shift( | |||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
| outptr_interleave); | |||
| outptr_inner += ksize8; | |||
| } | |||
| outptr += 64; | |||
| @@ -847,8 +864,9 @@ static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in, | |||
| break; | |||
| } | |||
| outptr_interleave = outptr_inner; | |||
| interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||
| inptr6, inptr7, outptr_interleave); | |||
| interleave_8x4_1_b_with_shift( | |||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
| outptr_interleave); | |||
| outptr_inner += ksize8; | |||
| } | |||
| if (x < xmax) { | |||
| @@ -880,14 +898,14 @@ static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in, | |||
| } | |||
| int remainx = xmax - x; | |||
| outptr_interleave = outptr_inner; | |||
| std::memcpy(tmpbuff0,inptr0,remainx); | |||
| std::memcpy(tmpbuff1,inptr1,remainx); | |||
| std::memcpy(tmpbuff2,inptr2,remainx); | |||
| std::memcpy(tmpbuff3,inptr3,remainx); | |||
| std::memcpy(tmpbuff4,inptr4,remainx); | |||
| std::memcpy(tmpbuff5,inptr5,remainx); | |||
| std::memcpy(tmpbuff6,inptr6,remainx); | |||
| std::memcpy(tmpbuff7,inptr7,remainx); | |||
| std::memcpy(tmpbuff0, inptr0, remainx); | |||
| std::memcpy(tmpbuff1, inptr1, remainx); | |||
| std::memcpy(tmpbuff2, inptr2, remainx); | |||
| std::memcpy(tmpbuff3, inptr3, remainx); | |||
| std::memcpy(tmpbuff4, inptr4, remainx); | |||
| std::memcpy(tmpbuff5, inptr5, remainx); | |||
| std::memcpy(tmpbuff6, inptr6, remainx); | |||
| std::memcpy(tmpbuff7, inptr7, remainx); | |||
| inptr0 = tmpbuff0; | |||
| inptr1 = tmpbuff1; | |||
| inptr2 = tmpbuff2; | |||
| @@ -898,16 +916,16 @@ static void gemm_s4x4x16_8x8x8_interleave_pack(dt_int8* out, const dt_int8* in, | |||
| inptr7 = tmpbuff7; | |||
| outptr_interleave = outptr_inner; | |||
| interleave_8x4_1_b_with_shift(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||
| inptr6, inptr7, outptr_interleave); | |||
| interleave_8x4_1_b_with_shift( | |||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
| outptr_interleave); | |||
| outptr_inner += ksize8; | |||
| } | |||
| } | |||
| } | |||
| } // namespace matmul_4x4x16 | |||
| } // namespace matmul_s4_4x4x16 | |||
| } // namespace aarch64 | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -10,9 +10,9 @@ | |||
| * implied. | |||
| */ | |||
| #include "src/aarch64/matrix_mul/int4x4x16/strategy.h" | |||
| #include "src/aarch64/matrix_mul/asm/common.h" | |||
| #include "src/aarch64/matrix_mul/int4x4x16/kernel_int4_8x8x8.h" | |||
| #include "src/aarch64/matrix_mul/int4x4x16/strategy.h" | |||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/fallback/matrix_mul/gemm_common.h" | |||
| @@ -23,39 +23,38 @@ using namespace aarch64::matmul; | |||
| // ===========================gemm_s4x4x16_s4_8x8x8================================== | |||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s4x4x16_s4_8x8x8); | |||
| void gemm_s4x4x16_s4_8x8x8::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, | |||
| int ymax, int k0, int kmax, | |||
| bool transpose) const { | |||
| void gemm_s4x4x16_s4_8x8x8::pack_A( | |||
| dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax, | |||
| bool transpose) const { | |||
| if (transpose) { | |||
| matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_interleave_pack(out, in, ldin, y0, ymax, k0, | |||
| kmax); | |||
| matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_interleave_pack( | |||
| out, in, ldin, y0, ymax, k0, kmax); | |||
| } else { | |||
| matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_transpose_pack(out, in, ldin, y0, ymax, k0, | |||
| kmax); | |||
| matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_transpose_pack( | |||
| out, in, ldin, y0, ymax, k0, kmax); | |||
| } | |||
| } | |||
| void gemm_s4x4x16_s4_8x8x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||
| int xmax, int k0, int kmax, | |||
| bool transpose) const { | |||
| void gemm_s4x4x16_s4_8x8x8::pack_B( | |||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||
| bool transpose) const { | |||
| if (transpose) { | |||
| matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_transpose_pack(out, in, ldin, x0, xmax, k0, | |||
| kmax); | |||
| matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_transpose_pack( | |||
| out, in, ldin, x0, xmax, k0, kmax); | |||
| } else { | |||
| matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_interleave_pack(out, in, ldin, x0, xmax, k0, | |||
| kmax); | |||
| matmul_s4_4x4x16::gemm_s4x4x16_8x8x8_interleave_pack( | |||
| out, in, ldin, x0, xmax, k0, kmax); | |||
| } | |||
| } | |||
| void gemm_s4x4x16_s4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||
| size_t M, size_t N, size_t K, dt_int16* C, | |||
| size_t LDC, bool is_first_k, const dt_int16*, | |||
| dt_int16*) const { | |||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
| (A_dtype.enumv() == DTypeEnum::QuantizedS4 && | |||
| C_dtype.enumv() == DTypeEnum::QuantizedS16), | |||
| "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||
| C_dtype.name()); | |||
| void gemm_s4x4x16_s4_8x8x8::kern( | |||
| const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||
| dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const { | |||
| megdnn_assert( | |||
| A_dtype.enumv() == B_dtype.enumv() && | |||
| (A_dtype.enumv() == DTypeEnum::QuantizedS4 && | |||
| C_dtype.enumv() == DTypeEnum::QuantizedS16), | |||
| "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); | |||
| MEGDNN_MARK_USED_VAR(A_dtype); | |||
| MEGDNN_MARK_USED_VAR(B_dtype); | |||
| MEGDNN_MARK_USED_VAR(C_dtype); | |||
| @@ -72,16 +71,17 @@ void gemm_s4x4x16_s4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||
| size_t n = 0; | |||
| const dt_int8* cur_packB = packB; | |||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
| matmul_s4_4x4x16::s4_kern_8x8(packA, cur_packB, K, output, LDC, | |||
| is_first_k, A_INTERLEAVE, B_INTERLEAVE); | |||
| matmul_s4_4x4x16::s4_kern_8x8( | |||
| packA, cur_packB, K, output, LDC, is_first_k, A_INTERLEAVE, | |||
| B_INTERLEAVE); | |||
| output += B_INTERLEAVE; | |||
| cur_packB += K8; | |||
| } | |||
| for (; n < N; n += B_INTERLEAVE) { | |||
| matmul_s4_4x4x16::s4_kern_8x8_remain(packA, cur_packB, K, output, LDC, | |||
| is_first_k, A_INTERLEAVE, | |||
| std::min<size_t>(N - n, B_INTERLEAVE)); | |||
| matmul_s4_4x4x16::s4_kern_8x8_remain( | |||
| packA, cur_packB, K, output, LDC, is_first_k, A_INTERLEAVE, | |||
| std::min<size_t>(N - n, B_INTERLEAVE)); | |||
| output += B_INTERLEAVE; | |||
| cur_packB += K8; | |||
| } | |||
| @@ -94,10 +94,10 @@ void gemm_s4x4x16_s4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||
| size_t n = 0; | |||
| const dt_int8* cur_packB = packB; | |||
| for (; n < N; n += B_INTERLEAVE) { | |||
| matmul_s4_4x4x16::s4_kern_8x8_remain(packA, cur_packB, K, output, LDC, | |||
| is_first_k, | |||
| std::min<size_t>(M - m, A_INTERLEAVE), | |||
| std::min<size_t>(N - n, B_INTERLEAVE)); | |||
| matmul_s4_4x4x16::s4_kern_8x8_remain( | |||
| packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(M - m, A_INTERLEAVE), | |||
| std::min<size_t>(N - n, B_INTERLEAVE)); | |||
| output += B_INTERLEAVE; | |||
| cur_packB += K8; | |||
| } | |||
| @@ -105,5 +105,4 @@ void gemm_s4x4x16_s4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||
| } | |||
| } | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -17,8 +17,8 @@ namespace megdnn { | |||
| namespace aarch64 { | |||
| namespace matmul { | |||
| MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int16, dt_int16, 8, 8, 8, false, true, | |||
| gemm_s4x4x16_s4_8x8x8); | |||
| MEGDNN_REG_GEMM_STRATEGY( | |||
| dt_int8, dt_int16, dt_int16, 8, 8, 8, false, true, gemm_s4x4x16_s4_8x8x8); | |||
| } // namespace matmul | |||
| } // namespace aarch64 | |||
| @@ -51,8 +51,9 @@ namespace matmul_4x4x16 { | |||
| * Accumulator | |||
| */ | |||
| static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
| int32_t* output, int LDC, bool is_first_k) { | |||
| static void kern_4x4( | |||
| const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||
| bool is_first_k) { | |||
| K /= 16; | |||
| const int8_t* a_ptr = packA; | |||
| const int8_t* b_ptr = packB; | |||
| @@ -472,9 +473,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
| ); | |||
| } | |||
| static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K, | |||
| int32_t* output, int LDC, bool is_first_k, | |||
| int m_remain, int n_remain) { | |||
| static void kern_4x4_remain( | |||
| const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||
| bool is_first_k, int m_remain, int n_remain) { | |||
| megdnn_assert(K > 0); | |||
| K /= 16; | |||
| const int8_t* a_ptr = packA; | |||
| @@ -655,16 +656,14 @@ static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K, | |||
| STORE_C | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||
| [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||
| [output] "+r"(output), [m_remain] "+r"(m_remain), | |||
| [n_remain] "+r"(n_remain) | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||
| [K] "+r"(K), [LDC] "+r"(LDC), [output] "+r"(output), | |||
| [m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||
| "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||
| "v29", "v30", "v31", "x0", "x1", "x2", "x3", "x4", "x5", "cc", | |||
| "memory"); | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||
| "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||
| "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", | |||
| "x0", "x1", "x2", "x3", "x4", "x5", "cc", "memory"); | |||
| #undef LOAD_LINE | |||
| #undef LOAD_C | |||
| @@ -672,8 +671,9 @@ static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K, | |||
| #undef STORE_C | |||
| } | |||
| static void gemm_s8_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||
| int ldin, int y0, int ymax, int k0, int kmax) { | |||
| static void gemm_s8_4x4_pack_A_n( | |||
| dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||
| int kmax) { | |||
| int8_t zerobuff[16]; | |||
| std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||
| @@ -716,9 +716,11 @@ static void gemm_s8_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||
| if (y + 3 >= ymax) { | |||
| switch (y + 3 - ymax) { | |||
| case 2: | |||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr1 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 1: | |||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr2 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 0: | |||
| inptr3 = zerobuff; | |||
| break; | |||
| @@ -734,9 +736,11 @@ static void gemm_s8_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||
| if (y + 3 >= ymax) { | |||
| switch (y + 3 - ymax) { | |||
| case 2: | |||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr1 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 1: | |||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr2 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 0: | |||
| inptr3 = zerobuff; | |||
| break; | |||
| @@ -749,8 +753,8 @@ static void gemm_s8_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||
| } | |||
| } | |||
| static void gemm_s8_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||
| int x0, int xmax, int k0, int kmax) { | |||
| static void gemm_s8_4x4_pack_B_n( | |||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||
| int8_t zerobuff[16]; | |||
| std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||
| const int ksize = kmax - k0; | |||
| @@ -777,19 +781,26 @@ static void gemm_s8_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||
| if (remain >= 0) { | |||
| switch (remain) { | |||
| case 7: | |||
| inptr0 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr0 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 6: | |||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr1 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 5: | |||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr2 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 4: | |||
| inptr3 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr3 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 3: | |||
| inptr4 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr4 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 2: | |||
| inptr5 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr5 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 1: | |||
| inptr6 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr6 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 0: | |||
| inptr7 = zerobuff; | |||
| break; | |||
| @@ -798,9 +809,9 @@ static void gemm_s8_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||
| } | |||
| } | |||
| transpose_4x16_1_b_helper(inptr0, inptr1, inptr2, inptr3, | |||
| inptr4, inptr5, inptr6, inptr7, | |||
| outptr_inner); | |||
| transpose_4x16_1_b_helper( | |||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
| outptr_inner); | |||
| outptr_inner += ksize4; | |||
| } | |||
| @@ -808,19 +819,26 @@ static void gemm_s8_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||
| if (remain >= 0) { | |||
| switch (remain) { | |||
| case 7: | |||
| inptr0 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr0 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 6: | |||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr1 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 5: | |||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr2 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 4: | |||
| inptr3 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr3 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 3: | |||
| inptr4 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr4 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 2: | |||
| inptr5 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr5 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 1: | |||
| inptr6 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr6 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 0: | |||
| inptr7 = zerobuff; | |||
| break; | |||
| @@ -42,8 +42,9 @@ namespace matmul_8x8x8 { | |||
| * Accumulator | |||
| */ | |||
| static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||
| int32_t* output, int LDC, bool is_first_k) { | |||
| static void kern_8x8( | |||
| const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||
| bool is_first_k) { | |||
| K /= 8; | |||
| const int8_t* a_ptr = packA; | |||
| const int8_t* b_ptr = packB; | |||
| @@ -272,14 +273,13 @@ static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||
| "stp q18, q19, [x5]\n" | |||
| "stp q20, q21, [x6]\n" | |||
| "stp q22, q23, [x7]\n" | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||
| [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||
| [output] "+r"(output) | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||
| [K] "+r"(K), [LDC] "+r"(LDC), [output] "+r"(output) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||
| "v20", "v21", "v22", "v23", "v26", "v27", "x1", | |||
| "x2", "x3", "x4", "x5", "x6", "x7", "cc", "memory"); | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||
| "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||
| "v22", "v23", "v26", "v27", "x1", "x2", "x3", "x4", "x5", "x6", "x7", | |||
| "cc", "memory"); | |||
| } | |||
| /** | |||
| @@ -309,9 +309,9 @@ static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||
| * Accumulator | |||
| */ | |||
| static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||
| int32_t* output, int LDC, bool is_first_k, | |||
| size_t n_remain) { | |||
| static void kern_8x4( | |||
| const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||
| bool is_first_k, size_t n_remain) { | |||
| K /= 8; | |||
| const int8_t* a_ptr = packA; | |||
| const int8_t* b_ptr = packB; | |||
| @@ -520,16 +520,14 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||
| "cbnz %w[K], 2b\n" | |||
| "3:\n" STORE_C | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||
| [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||
| [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), | |||
| [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||
| [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), | |||
| [outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), [x0] "+r"(x0), | |||
| [n_remain] "+r"(n_remain) | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||
| [K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), | |||
| [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||
| [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6), | |||
| [outptr7] "=r"(outptr7), [x0] "+r"(x0), [n_remain] "+r"(n_remain) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "cc", "memory"); | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||
| "v12", "v13", "v14", "v15", "v16", "v17", "cc", "memory"); | |||
| #undef LOAD_LINE | |||
| #undef LOAD_C | |||
| @@ -559,9 +557,9 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||
| * Accumulator | |||
| */ | |||
| static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, | |||
| int32_t* output, int LDC, bool is_first_k, | |||
| size_t m_remain) { | |||
| static void kern_4x8( | |||
| const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||
| bool is_first_k, size_t m_remain) { | |||
| K /= 8; | |||
| const int8_t* a_ptr = packA; | |||
| const int8_t* b_ptr = packB; | |||
| @@ -724,14 +722,13 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, | |||
| "cbnz %w[K], 2b\n" | |||
| "3:\n" STORE_C | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||
| [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||
| [outptr0] "+r"(outptr0), | |||
| [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), | |||
| [outptr3] "=r"(outptr3), [x0] "+r"(x0), [m_remain] "+r"(m_remain) | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||
| [K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), | |||
| [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||
| [x0] "+r"(x0), [m_remain] "+r"(m_remain) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
| "v11", "v12", "v13", "cc", "memory"); | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||
| "v12", "v13", "cc", "memory"); | |||
| #undef LOAD_LINE | |||
| #undef LOAD_C | |||
| @@ -762,9 +759,9 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, | |||
| * Accumulator | |||
| */ | |||
| static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
| int32_t* output, int LDC, bool is_first_k, size_t m_remain, | |||
| size_t n_remain) { | |||
| static void kern_4x4( | |||
| const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||
| bool is_first_k, size_t m_remain, size_t n_remain) { | |||
| K /= 8; | |||
| const int8_t* a_ptr = packA; | |||
| const int8_t* b_ptr = packB; | |||
| @@ -922,11 +919,10 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
| "cbnz %w[K], 2b\n" | |||
| "3:\n" STORE_C | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||
| [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||
| [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), | |||
| [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [x0] "+r"(x0), | |||
| [x1] "+r"(x1), [m_remain] "+r"(m_remain), | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||
| [K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), | |||
| [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||
| [x0] "+r"(x0), [x1] "+r"(x1), [m_remain] "+r"(m_remain), | |||
| [n_remain] "+r"(n_remain) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v11", "cc", | |||
| @@ -938,8 +934,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
| #undef STORE_C | |||
| } | |||
| static void gemm_s8_8x8_pack_A_n(int8_t* outptr, const int8_t* inptr, int ldin, | |||
| int y0, int ymax, int k0, int kmax) { | |||
| static void gemm_s8_8x8_pack_A_n( | |||
| int8_t* outptr, const int8_t* inptr, int ldin, int y0, int ymax, int k0, | |||
| int kmax) { | |||
| int8_t zerobuff[16]; | |||
| std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||
| @@ -965,13 +962,15 @@ static void gemm_s8_8x8_pack_A_n(int8_t* outptr, const int8_t* inptr, int ldin, | |||
| int K = kmax - k0; | |||
| for (; K > 15; K -= 16) { | |||
| interleave_8x8_2_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||
| inptr6, inptr7, outptr); | |||
| interleave_8x8_2_b( | |||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
| outptr); | |||
| } | |||
| if (K > 0) { | |||
| interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||
| inptr7, outptr, 8, K); | |||
| interleave_8( | |||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
| outptr, 8, K); | |||
| } | |||
| } | |||
| @@ -991,9 +990,11 @@ static void gemm_s8_8x8_pack_A_n(int8_t* outptr, const int8_t* inptr, int ldin, | |||
| if (y + 3 >= ymax) { | |||
| switch (y + 3 - ymax) { | |||
| case 2: | |||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr1 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 1: | |||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr2 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 0: | |||
| inptr3 = zerobuff; | |||
| break; | |||
| @@ -1009,9 +1010,11 @@ static void gemm_s8_8x8_pack_A_n(int8_t* outptr, const int8_t* inptr, int ldin, | |||
| if (y + 3 >= ymax) { | |||
| switch (y + 3 - ymax) { | |||
| case 2: | |||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr1 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 1: | |||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr2 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 0: | |||
| inptr3 = zerobuff; | |||
| break; | |||
| @@ -1024,9 +1027,8 @@ static void gemm_s8_8x8_pack_A_n(int8_t* outptr, const int8_t* inptr, int ldin, | |||
| } | |||
| } | |||
| static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in, | |||
| int ldin, int x0, int xmax, int k0, | |||
| int kmax) { | |||
| static void gemm_s8_8x8_transpose_pack_A_n( | |||
| int8_t* out, const int8_t* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||
| int8_t zerobuff[16]; | |||
| std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||
| const int ksize = kmax - k0; | |||
| @@ -1063,17 +1065,23 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in, | |||
| if (k + 7 >= kmax) { | |||
| switch (k + 7 - kmax) { | |||
| case 6: | |||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr1 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 5: | |||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr2 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 4: | |||
| inptr3 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr3 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 3: | |||
| inptr4 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr4 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 2: | |||
| inptr5 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr5 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 1: | |||
| inptr6 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr6 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 0: | |||
| inptr7 = zerobuff; | |||
| break; | |||
| @@ -1081,8 +1089,9 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in, | |||
| megdnn_assert(0); | |||
| } | |||
| } | |||
| transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||
| inptr6, inptr7, outptr); | |||
| transpose_8x8_1_b( | |||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
| outptr); | |||
| outptr += ksize8; | |||
| } | |||
| @@ -1091,17 +1100,23 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in, | |||
| if (k + 7 >= kmax) { | |||
| switch (k + 7 - kmax) { | |||
| case 6: | |||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr1 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 5: | |||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr2 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 4: | |||
| inptr3 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr3 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 3: | |||
| inptr4 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr4 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 2: | |||
| inptr5 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr5 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 1: | |||
| inptr6 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr6 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 0: | |||
| inptr7 = zerobuff; | |||
| break; | |||
| @@ -1110,8 +1125,9 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in, | |||
| } | |||
| } | |||
| transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||
| inptr7, outptr, 4, 4); | |||
| transpose_8( | |||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
| outptr, 4, 4); | |||
| outptr += ksize4; | |||
| } | |||
| @@ -1119,17 +1135,23 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in, | |||
| if (k + 7 >= kmax) { | |||
| switch (k + 7 - kmax) { | |||
| case 6: | |||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr1 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 5: | |||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr2 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 4: | |||
| inptr3 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr3 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 3: | |||
| inptr4 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr4 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 2: | |||
| inptr5 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr5 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 1: | |||
| inptr6 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr6 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 0: | |||
| inptr7 = zerobuff; | |||
| break; | |||
| @@ -1138,8 +1160,9 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in, | |||
| } | |||
| } | |||
| transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||
| inptr7, outptr, 4, xmax - x); | |||
| transpose_8( | |||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
| outptr, 4, xmax - x); | |||
| } | |||
| outptr_base += 8 * 8; | |||
| @@ -1147,8 +1170,8 @@ static void gemm_s8_8x8_transpose_pack_A_n(int8_t* out, const int8_t* in, | |||
| } | |||
| } | |||
| static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin, | |||
| int x0, int xmax, int k0, int kmax) { | |||
| static void gemm_s8_8x8_pack_B_n( | |||
| int8_t* out, const int8_t* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||
| int8_t zerobuff[16]; | |||
| std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||
| const int ksize = kmax - k0; | |||
| @@ -1186,17 +1209,23 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin, | |||
| if (k + 7 >= kmax) { | |||
| switch (k + 7 - kmax) { | |||
| case 6: | |||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr1 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 5: | |||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr2 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 4: | |||
| inptr3 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr3 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 3: | |||
| inptr4 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr4 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 2: | |||
| inptr5 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr5 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 1: | |||
| inptr6 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr6 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 0: | |||
| inptr7 = zerobuff; | |||
| break; | |||
| @@ -1205,8 +1234,9 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin, | |||
| } | |||
| } | |||
| outptr_interleave = outptr; | |||
| interleave_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||
| inptr6, inptr7, outptr_interleave); | |||
| interleave_8x8_1_b( | |||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
| outptr_interleave); | |||
| outptr += ksize8; | |||
| } | |||
| @@ -1215,17 +1245,23 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin, | |||
| if (k + 7 >= kmax) { | |||
| switch (k + 7 - kmax) { | |||
| case 6: | |||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr1 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 5: | |||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr2 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 4: | |||
| inptr3 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr3 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 3: | |||
| inptr4 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr4 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 2: | |||
| inptr5 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr5 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 1: | |||
| inptr6 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr6 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 0: | |||
| inptr7 = zerobuff; | |||
| break; | |||
| @@ -1235,8 +1271,9 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin, | |||
| } | |||
| outptr_interleave = outptr; | |||
| interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||
| inptr7, outptr_interleave, 4, 4); | |||
| interleave_8( | |||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
| outptr_interleave, 4, 4); | |||
| outptr += ksize4; | |||
| } | |||
| @@ -1244,17 +1281,23 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin, | |||
| if (k + 7 >= kmax) { | |||
| switch (k + 7 - kmax) { | |||
| case 6: | |||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr1 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 5: | |||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr2 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 4: | |||
| inptr3 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr3 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 3: | |||
| inptr4 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr4 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 2: | |||
| inptr5 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr5 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 1: | |||
| inptr6 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr6 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 0: | |||
| inptr7 = zerobuff; | |||
| break; | |||
| @@ -1264,8 +1307,9 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin, | |||
| } | |||
| outptr_interleave = outptr; | |||
| interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||
| inptr7, outptr_interleave, 4, xmax - x); | |||
| interleave_8( | |||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
| outptr_interleave, 4, xmax - x); | |||
| } | |||
| outptr_base += 8 * 8; | |||
| @@ -1273,9 +1317,9 @@ static void gemm_s8_8x8_pack_B_n(int8_t* out, const int8_t* in, int ldin, | |||
| } | |||
| } | |||
| static void gemm_s8_8x8_transpose_pack_B_n(int8_t* outptr, const int8_t* inptr, | |||
| int ldin, int y0, int ymax, int k0, | |||
| int kmax) { | |||
| static void gemm_s8_8x8_transpose_pack_B_n( | |||
| int8_t* outptr, const int8_t* inptr, int ldin, int y0, int ymax, int k0, | |||
| int kmax) { | |||
| int8_t zerobuff[16]; | |||
| std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||
| constexpr int interleave4 = 32; | |||
| @@ -1303,14 +1347,16 @@ static void gemm_s8_8x8_transpose_pack_B_n(int8_t* outptr, const int8_t* inptr, | |||
| int K = kmax - k0; | |||
| for (; K > 7; K -= 8) { | |||
| transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||
| inptr6, inptr7, outptr); | |||
| transpose_8x8_1_b( | |||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
| outptr); | |||
| outptr += interleave8; | |||
| } | |||
| if (K > 0) { | |||
| transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||
| inptr7, outptr, 8, K); | |||
| transpose_8( | |||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
| outptr, 8, K); | |||
| outptr += interleave8; | |||
| } | |||
| } | |||
| @@ -1331,9 +1377,11 @@ static void gemm_s8_8x8_transpose_pack_B_n(int8_t* outptr, const int8_t* inptr, | |||
| if (y + 3 >= ymax) { | |||
| switch (y + 3 - ymax) { | |||
| case 2: | |||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr1 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 1: | |||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr2 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 0: | |||
| inptr3 = zerobuff; | |||
| break; | |||
| @@ -1350,9 +1398,11 @@ static void gemm_s8_8x8_transpose_pack_B_n(int8_t* outptr, const int8_t* inptr, | |||
| if (y + 3 >= ymax) { | |||
| switch (y + 3 - ymax) { | |||
| case 2: | |||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr1 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 1: | |||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr2 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 0: | |||
| inptr3 = zerobuff; | |||
| break; | |||
| @@ -50,8 +50,9 @@ namespace matmul_mk4_4x4x16 { | |||
| * Accumulator | |||
| */ | |||
| static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
| int32_t* output, bool is_first_k) { | |||
| static void kern_4x4( | |||
| const int8_t* packA, const int8_t* packB, int K, int32_t* output, | |||
| bool is_first_k) { | |||
| K = div_ceil(K, 16); | |||
| const int8_t* a_ptr = packA; | |||
| const int8_t* b_ptr = packB; | |||
| @@ -366,17 +367,18 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
| "6:\n" | |||
| "st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[output]], #64\n" | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||
| [is_first_k] "+r"(is_first_k), [k] "+r"(K), [output] "+r"(output) | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||
| [k] "+r"(K), [output] "+r"(output) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||
| "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||
| "v29", "v30", "v31", "cc", "memory"); | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||
| "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||
| "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", | |||
| "cc", "memory"); | |||
| } | |||
| static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K, | |||
| int32_t* output, bool is_first_k, size_t remain_n) { | |||
| static void kern_4x4_remain( | |||
| const int8_t* packA, const int8_t* packB, int K, int32_t* output, | |||
| bool is_first_k, size_t remain_n) { | |||
| K = div_ceil(K, 16); | |||
| const int8_t* a_ptr = packA; | |||
| const int8_t* b_ptr = packB; | |||
| @@ -718,26 +720,27 @@ static void kern_4x4_remain(const int8_t* packA, const int8_t* packB, int K, | |||
| "7:\n" | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||
| [remain_n] "+r"(remain_n), [is_first_k] "+r"(is_first_k), | |||
| [k] "+r"(K), [output] "+r"(output) | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [remain_n] "+r"(remain_n), | |||
| [is_first_k] "+r"(is_first_k), [k] "+r"(K), [output] "+r"(output) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||
| "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||
| "v29", "v30", "v31", "cc", "memory"); | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||
| "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||
| "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", | |||
| "cc", "memory"); | |||
| } | |||
| static void gemm_mk4_s8_4x4_pack_A(dt_int8* outptr, const dt_int8* inptr, | |||
| int ldin, int y0, int ymax, int k0, | |||
| int kmax) { | |||
| static void gemm_mk4_s8_4x4_pack_A( | |||
| dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||
| int kmax) { | |||
| //! pack form {oc/4, ic/4, 4(ic), 4(oc)} to {oc/4, ic/16, 4(oc), 16(ic)} | |||
| int8_t zerobuff[4][64]; | |||
| std::memset(zerobuff, 0, sizeof(int8_t) * 64 * 4); | |||
| megdnn_assert(ymax % 4 == 0 && y0 % 4 == 0 && (ymax - y0) % 4 == 0, | |||
| "mk4 matmul with m is not times of 4"); | |||
| megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0 && (kmax - k0) % 4 == 0, | |||
| "mk4 matmul with k is not times of 4"); | |||
| megdnn_assert( | |||
| ymax % 4 == 0 && y0 % 4 == 0 && (ymax - y0) % 4 == 0, | |||
| "mk4 matmul with m is not times of 4"); | |||
| megdnn_assert( | |||
| kmax % 4 == 0 && k0 % 4 == 0 && (kmax - k0) % 4 == 0, | |||
| "mk4 matmul with k is not times of 4"); | |||
| size_t roundk = round_up(kmax - k0, 16); | |||
| size_t out_offset = roundk * 4; | |||
| int y = y0; | |||
| @@ -754,8 +757,8 @@ static void gemm_mk4_s8_4x4_pack_A(dt_int8* outptr, const dt_int8* inptr, | |||
| prefetch_2x(inptr3); | |||
| int K = kmax - k0; | |||
| for (; K > 15; K -= 16) { | |||
| transpose_interleave_4x4_4_b(inptr0, inptr1, inptr2, inptr3, output, | |||
| out_offset); | |||
| transpose_interleave_4x4_4_b( | |||
| inptr0, inptr1, inptr2, inptr3, output, out_offset); | |||
| output += 64; | |||
| } | |||
| if (K > 0) { | |||
| @@ -767,8 +770,8 @@ static void gemm_mk4_s8_4x4_pack_A(dt_int8* outptr, const dt_int8* inptr, | |||
| inptr1 = zerobuff[1]; | |||
| inptr2 = zerobuff[2]; | |||
| inptr3 = zerobuff[3]; | |||
| transpose_interleave_4x4_4_b(inptr0, inptr1, inptr2, inptr3, output, | |||
| out_offset); | |||
| transpose_interleave_4x4_4_b( | |||
| inptr0, inptr1, inptr2, inptr3, output, out_offset); | |||
| output += 64; | |||
| } | |||
| } | |||
| @@ -790,21 +793,21 @@ static void gemm_mk4_s8_4x4_pack_A(dt_int8* outptr, const dt_int8* inptr, | |||
| } | |||
| } | |||
| static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||
| int x0, int xmax, int k0, int kmax) { | |||
| static void gemm_mk4_s8_4x4_pack_B( | |||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||
| int32_t zerobuff[4]; | |||
| std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||
| const int ksize = kmax - k0; | |||
| const int ICB = (ksize) / 4; | |||
| const int ksize4 = round_up<int>(ICB, 4) * 4; | |||
| int32_t* outptr = reinterpret_cast<int32_t*>(out); | |||
| megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0 && ksize % 4 == 0, | |||
| "mk4 matmul with k is not times of 4"); | |||
| megdnn_assert( | |||
| kmax % 4 == 0 && k0 % 4 == 0 && ksize % 4 == 0, | |||
| "mk4 matmul with k is not times of 4"); | |||
| int k = k0 / 4; | |||
| for (; k + 3 < ICB; k += 4) { | |||
| const int32_t* inptr0 = | |||
| reinterpret_cast<const int32_t*>(in + k * ldin + x0); | |||
| const int32_t* inptr0 = reinterpret_cast<const int32_t*>(in + k * ldin + x0); | |||
| const int32_t* inptr1 = | |||
| reinterpret_cast<const int32_t*>(in + (k + 1) * ldin + x0); | |||
| const int32_t* inptr2 = | |||
| @@ -829,8 +832,7 @@ static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||
| outptr += 4 * 4; | |||
| } | |||
| if (k < ICB) { | |||
| const int32_t* inptr0 = | |||
| reinterpret_cast<const int32_t*>(in + k * ldin + x0); | |||
| const int32_t* inptr0 = reinterpret_cast<const int32_t*>(in + k * ldin + x0); | |||
| const int32_t* inptr1 = | |||
| reinterpret_cast<const int32_t*>(in + (k + 1) * ldin + x0); | |||
| const int32_t* inptr2 = | |||
| @@ -844,9 +846,11 @@ static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||
| if (k + 3 >= ICB) { | |||
| switch (k + 3 - ICB) { | |||
| case 2: | |||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr1 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 1: | |||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr2 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 0: | |||
| inptr3 = zerobuff; | |||
| break; | |||
| @@ -861,9 +865,11 @@ static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||
| if (k + 3 >= ICB) { | |||
| switch (k + 3 - ICB) { | |||
| case 2: | |||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr1 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 1: | |||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr2 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 0: | |||
| inptr3 = zerobuff; | |||
| break; | |||
| @@ -882,7 +888,7 @@ static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||
| } | |||
| } | |||
| } // namespace matmul_4x4x16 | |||
| } // namespace matmul_mk4_4x4x16 | |||
| } // namespace aarch64 | |||
| } // namespace megdnn | |||
| @@ -24,20 +24,19 @@ using namespace aarch64::matmul; | |||
| ///////////////////////// gemm_s8_4x4 //////////////////////////////////// | |||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_4x4); | |||
| void gemm_s8_4x4::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin, | |||
| int y0, int ymax, int k0, int kmax, | |||
| bool transpose) const { | |||
| void gemm_s8_4x4::pack_A( | |||
| dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||
| int kmax, bool transpose) const { | |||
| if (transpose) { | |||
| matmul_4x4x16::gemm_s8_4x4_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, | |||
| kmax); | |||
| matmul_4x4x16::gemm_s8_4x4_pack_B_n(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||
| } else { | |||
| matmul_4x4x16::gemm_s8_4x4_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, | |||
| kmax); | |||
| matmul_4x4x16::gemm_s8_4x4_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||
| } | |||
| } | |||
| void gemm_s8_4x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||
| int xmax, int k0, int kmax, bool transpose) const { | |||
| void gemm_s8_4x4::pack_B( | |||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||
| bool transpose) const { | |||
| if (transpose) { | |||
| matmul_4x4x16::gemm_s8_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, kmax); | |||
| } else { | |||
| @@ -45,16 +44,16 @@ void gemm_s8_4x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||
| } | |||
| } | |||
| void gemm_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||
| size_t N, size_t K, dt_int32* C, size_t LDC, | |||
| bool is_first_k, const dt_int32*, dt_int32*) const { | |||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
| ((A_dtype.enumv() == DTypeEnum::Int8 && | |||
| C_dtype.enumv() == DTypeEnum::Int32) || | |||
| (A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
| C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||
| "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||
| C_dtype.name()); | |||
| void gemm_s8_4x4::kern( | |||
| const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||
| dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const { | |||
| megdnn_assert( | |||
| A_dtype.enumv() == B_dtype.enumv() && | |||
| ((A_dtype.enumv() == DTypeEnum::Int8 && | |||
| C_dtype.enumv() == DTypeEnum::Int32) || | |||
| (A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
| C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||
| "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); | |||
| MEGDNN_MARK_USED_VAR(A_dtype); | |||
| MEGDNN_MARK_USED_VAR(B_dtype); | |||
| MEGDNN_MARK_USED_VAR(C_dtype); | |||
| @@ -72,16 +71,15 @@ void gemm_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||
| size_t n = 0; | |||
| const dt_int8* cur_packB = packB; | |||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
| matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC, | |||
| is_first_k); | |||
| matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k); | |||
| output += B_INTERLEAVE; | |||
| cur_packB += K4; | |||
| } | |||
| for (; n < N; n += B_INTERLEAVE) { | |||
| matmul_4x4x16::kern_4x4_remain(packA, cur_packB, K, output, LDC, | |||
| is_first_k, 4, | |||
| std::min<size_t>(N - n, 4)); | |||
| matmul_4x4x16::kern_4x4_remain( | |||
| packA, cur_packB, K, output, LDC, is_first_k, 4, | |||
| std::min<size_t>(N - n, 4)); | |||
| output += B_INTERLEAVE; | |||
| cur_packB += K4; | |||
| } | |||
| @@ -107,33 +105,32 @@ void gemm_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||
| ///////////////////////// gemm_mk4_s8_4x4 //////////////////////////////////// | |||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_s8_4x4); | |||
| void gemm_mk4_s8_4x4::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin, | |||
| int y0, int ymax, int k0, int kmax, | |||
| bool transpose) const { | |||
| megdnn_assert(!transpose, | |||
| "the gemm_mk4_s8_4x4 strategy is not support transpose A"); | |||
| matmul_mk4_4x4x16::gemm_mk4_s8_4x4_pack_A(outptr, inptr, ldin, y0, ymax, k0, | |||
| kmax); | |||
| void gemm_mk4_s8_4x4::pack_A( | |||
| dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||
| int kmax, bool transpose) const { | |||
| megdnn_assert( | |||
| !transpose, "the gemm_mk4_s8_4x4 strategy is not support transpose A"); | |||
| matmul_mk4_4x4x16::gemm_mk4_s8_4x4_pack_A(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||
| } | |||
| void gemm_mk4_s8_4x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||
| int xmax, int k0, int kmax, bool transpose) const { | |||
| megdnn_assert(!transpose, | |||
| "the gemm_mk4_s8_4x4 strategy is not support transpose B"); | |||
| matmul_mk4_4x4x16::gemm_mk4_s8_4x4_pack_B(out, in, ldin, x0, xmax, k0, | |||
| kmax); | |||
| void gemm_mk4_s8_4x4::pack_B( | |||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||
| bool transpose) const { | |||
| megdnn_assert( | |||
| !transpose, "the gemm_mk4_s8_4x4 strategy is not support transpose B"); | |||
| matmul_mk4_4x4x16::gemm_mk4_s8_4x4_pack_B(out, in, ldin, x0, xmax, k0, kmax); | |||
| } | |||
| void gemm_mk4_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||
| size_t N, size_t K, dt_int32* C, size_t LDC, | |||
| bool is_first_k, const dt_int32*, dt_int32*) const { | |||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
| ((A_dtype.enumv() == DTypeEnum::Int8 && | |||
| C_dtype.enumv() == DTypeEnum::Int32) || | |||
| (A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
| C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||
| "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||
| C_dtype.name()); | |||
| void gemm_mk4_s8_4x4::kern( | |||
| const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||
| dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const { | |||
| megdnn_assert( | |||
| A_dtype.enumv() == B_dtype.enumv() && | |||
| ((A_dtype.enumv() == DTypeEnum::Int8 && | |||
| C_dtype.enumv() == DTypeEnum::Int32) || | |||
| (A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
| C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||
| "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); | |||
| MEGDNN_MARK_USED_VAR(A_dtype); | |||
| MEGDNN_MARK_USED_VAR(B_dtype); | |||
| MEGDNN_MARK_USED_VAR(C_dtype); | |||
| @@ -151,57 +148,54 @@ void gemm_mk4_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||
| size_t n = 0; | |||
| const dt_int8* cur_packB = packB; | |||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
| matmul_mk4_4x4x16::kern_4x4(packA, cur_packB, K, output, | |||
| is_first_k); | |||
| matmul_mk4_4x4x16::kern_4x4(packA, cur_packB, K, output, is_first_k); | |||
| output += B_INTERLEAVE * 4; | |||
| cur_packB += K4; | |||
| } | |||
| if (n < N) { | |||
| matmul_mk4_4x4x16::kern_4x4_remain(packA, cur_packB, K, output, | |||
| is_first_k, N - n); | |||
| matmul_mk4_4x4x16::kern_4x4_remain( | |||
| packA, cur_packB, K, output, is_first_k, N - n); | |||
| } | |||
| packA += K4; | |||
| } | |||
| } | |||
| ///////////////////////// gemm_s8_8x8 //////////////////////////////////// | |||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x8); | |||
| void gemm_s8_8x8::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin, | |||
| int y0, int ymax, int k0, int kmax, | |||
| bool transpose) const { | |||
| void gemm_s8_8x8::pack_A( | |||
| dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||
| int kmax, bool transpose) const { | |||
| if (transpose) { | |||
| matmul_8x8x8::gemm_s8_8x8_transpose_pack_A_n(outptr, inptr, ldin, y0, | |||
| ymax, k0, kmax); | |||
| matmul_8x8x8::gemm_s8_8x8_transpose_pack_A_n( | |||
| outptr, inptr, ldin, y0, ymax, k0, kmax); | |||
| } else { | |||
| matmul_8x8x8::gemm_s8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, | |||
| kmax); | |||
| matmul_8x8x8::gemm_s8_8x8_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||
| } | |||
| } | |||
| void gemm_s8_8x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||
| int xmax, int k0, int kmax, bool transpose) const { | |||
| void gemm_s8_8x8::pack_B( | |||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||
| bool transpose) const { | |||
| if (transpose) { | |||
| matmul_8x8x8::gemm_s8_8x8_transpose_pack_B_n(out, in, ldin, x0, xmax, | |||
| k0, kmax); | |||
| matmul_8x8x8::gemm_s8_8x8_transpose_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | |||
| } else { | |||
| matmul_8x8x8::gemm_s8_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | |||
| } | |||
| } | |||
| void gemm_s8_8x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||
| size_t N, size_t K, dt_int32* C, size_t LDC, | |||
| bool is_first_k, const dt_int32*, dt_int32*) const { | |||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
| ((A_dtype.enumv() == DTypeEnum::Int8 && | |||
| C_dtype.enumv() == DTypeEnum::Int32) || | |||
| (A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
| C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||
| "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||
| C_dtype.name()); | |||
| void gemm_s8_8x8::kern( | |||
| const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||
| dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const { | |||
| megdnn_assert( | |||
| A_dtype.enumv() == B_dtype.enumv() && | |||
| ((A_dtype.enumv() == DTypeEnum::Int8 && | |||
| C_dtype.enumv() == DTypeEnum::Int32) || | |||
| (A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
| C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||
| "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); | |||
| MEGDNN_MARK_USED_VAR(A_dtype); | |||
| MEGDNN_MARK_USED_VAR(B_dtype); | |||
| MEGDNN_MARK_USED_VAR(C_dtype); | |||
| @@ -220,15 +214,15 @@ void gemm_s8_8x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||
| size_t n = 0; | |||
| const dt_int8* cur_packB = packB; | |||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
| matmul_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, | |||
| is_first_k); | |||
| matmul_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, is_first_k); | |||
| output += B_INTERLEAVE; | |||
| cur_packB += K8; | |||
| } | |||
| for (; n < N; n += 4) { | |||
| matmul_8x8x8::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(N - n, 4)); | |||
| matmul_8x8x8::kern_8x4( | |||
| packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(N - n, 4)); | |||
| output += 4; | |||
| cur_packB += K4; | |||
| } | |||
| @@ -240,16 +234,17 @@ void gemm_s8_8x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||
| const dt_int8* cur_packB = packB; | |||
| size_t n = 0; | |||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
| matmul_8x8x8::kern_4x8(packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(M - m, 4)); | |||
| matmul_8x8x8::kern_4x8( | |||
| packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(M - m, 4)); | |||
| output += B_INTERLEAVE; | |||
| cur_packB += K8; | |||
| } | |||
| for (; n < N; n += 4) { | |||
| matmul_8x8x8::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(M - m, 4), | |||
| std::min<size_t>(N - n, 4)); | |||
| matmul_8x8x8::kern_4x4( | |||
| packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||
| output += 4; | |||
| cur_packB += K4; | |||
| } | |||
| @@ -16,14 +16,14 @@ namespace megdnn { | |||
| namespace aarch64 { | |||
| namespace matmul { | |||
| MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 4, 16, false, true, | |||
| gemm_s8_4x4); | |||
| MEGDNN_REG_GEMM_STRATEGY( | |||
| dt_int8, dt_int32, dt_int32, 4, 4, 16, false, true, gemm_s8_4x4); | |||
| MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 4, 16, false, false, | |||
| gemm_mk4_s8_4x4); | |||
| MEGDNN_REG_GEMM_STRATEGY( | |||
| dt_int8, dt_int32, dt_int32, 4, 4, 16, false, false, gemm_mk4_s8_4x4); | |||
| MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 8, 8, false, true, | |||
| gemm_s8_8x8); | |||
| MEGDNN_REG_GEMM_STRATEGY( | |||
| dt_int8, dt_int32, dt_int32, 8, 8, 8, false, true, gemm_s8_8x8); | |||
| } // namespace matmul | |||
| } // namespace aarch64 | |||
| @@ -52,8 +52,9 @@ namespace matmul_8x12x4 { | |||
| #if 1 | |||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||
| static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, | |||
| int32_t* output, int LDC, bool is_first_k) { | |||
| static void kern_8x12( | |||
| const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||
| bool is_first_k) { | |||
| K /= 4; | |||
| const int8_t* a_ptr = packA; | |||
| const int8_t* b_ptr = packB; | |||
| @@ -410,8 +411,9 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, | |||
| } | |||
| #else | |||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||
| static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, | |||
| int32_t* output, int LDC, bool is_first_k) { | |||
| static void kern_8x12( | |||
| const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||
| bool is_first_k) { | |||
| K /= 4; | |||
| const int8_t* a_ptr = packA; | |||
| const int8_t* b_ptr = packB; | |||
| @@ -612,18 +614,17 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, | |||
| "stp q15, q23, [%[outptr7]]\n" | |||
| "str q31, [%[outptr7], #32]\n" | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [a0] "+w"(a0), | |||
| [a1] "+w"(a1), [a0a] "+w"(a0a), [a1a] "+w"(a1a), [b0] "+w"(b0), | |||
| [b1] "+w"(b1), [b2] "+w"(b2), [k] "+r"(k), [LDC] "+r"(LDC), | |||
| [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), | |||
| [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), | |||
| [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||
| [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), | |||
| [outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7) | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [a0] "+w"(a0), [a1] "+w"(a1), | |||
| [a0a] "+w"(a0a), [a1a] "+w"(a1a), [b0] "+w"(b0), [b1] "+w"(b1), | |||
| [b2] "+w"(b2), [k] "+r"(k), [LDC] "+r"(LDC), [oddk] "+r"(oddk), | |||
| [is_first_k] "+r"(is_first_k), [outptr0] "+r"(outptr0), | |||
| [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||
| [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6), | |||
| [outptr7] "=r"(outptr7) | |||
| : | |||
| : "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", | |||
| "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", | |||
| "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory"); | |||
| : "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||
| "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||
| "v29", "v30", "v31", "cc", "memory"); | |||
| } | |||
| #endif | |||
| @@ -653,8 +654,9 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, | |||
| // | |||
| // Accumulator | |||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||
| static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, | |||
| int32_t* output, int LDC, bool is_first_k, int m_remain) { | |||
| static void kern_4x12( | |||
| const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||
| bool is_first_k, int m_remain) { | |||
| K /= 4; | |||
| const int8_t* a_ptr = packA; | |||
| const int8_t* b_ptr = packB; | |||
| @@ -796,15 +798,15 @@ static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, | |||
| "4:\n" STORE_C | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [k] "+r"(k), | |||
| [outptr0] "+r"(outptr0), [oddk] "+r"(oddk), | |||
| [is_first_k] "+r"(is_first_k), [m_remain] "+r"(m_remain), | |||
| [LDC] "+r"(LDC), [a0] "=w"(a0), [a0a] "=w"(a0a), [b0] "=w"(b0), | |||
| [b1] "=w"(b1), [b2] "=w"(b2), [b0a] "=w"(b0a), [b1a] "=w"(b1a), | |||
| [b2a] "=w"(b2a), [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), | |||
| [outptr3] "=r"(outptr3), [x0] "=r"(x0) | |||
| [outptr0] "+r"(outptr0), [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), | |||
| [m_remain] "+r"(m_remain), [LDC] "+r"(LDC), [a0] "=w"(a0), | |||
| [a0a] "=w"(a0a), [b0] "=w"(b0), [b1] "=w"(b1), [b2] "=w"(b2), | |||
| [b0a] "=w"(b0a), [b1a] "=w"(b1a), [b2a] "=w"(b2a), | |||
| [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||
| [x0] "=r"(x0) | |||
| : | |||
| : "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", | |||
| "v17", "v18", "v19", "memory", "cc"); | |||
| : "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||
| "v19", "memory", "cc"); | |||
| #undef LOAD_LINE | |||
| #undef LOAD_C | |||
| @@ -840,8 +842,9 @@ static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, | |||
| // | |||
| // Accumulator | |||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||
| static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||
| int32_t* output, int LDC, bool is_first_k, int n_remain) { | |||
| static void kern_8x4( | |||
| const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||
| bool is_first_k, int n_remain) { | |||
| K /= 4; | |||
| const int8_t* a_ptr = packA; | |||
| const int8_t* b_ptr = packB; | |||
| @@ -1004,12 +1007,11 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||
| [n_remain] "+r"(n_remain), [k] "+r"(k), [outptr0] "+r"(outptr0), | |||
| [a0] "=w"(a0), [a1] "=w"(a1), [a0a] "=w"(a0a), [a1a] "=w"(a1a), | |||
| [b0] "=w"(b0), [b0a] "=w"(b0a), [outptr1] "=r"(outptr1), | |||
| [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||
| [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), | |||
| [outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), [x0] "=r"(x0) | |||
| [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [outptr4] "=r"(outptr4), | |||
| [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), | |||
| [x0] "=r"(x0) | |||
| : | |||
| : "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "memory", | |||
| "cc"); | |||
| : "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "memory", "cc"); | |||
| #undef LOAD_LINE | |||
| #undef LOAD_C | |||
| @@ -1041,9 +1043,9 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||
| // | |||
| // Accumulator | |||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||
| static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
| int32_t* output, int LDC, bool is_first_k, int m_remain, | |||
| int n_remain) { | |||
| static void kern_4x4( | |||
| const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||
| bool is_first_k, int m_remain, int n_remain) { | |||
| K /= 4; | |||
| const int32_t* a_ptr = reinterpret_cast<const int32_t*>(packA); | |||
| const int32_t* b_ptr = reinterpret_cast<const int32_t*>(packB); | |||
| @@ -1172,10 +1174,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
| "4:\n" STORE_C | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [oddk] "+r"(oddk), | |||
| [is_first_k] "+r"(is_first_k), [n_remain] "+r"(n_remain), | |||
| [m_remain] "+r"(m_remain), [LDC] "+r"(LDC), | |||
| [outptr0] "+r"(outptr0), [k] "+r"(k), [a0] "=w"(a0), | |||
| [a0a] "=w"(a0a), [b0] "=w"(b0), [b0a] "=w"(b0a), | |||
| [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), | |||
| [m_remain] "+r"(m_remain), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), | |||
| [k] "+r"(k), [a0] "=w"(a0), [a0a] "=w"(a0a), [b0] "=w"(b0), | |||
| [b0a] "=w"(b0a), [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), | |||
| [outptr3] "=r"(outptr3), [x0] "=r"(x0), [x1] "=r"(x1) | |||
| : | |||
| : "v4", "v5", "v6", "v7", "memory", "cc"); | |||
| @@ -1186,9 +1187,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
| #undef STORE_C | |||
| } | |||
| static void gemm_s8_8x12_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||
| int ldin, int y0, int ymax, int k0, | |||
| int kmax) { | |||
| static void gemm_s8_8x12_pack_A_n( | |||
| dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||
| int kmax) { | |||
| int8_t zerobuff[16]; | |||
| std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||
| @@ -1215,13 +1216,15 @@ static void gemm_s8_8x12_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||
| int K = kmax - k0; | |||
| //! read 8 * 4 in each row | |||
| for (; K > 15; K -= 16) { | |||
| interleave_8x4_4_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||
| inptr6, inptr7, outptr); | |||
| interleave_8x4_4_b( | |||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
| outptr); | |||
| } | |||
| if (K > 0) { | |||
| interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||
| inptr7, outptr, 4, K); | |||
| interleave_8( | |||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
| outptr, 4, K); | |||
| } | |||
| } | |||
| for (; y < ymax; y += 4) { | |||
| @@ -1274,8 +1277,8 @@ static void gemm_s8_8x12_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||
| } | |||
| } | |||
| static void gemm_s8_8x12_pack_A_t(dt_int8* out, const dt_int8* in, int ldin, | |||
| int x0, int xmax, int k0, int kmax) { | |||
| static void gemm_s8_8x12_pack_A_t( | |||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||
| int8_t zerobuff[16]; | |||
| std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||
| const int ksize = kmax - k0; | |||
| @@ -1361,8 +1364,8 @@ static void gemm_s8_8x12_pack_A_t(dt_int8* out, const dt_int8* in, int ldin, | |||
| } | |||
| } | |||
| static void gemm_s8_8x12_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||
| int x0, int xmax, int k0, int kmax) { | |||
| static void gemm_s8_8x12_pack_B_n( | |||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||
| int8_t zerobuff[16]; | |||
| std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||
| const int ksize = kmax - k0; | |||
| @@ -1448,9 +1451,9 @@ static void gemm_s8_8x12_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||
| } | |||
| } | |||
| static void gemm_s8_8x12_pack_B_t(dt_int8* outptr, const dt_int8* inptr, | |||
| int ldin, int y0, int ymax, int k0, | |||
| int kmax) { | |||
| static void gemm_s8_8x12_pack_B_t( | |||
| dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||
| int kmax) { | |||
| int8_t zerobuff[16]; | |||
| std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||
| @@ -1485,15 +1488,15 @@ static void gemm_s8_8x12_pack_B_t(dt_int8* outptr, const dt_int8* inptr, | |||
| int K = kmax - k0; | |||
| //! read 12 * 4 in each row | |||
| for (; K > 15; K -= 16) { | |||
| interleave_12x4_4_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||
| inptr6, inptr7, inptr8, inptr9, inptr10, | |||
| inptr11, outptr); | |||
| interleave_12x4_4_b( | |||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
| inptr8, inptr9, inptr10, inptr11, outptr); | |||
| } | |||
| if (K > 0) { | |||
| interleave_12(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||
| inptr6, inptr7, inptr8, inptr9, inptr10, inptr11, | |||
| outptr, 4, K); | |||
| interleave_12( | |||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
| inptr8, inptr9, inptr10, inptr11, outptr, 4, K); | |||
| } | |||
| } | |||
| for (; y < ymax; y += 4) { | |||
| @@ -40,8 +40,9 @@ namespace matmul_mk4_8x12x4 { | |||
| // Accumulator | |||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||
| static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, | |||
| int32_t* output, int LDC, bool is_first_k) { | |||
| static void kern_8x12( | |||
| const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||
| bool is_first_k) { | |||
| K /= 4; | |||
| const int8_t* a_ptr = packA; | |||
| const int8_t* b_ptr = packB; | |||
| @@ -397,8 +398,9 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, | |||
| // Accumulator | |||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||
| static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, | |||
| int32_t* output, int LDC, bool is_first_k) { | |||
| static void kern_4x12( | |||
| const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||
| bool is_first_k) { | |||
| K /= 4; | |||
| const int8_t* a_ptr = packA; | |||
| const int8_t* b_ptr = packB; | |||
| @@ -514,13 +516,12 @@ static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, | |||
| "stp q16, q17, [%[outptr0], #128]\n" | |||
| "stp q18, q19, [%[outptr0], #160]\n" | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [k] "+r"(k), | |||
| [outptr0] "+r"(outptr0), [oddk] "+r"(oddk), | |||
| [is_first_k] "+r"(is_first_k), [a0] "=w"(a0), [a0a] "=w"(a0a), | |||
| [b0] "=w"(b0), [b1] "=w"(b1), [b2] "=w"(b2), [b0a] "=w"(b0a), | |||
| [b1a] "=w"(b1a), [b2a] "=w"(b2a) | |||
| [outptr0] "+r"(outptr0), [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), | |||
| [a0] "=w"(a0), [a0a] "=w"(a0a), [b0] "=w"(b0), [b1] "=w"(b1), | |||
| [b2] "=w"(b2), [b0a] "=w"(b0a), [b1a] "=w"(b1a), [b2a] "=w"(b2a) | |||
| : | |||
| : "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", | |||
| "v17", "v18", "v19", "memory", "cc"); | |||
| : "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", | |||
| "v19", "memory", "cc"); | |||
| } | |||
| // Overview of register layout: | |||
| @@ -544,8 +545,9 @@ static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, | |||
| // Accumulator | |||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||
| static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||
| int32_t* output, int LDC, bool is_first_k, int n_remain) { | |||
| static void kern_8x4( | |||
| const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||
| bool is_first_k, int n_remain) { | |||
| K /= 4; | |||
| const int8_t* a_ptr = packA; | |||
| const int8_t* b_ptr = packB; | |||
| @@ -689,11 +691,9 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||
| [oddk] "+r"(oddk), [is_first_k] "+r"(is_first_k), | |||
| [n_remain] "+r"(n_remain), [k] "+r"(k), [outptr0] "+r"(outptr0), | |||
| [a0] "=w"(a0), [a1] "=w"(a1), [a0a] "=w"(a0a), [a1a] "=w"(a1a), | |||
| [b0] "=w"(b0), [b0a] "=w"(b0a), [outptr1] "=r"(outptr1), | |||
| [x0] "=r"(x0) | |||
| [b0] "=w"(b0), [b0a] "=w"(b0a), [outptr1] "=r"(outptr1), [x0] "=r"(x0) | |||
| : | |||
| : "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "memory", | |||
| "cc"); | |||
| : "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "memory", "cc"); | |||
| #undef LOAD_LINE | |||
| #undef LOAD_C | |||
| @@ -720,8 +720,9 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||
| // Accumulator | |||
| MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||
| static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
| int32_t* output, int LDC, bool is_first_k, int n_remain) { | |||
| static void kern_4x4( | |||
| const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, | |||
| bool is_first_k, int n_remain) { | |||
| K /= 4; | |||
| const int32_t* a_ptr = reinterpret_cast<const int32_t*>(packA); | |||
| const int32_t* b_ptr = reinterpret_cast<const int32_t*>(packB); | |||
| @@ -834,10 +835,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
| "4:\n" STORE_C | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [oddk] "+r"(oddk), | |||
| [is_first_k] "+r"(is_first_k), [n_remain] "+r"(n_remain), | |||
| [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), [k] "+r"(k), | |||
| [a0] "=w"(a0), [a0a] "=w"(a0a), [b0] "=w"(b0), [b0a] "=w"(b0a), | |||
| [x0] "=r"(x0) | |||
| [is_first_k] "+r"(is_first_k), [n_remain] "+r"(n_remain), [LDC] "+r"(LDC), | |||
| [outptr0] "+r"(outptr0), [k] "+r"(k), [a0] "=w"(a0), [a0a] "=w"(a0a), | |||
| [b0] "=w"(b0), [b0a] "=w"(b0a), [x0] "=r"(x0) | |||
| : | |||
| : "v4", "v5", "v6", "v7", "memory", "cc"); | |||
| @@ -847,13 +847,11 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
| #undef STORE_C | |||
| } | |||
| static void gemm_mk4_s8_8x12_pack_A(dt_int8* outptr, const dt_int8* inptr, | |||
| int ldin, int y0, int ymax, int k0, | |||
| int kmax) { | |||
| megdnn_assert(ymax % 4 == 0 && y0 % 4 == 0, | |||
| "mk4 matmul with m is not times of 4"); | |||
| megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0, | |||
| "mk4 matmul with k is not times of 4"); | |||
| static void gemm_mk4_s8_8x12_pack_A( | |||
| dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||
| int kmax) { | |||
| megdnn_assert(ymax % 4 == 0 && y0 % 4 == 0, "mk4 matmul with m is not times of 4"); | |||
| megdnn_assert(kmax % 4 == 0 && k0 % 4 == 0, "mk4 matmul with k is not times of 4"); | |||
| int y = y0; | |||
| int start_y = y0 / 4; | |||
| for (; y + 7 < ymax; y += 8, start_y += 2) { | |||
| @@ -869,15 +867,15 @@ static void gemm_mk4_s8_8x12_pack_A(dt_int8* outptr, const dt_int8* inptr, | |||
| interleave_2x4_4_b(inptr0, inptr1, outptr); | |||
| } | |||
| } | |||
| for (; y + 3 < ymax; y += 4, start_y ++) { | |||
| for (; y + 3 < ymax; y += 4, start_y++) { | |||
| int K = kmax - k0; | |||
| const int8_t* inptr0 = inptr + start_y * ldin + (k0 << 2); | |||
| std::memcpy(outptr, inptr0, sizeof(dt_int8) * K * 4); | |||
| } | |||
| } | |||
| static void gemm_mk4_s8_8x12_pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||
| int x0, int xmax, int k0, int kmax) { | |||
| static void gemm_mk4_s8_8x12_pack_B( | |||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||
| const int ksize = kmax - k0; | |||
| const int ksize12 = ksize * 12; | |||
| const int ksize4 = ksize * 4; | |||
| @@ -12,10 +12,10 @@ | |||
| #include "src/aarch64/matrix_mul/int8_dot/strategy.h" | |||
| #if MGB_ENABLE_DOT | |||
| #include "src/aarch64/matrix_mul/asm/common.h" | |||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h" | |||
| #include "src/aarch64/matrix_mul/int8_dot/kernel_mk4_8x12x4.h" | |||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||
| #include "src/common/utils.h" | |||
| using namespace megdnn; | |||
| using namespace aarch64; | |||
| @@ -24,20 +24,19 @@ using namespace aarch64::matmul; | |||
| /* ====================== gemm_s8_8x12 ===========================*/ | |||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x12); | |||
| void gemm_s8_8x12::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin, | |||
| int y0, int ymax, int k0, int kmax, | |||
| bool transpose) const { | |||
| void gemm_s8_8x12::pack_A( | |||
| dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||
| int kmax, bool transpose) const { | |||
| if (transpose) { | |||
| matmul_8x12x4::gemm_s8_8x12_pack_A_t(outptr, inptr, ldin, y0, ymax, k0, | |||
| kmax); | |||
| matmul_8x12x4::gemm_s8_8x12_pack_A_t(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||
| } else { | |||
| matmul_8x12x4::gemm_s8_8x12_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, | |||
| kmax); | |||
| matmul_8x12x4::gemm_s8_8x12_pack_A_n(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||
| } | |||
| } | |||
| void gemm_s8_8x12::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||
| int xmax, int k0, int kmax, bool transpose) const { | |||
| void gemm_s8_8x12::pack_B( | |||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||
| bool transpose) const { | |||
| if (transpose) { | |||
| matmul_8x12x4::gemm_s8_8x12_pack_B_t(out, in, ldin, x0, xmax, k0, kmax); | |||
| } else { | |||
| @@ -45,16 +44,16 @@ void gemm_s8_8x12::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||
| } | |||
| } | |||
| void gemm_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||
| size_t N, size_t K, dt_int32* C, size_t LDC, | |||
| bool is_first_k, const dt_int32*, dt_int32*) const { | |||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
| ((A_dtype.enumv() == DTypeEnum::Int8 && | |||
| C_dtype.enumv() == DTypeEnum::Int32) || | |||
| (A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
| C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||
| "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||
| C_dtype.name()); | |||
| void gemm_s8_8x12::kern( | |||
| const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||
| dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const { | |||
| megdnn_assert( | |||
| A_dtype.enumv() == B_dtype.enumv() && | |||
| ((A_dtype.enumv() == DTypeEnum::Int8 && | |||
| C_dtype.enumv() == DTypeEnum::Int32) || | |||
| (A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
| C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||
| "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); | |||
| MEGDNN_MARK_USED_VAR(A_dtype); | |||
| MEGDNN_MARK_USED_VAR(B_dtype); | |||
| @@ -75,15 +74,15 @@ void gemm_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||
| size_t n = 0; | |||
| const dt_int8* cur_packB = packB; | |||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
| matmul_8x12x4::kern_8x12(packA, cur_packB, K, output, LDC, | |||
| is_first_k); | |||
| matmul_8x12x4::kern_8x12(packA, cur_packB, K, output, LDC, is_first_k); | |||
| output += B_INTERLEAVE; | |||
| cur_packB += K12; | |||
| } | |||
| for (; n < N; n += 4) { | |||
| matmul_8x12x4::kern_8x4(packA, cur_packB, K, output, LDC, | |||
| is_first_k, std::min<size_t>(N - n, 4)); | |||
| matmul_8x12x4::kern_8x4( | |||
| packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(N - n, 4)); | |||
| output += 4; | |||
| cur_packB += K4; | |||
| } | |||
| @@ -95,16 +94,17 @@ void gemm_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||
| const dt_int8* cur_packB = packB; | |||
| size_t n = 0; | |||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
| matmul_8x12x4::kern_4x12(packA, cur_packB, K, output, LDC, | |||
| is_first_k, std::min<size_t>(M - m, 4)); | |||
| matmul_8x12x4::kern_4x12( | |||
| packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(M - m, 4)); | |||
| output += B_INTERLEAVE; | |||
| cur_packB += K12; | |||
| } | |||
| for (; n < N; n += 4) { | |||
| matmul_8x12x4::kern_4x4(packA, cur_packB, K, output, LDC, | |||
| is_first_k, std::min<size_t>(M - m, 4), | |||
| std::min<size_t>(N - n, 4)); | |||
| matmul_8x12x4::kern_4x4( | |||
| packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||
| output += 4; | |||
| cur_packB += K4; | |||
| } | |||
| @@ -115,32 +115,32 @@ void gemm_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, size_t M, | |||
| /* ====================== gemm_mk4_s8_8x12 ===========================*/ | |||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_s8_8x12); | |||
| void gemm_mk4_s8_8x12::pack_A(dt_int8* outptr, const dt_int8* inptr, int ldin, | |||
| int y0, int ymax, int k0, int kmax, | |||
| bool transpose) const { | |||
| megdnn_assert(!transpose, "matrix mul mk4 with transposed matrix A is not supported"); | |||
| matmul_mk4_8x12x4::gemm_mk4_s8_8x12_pack_A(outptr, inptr, ldin, y0, ymax, k0, | |||
| kmax); | |||
| void gemm_mk4_s8_8x12::pack_A( | |||
| dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||
| int kmax, bool transpose) const { | |||
| megdnn_assert( | |||
| !transpose, "matrix mul mk4 with transposed matrix A is not supported"); | |||
| matmul_mk4_8x12x4::gemm_mk4_s8_8x12_pack_A(outptr, inptr, ldin, y0, ymax, k0, kmax); | |||
| } | |||
| void gemm_mk4_s8_8x12::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||
| int xmax, int k0, int kmax, | |||
| bool transpose) const { | |||
| megdnn_assert(!transpose, "matrix mul mk4 with transposed matrix B is not supported"); | |||
| void gemm_mk4_s8_8x12::pack_B( | |||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||
| bool transpose) const { | |||
| megdnn_assert( | |||
| !transpose, "matrix mul mk4 with transposed matrix B is not supported"); | |||
| matmul_mk4_8x12x4::gemm_mk4_s8_8x12_pack_B(out, in, ldin, x0, xmax, k0, kmax); | |||
| } | |||
| void gemm_mk4_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, | |||
| size_t M, size_t N, size_t K, dt_int32* C, | |||
| size_t LDC, bool is_first_k, const dt_int32*, | |||
| dt_int32*) const { | |||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
| ((A_dtype.enumv() == DTypeEnum::Int8 && | |||
| C_dtype.enumv() == DTypeEnum::Int32) || | |||
| (A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
| C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||
| "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||
| C_dtype.name()); | |||
| void gemm_mk4_s8_8x12::kern( | |||
| const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||
| dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const { | |||
| megdnn_assert( | |||
| A_dtype.enumv() == B_dtype.enumv() && | |||
| ((A_dtype.enumv() == DTypeEnum::Int8 && | |||
| C_dtype.enumv() == DTypeEnum::Int32) || | |||
| (A_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
| C_dtype.enumv() == DTypeEnum::QuantizedS32)), | |||
| "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); | |||
| MEGDNN_MARK_USED_VAR(A_dtype); | |||
| MEGDNN_MARK_USED_VAR(B_dtype); | |||
| @@ -161,15 +161,15 @@ void gemm_mk4_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, | |||
| size_t n = 0; | |||
| const dt_int8* cur_packB = packB; | |||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
| matmul_mk4_8x12x4::kern_8x12(packA, cur_packB, K, output, LDC, | |||
| is_first_k); | |||
| matmul_mk4_8x12x4::kern_8x12(packA, cur_packB, K, output, LDC, is_first_k); | |||
| output += (B_INTERLEAVE << 2); | |||
| cur_packB += K12; | |||
| } | |||
| for (; n < N; n += 4) { | |||
| matmul_mk4_8x12x4::kern_8x4(packA, cur_packB, K, output, LDC, | |||
| is_first_k, std::min<size_t>(N - n, 4)); | |||
| matmul_mk4_8x12x4::kern_8x4( | |||
| packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(N - n, 4)); | |||
| output += 16; | |||
| cur_packB += K4; | |||
| } | |||
| @@ -181,15 +181,15 @@ void gemm_mk4_s8_8x12::kern(const dt_int8* packA, const dt_int8* packB, | |||
| const dt_int8* cur_packB = packB; | |||
| size_t n = 0; | |||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
| matmul_mk4_8x12x4::kern_4x12(packA, cur_packB, K, output, LDC, | |||
| is_first_k); | |||
| matmul_mk4_8x12x4::kern_4x12(packA, cur_packB, K, output, LDC, is_first_k); | |||
| output += (B_INTERLEAVE << 2); | |||
| cur_packB += K12; | |||
| } | |||
| for (; n < N; n += 4) { | |||
| matmul_mk4_8x12x4::kern_4x4(packA, cur_packB, K, output, LDC, | |||
| is_first_k, std::min<size_t>(N - n, 4)); | |||
| matmul_mk4_8x12x4::kern_4x4( | |||
| packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(N - n, 4)); | |||
| output += 16; | |||
| cur_packB += K4; | |||
| } | |||
| @@ -16,14 +16,14 @@ namespace megdnn { | |||
| namespace aarch64 { | |||
| namespace matmul { | |||
| MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 12, 4, false, true, | |||
| gemm_s8_8x12); | |||
| MEGDNN_REG_GEMM_STRATEGY( | |||
| dt_int8, dt_int32, dt_int32, 8, 12, 4, false, true, gemm_s8_8x12); | |||
| MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 12, 4, false, true, | |||
| gemm_mk4_s8_8x12); | |||
| MEGDNN_REG_GEMM_STRATEGY( | |||
| dt_int8, dt_int32, dt_int32, 8, 12, 4, false, true, gemm_mk4_s8_8x12); | |||
| } // namespace aarch64 | |||
| } // namespace matmul | |||
| } // namespace aarch64 | |||
| } // namespace megdnn | |||
| #endif | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -34,9 +34,9 @@ namespace matmul_4x4x16 { | |||
| * | |||
| * Accumulator | |||
| */ | |||
| static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
| int16_t* output, int LDC, bool is_first_k, int m_remain, | |||
| int n_remain) { | |||
| static void kern_4x4( | |||
| const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||
| bool is_first_k, int m_remain, int n_remain) { | |||
| K /= 16; | |||
| const int8_t* a_ptr = packA; | |||
| const int8_t* b_ptr = packB; | |||
| @@ -230,16 +230,14 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
| // Store back into memory | |||
| STORE_C | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||
| [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||
| [outptr] "+r"(outptr), [m_remain] "+r"(m_remain), | |||
| [n_remain] "+r"(n_remain) | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||
| [K] "+r"(K), [LDC] "+r"(LDC), [outptr] "+r"(outptr), | |||
| [m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain) | |||
| : | |||
| : "cc", "memory", "x1", "x2", "x3", "x4", "x5", "v0", "v1", "v2", | |||
| "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", | |||
| "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||
| "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", | |||
| "v31"); | |||
| : "cc", "memory", "x1", "x2", "x3", "x4", "x5", "v0", "v1", "v2", "v3", | |||
| "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", | |||
| "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", | |||
| "v25", "v26", "v27", "v28", "v29", "v30", "v31"); | |||
| #undef LOAD_LINE | |||
| #undef LOAD_C | |||
| @@ -247,9 +245,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
| #undef STORE_C | |||
| } | |||
| static void gemm_s8x8x16_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||
| int ldin, int y0, int ymax, int k0, | |||
| int kmax) { | |||
| static void gemm_s8x8x16_4x4_pack_A_n( | |||
| dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||
| int kmax) { | |||
| int8_t zerobuff[16]; | |||
| std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||
| @@ -292,9 +290,11 @@ static void gemm_s8x8x16_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||
| if (y + 3 >= ymax) { | |||
| switch (y + 3 - ymax) { | |||
| case 2: | |||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr1 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 1: | |||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr2 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 0: | |||
| inptr3 = zerobuff; | |||
| break; | |||
| @@ -309,9 +309,11 @@ static void gemm_s8x8x16_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||
| if (y + 3 >= ymax) { | |||
| switch (y + 3 - ymax) { | |||
| case 2: | |||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr1 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 1: | |||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr2 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 0: | |||
| inptr3 = zerobuff; | |||
| break; | |||
| @@ -324,8 +326,8 @@ static void gemm_s8x8x16_4x4_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||
| } | |||
| } | |||
| static void gemm_s8x8x16_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||
| int x0, int xmax, int k0, int kmax) { | |||
| static void gemm_s8x8x16_4x4_pack_B_n( | |||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||
| int8_t zerobuff[16]; | |||
| std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||
| const int ksize = kmax - k0; | |||
| @@ -362,19 +364,26 @@ static void gemm_s8x8x16_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||
| if (remain >= 0) { | |||
| switch (remain) { | |||
| case 7: | |||
| inptr0 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr0 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 6: | |||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr1 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 5: | |||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr2 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 4: | |||
| inptr3 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr3 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 3: | |||
| inptr4 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr4 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 2: | |||
| inptr5 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr5 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 1: | |||
| inptr6 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr6 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 0: | |||
| inptr7 = zerobuff; | |||
| break; | |||
| @@ -383,9 +392,9 @@ static void gemm_s8x8x16_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||
| } | |||
| } | |||
| transpose_4x16_1_b_helper(inptr0, inptr1, inptr2, inptr3, | |||
| inptr4, inptr5, inptr6, inptr7, | |||
| outptr_inner); | |||
| transpose_4x16_1_b_helper( | |||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
| outptr_inner); | |||
| outptr_inner += ksize4; | |||
| } | |||
| @@ -393,19 +402,26 @@ static void gemm_s8x8x16_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||
| if (remain >= 0) { | |||
| switch (remain) { | |||
| case 7: | |||
| inptr0 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr0 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 6: | |||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr1 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 5: | |||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr2 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 4: | |||
| inptr3 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr3 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 3: | |||
| inptr4 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr4 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 2: | |||
| inptr5 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr5 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 1: | |||
| inptr6 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr6 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 0: | |||
| inptr7 = zerobuff; | |||
| break; | |||
| @@ -42,8 +42,9 @@ namespace matmul_8x8x8 { | |||
| * | |||
| * Accumulator | |||
| */ | |||
| static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||
| int16_t* output, int LDC, bool is_first_k) { | |||
| static void kern_8x8( | |||
| const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||
| bool is_first_k) { | |||
| K /= 8; | |||
| const int8_t* a_ptr = packA; | |||
| const int8_t* b_ptr = packB; | |||
| @@ -217,13 +218,12 @@ static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||
| "bne 2b\n" | |||
| "3:\n" STORE_C | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||
| [outptr] "+r"(outptr) | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC), | |||
| [is_first_k] "+r"(is_first_k), [outptr] "+r"(outptr) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "x1", "x2", "x3", | |||
| "x4", "x5", "x6", "x7", "cc", "memory"); | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||
| "v12", "v13", "v14", "v15", "v16", "v17", "x1", "x2", "x3", "x4", "x5", | |||
| "x6", "x7", "cc", "memory"); | |||
| #undef LOAD_LINE | |||
| #undef LOAD_C | |||
| #undef STORE_LINE | |||
| @@ -258,9 +258,9 @@ static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||
| * Accumulator | |||
| */ | |||
| static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||
| int16_t* output, int LDC, bool is_first_k, | |||
| size_t n_remain) { | |||
| static void kern_8x4( | |||
| const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||
| bool is_first_k, size_t n_remain) { | |||
| K /= 8; | |||
| const int8_t* a_ptr = packA; | |||
| const int8_t* b_ptr = packB; | |||
| @@ -471,16 +471,14 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||
| "cbnz %w[K], 2b\n" | |||
| "3:\n" STORE_C | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||
| [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||
| [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), | |||
| [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||
| [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), | |||
| [outptr6] "=r"(outptr6), [outptr7] "=r"(outptr7), [x0] "+r"(x0), | |||
| [n_remain] "+r"(n_remain) | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||
| [K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), | |||
| [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||
| [outptr4] "=r"(outptr4), [outptr5] "=r"(outptr5), [outptr6] "=r"(outptr6), | |||
| [outptr7] "=r"(outptr7), [x0] "+r"(x0), [n_remain] "+r"(n_remain) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "cc", "memory"); | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||
| "v12", "v13", "v14", "v15", "v16", "v17", "cc", "memory"); | |||
| #undef LOAD_LINE | |||
| #undef LOAD_C | |||
| @@ -514,9 +512,9 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, | |||
| * Accumulator | |||
| */ | |||
| static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, | |||
| int16_t* output, int LDC, bool is_first_k, | |||
| size_t m_remain) { | |||
| static void kern_4x8( | |||
| const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||
| bool is_first_k, size_t m_remain) { | |||
| K /= 8; | |||
| const int8_t* a_ptr = packA; | |||
| const int8_t* b_ptr = packB; | |||
| @@ -646,11 +644,10 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, | |||
| "cbnz %w[K], 2b\n" | |||
| "3:\n" STORE_C | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), | |||
| [is_first_k] "+r"(is_first_k), [K] "+r"(K), [LDC] "+r"(LDC), | |||
| [outptr0] "+r"(outptr0), [outptr1] "=r"(outptr1), | |||
| [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), [x0] "+r"(x0), | |||
| [m_remain] "+r"(m_remain) | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||
| [K] "+r"(K), [LDC] "+r"(LDC), [outptr0] "+r"(outptr0), | |||
| [outptr1] "=r"(outptr1), [outptr2] "=r"(outptr2), [outptr3] "=r"(outptr3), | |||
| [x0] "+r"(x0), [m_remain] "+r"(m_remain) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "cc", | |||
| "memory"); | |||
| @@ -686,9 +683,9 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, | |||
| * | |||
| * Accumulator | |||
| */ | |||
| static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
| int16_t* output, int LDC, bool is_first_k, size_t m_remain, | |||
| size_t n_remain) { | |||
| static void kern_4x4( | |||
| const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||
| bool is_first_k, size_t m_remain, size_t n_remain) { | |||
| K /= 8; | |||
| const int8_t* a_ptr = packA; | |||
| const int8_t* b_ptr = packB; | |||
| @@ -853,11 +850,10 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
| "3:\n" STORE_C | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [outptr] "+r"(outptr), | |||
| [K] "+r"(K), [is_first_k] "+r"(is_first_k), [LDC] "+r"(LDC), | |||
| [x0] "+r"(x0), [m_remain] "+r"(m_remain), | |||
| [n_remain] "+r"(n_remain) | |||
| [x0] "+r"(x0), [m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "x1", | |||
| "cc", "memory"); | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "x1", "cc", | |||
| "memory"); | |||
| #undef LOAD_LINE | |||
| #undef LOAD_C | |||
| @@ -865,9 +861,9 @@ static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
| #undef STORE_C | |||
| } | |||
| static void gemm_s8x8x16_8x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||
| int ldin, int y0, int ymax, int k0, | |||
| int kmax) { | |||
| static void gemm_s8x8x16_8x8_pack_A_n( | |||
| dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||
| int kmax) { | |||
| int8_t zerobuff[16]; | |||
| std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||
| @@ -893,13 +889,15 @@ static void gemm_s8x8x16_8x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||
| int K = kmax - k0; | |||
| for (; K > 15; K -= 16) { | |||
| interleave_8x8_2_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||
| inptr6, inptr7, outptr); | |||
| interleave_8x8_2_b( | |||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
| outptr); | |||
| } | |||
| if (K > 0) { | |||
| interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||
| inptr7, outptr, 8, K); | |||
| interleave_8( | |||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
| outptr, 8, K); | |||
| } | |||
| } | |||
| for (; y < ymax; y += 4) { | |||
| @@ -918,9 +916,11 @@ static void gemm_s8x8x16_8x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||
| if (y + 3 >= ymax) { | |||
| switch (y + 3 - ymax) { | |||
| case 2: | |||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr1 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 1: | |||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr2 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 0: | |||
| inptr3 = zerobuff; | |||
| break; | |||
| @@ -936,9 +936,11 @@ static void gemm_s8x8x16_8x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||
| if (y + 3 >= ymax) { | |||
| switch (y + 3 - ymax) { | |||
| case 2: | |||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr1 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 1: | |||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr2 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 0: | |||
| inptr3 = zerobuff; | |||
| break; | |||
| @@ -951,9 +953,8 @@ static void gemm_s8x8x16_8x8_pack_A_n(dt_int8* outptr, const dt_int8* inptr, | |||
| } | |||
| } | |||
| static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, | |||
| int ldin, int x0, int xmax, | |||
| int k0, int kmax) { | |||
| static void gemm_s8x8x16_8x8_transpose_pack_A_n( | |||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||
| int8_t zerobuff[16]; | |||
| std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||
| @@ -991,17 +992,23 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, | |||
| if (k + 7 >= kmax) { | |||
| switch (k + 7 - kmax) { | |||
| case 6: | |||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr1 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 5: | |||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr2 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 4: | |||
| inptr3 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr3 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 3: | |||
| inptr4 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr4 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 2: | |||
| inptr5 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr5 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 1: | |||
| inptr6 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr6 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 0: | |||
| inptr7 = zerobuff; | |||
| break; | |||
| @@ -1009,8 +1016,9 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, | |||
| megdnn_assert(0); | |||
| } | |||
| } | |||
| transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||
| inptr6, inptr7, outptr); | |||
| transpose_8x8_1_b( | |||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
| outptr); | |||
| outptr += ksize8; | |||
| } | |||
| @@ -1019,17 +1027,23 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, | |||
| if (k + 7 >= kmax) { | |||
| switch (k + 7 - kmax) { | |||
| case 6: | |||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr1 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 5: | |||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr2 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 4: | |||
| inptr3 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr3 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 3: | |||
| inptr4 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr4 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 2: | |||
| inptr5 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr5 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 1: | |||
| inptr6 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr6 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 0: | |||
| inptr7 = zerobuff; | |||
| break; | |||
| @@ -1038,8 +1052,9 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, | |||
| } | |||
| } | |||
| transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||
| inptr7, outptr, 4, 4); | |||
| transpose_8( | |||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
| outptr, 4, 4); | |||
| outptr += ksize4; | |||
| } | |||
| @@ -1047,17 +1062,23 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, | |||
| if (k + 7 >= kmax) { | |||
| switch (k + 7 - kmax) { | |||
| case 6: | |||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr1 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 5: | |||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr2 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 4: | |||
| inptr3 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr3 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 3: | |||
| inptr4 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr4 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 2: | |||
| inptr5 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr5 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 1: | |||
| inptr6 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr6 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 0: | |||
| inptr7 = zerobuff; | |||
| break; | |||
| @@ -1066,8 +1087,9 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, | |||
| } | |||
| } | |||
| transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||
| inptr7, outptr, 4, xmax - x); | |||
| transpose_8( | |||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
| outptr, 4, xmax - x); | |||
| } | |||
| outptr_base += 8 * 8; | |||
| @@ -1075,8 +1097,8 @@ static void gemm_s8x8x16_8x8_transpose_pack_A_n(dt_int8* out, const dt_int8* in, | |||
| } | |||
| } | |||
| static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||
| int x0, int xmax, int k0, int kmax) { | |||
| static void gemm_s8x8x16_8x8_pack_B_n( | |||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax) { | |||
| int8_t zerobuff[16]; | |||
| std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||
| const int ksize = kmax - k0; | |||
| @@ -1113,17 +1135,23 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||
| if (k + 7 >= kmax) { | |||
| switch (k + 7 - kmax) { | |||
| case 6: | |||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr1 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 5: | |||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr2 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 4: | |||
| inptr3 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr3 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 3: | |||
| inptr4 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr4 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 2: | |||
| inptr5 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr5 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 1: | |||
| inptr6 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr6 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 0: | |||
| inptr7 = zerobuff; | |||
| break; | |||
| @@ -1132,8 +1160,9 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||
| } | |||
| } | |||
| outptr_interleave = outptr; | |||
| interleave_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||
| inptr6, inptr7, outptr_interleave); | |||
| interleave_8x8_1_b( | |||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
| outptr_interleave); | |||
| outptr += ksize8; | |||
| } | |||
| @@ -1142,17 +1171,23 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||
| if (k + 7 >= kmax) { | |||
| switch (k + 7 - kmax) { | |||
| case 6: | |||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr1 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 5: | |||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr2 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 4: | |||
| inptr3 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr3 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 3: | |||
| inptr4 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr4 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 2: | |||
| inptr5 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr5 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 1: | |||
| inptr6 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr6 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 0: | |||
| inptr7 = zerobuff; | |||
| break; | |||
| @@ -1162,8 +1197,9 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||
| } | |||
| outptr_interleave = outptr; | |||
| interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||
| inptr7, outptr_interleave, 4, 4); | |||
| interleave_8( | |||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
| outptr_interleave, 4, 4); | |||
| outptr += ksize4; | |||
| } | |||
| @@ -1171,17 +1207,23 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||
| if (k + 7 >= kmax) { | |||
| switch (k + 7 - kmax) { | |||
| case 6: | |||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr1 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 5: | |||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr2 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 4: | |||
| inptr3 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr3 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 3: | |||
| inptr4 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr4 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 2: | |||
| inptr5 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr5 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 1: | |||
| inptr6 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr6 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 0: | |||
| inptr7 = zerobuff; | |||
| break; | |||
| @@ -1191,8 +1233,9 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||
| } | |||
| outptr_interleave = outptr; | |||
| interleave_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||
| inptr7, outptr_interleave, 4, xmax - x); | |||
| interleave_8( | |||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
| outptr_interleave, 4, xmax - x); | |||
| } | |||
| outptr_base += 8 * 8; | |||
| @@ -1200,10 +1243,9 @@ static void gemm_s8x8x16_8x8_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, | |||
| } | |||
| } | |||
| static void gemm_s8x8x16_8x8_transpose_pack_B_n(dt_int8* outptr, | |||
| const dt_int8* inptr, int ldin, | |||
| int y0, int ymax, int k0, | |||
| int kmax) { | |||
| static void gemm_s8x8x16_8x8_transpose_pack_B_n( | |||
| dt_int8* outptr, const dt_int8* inptr, int ldin, int y0, int ymax, int k0, | |||
| int kmax) { | |||
| int8_t zerobuff[16]; | |||
| std::memset(zerobuff, 0, sizeof(int8_t) * 16); | |||
| constexpr int interleave4 = 32; | |||
| @@ -1231,14 +1273,16 @@ static void gemm_s8x8x16_8x8_transpose_pack_B_n(dt_int8* outptr, | |||
| int K = kmax - k0; | |||
| for (; K > 7; K -= 8) { | |||
| transpose_8x8_1_b(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, | |||
| inptr6, inptr7, outptr); | |||
| transpose_8x8_1_b( | |||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
| outptr); | |||
| outptr += interleave8; | |||
| } | |||
| if (K > 0) { | |||
| transpose_8(inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, | |||
| inptr7, outptr, 8, K); | |||
| transpose_8( | |||
| inptr0, inptr1, inptr2, inptr3, inptr4, inptr5, inptr6, inptr7, | |||
| outptr, 8, K); | |||
| outptr += interleave8; | |||
| } | |||
| } | |||
| @@ -1259,9 +1303,11 @@ static void gemm_s8x8x16_8x8_transpose_pack_B_n(dt_int8* outptr, | |||
| if (y + 3 >= ymax) { | |||
| switch (y + 3 - ymax) { | |||
| case 2: | |||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr1 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 1: | |||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr2 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 0: | |||
| inptr3 = zerobuff; | |||
| break; | |||
| @@ -1278,9 +1324,11 @@ static void gemm_s8x8x16_8x8_transpose_pack_B_n(dt_int8* outptr, | |||
| if (y + 3 >= ymax) { | |||
| switch (y + 3 - ymax) { | |||
| case 2: | |||
| inptr1 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr1 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 1: | |||
| inptr2 = zerobuff; MEGDNN_FALLTHRU | |||
| inptr2 = zerobuff; | |||
| MEGDNN_FALLTHRU | |||
| case 0: | |||
| inptr3 = zerobuff; | |||
| break; | |||
| @@ -40,11 +40,9 @@ namespace matmul_mk4_16x12x4_a53 { | |||
| * Accumulator | |||
| */ | |||
| // clang-format on | |||
| static __attribute__((noinline)) void kern_16x12(const int16_t* packA, | |||
| const int8_t* packB, int K, | |||
| int16_t* output, int LDC, | |||
| bool is_first_k, | |||
| int remain_n) { | |||
| static __attribute__((noinline)) void kern_16x12( | |||
| const int16_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||
| bool is_first_k, int remain_n) { | |||
| K /= 4; | |||
| const int16_t* a_ptr = packA; | |||
| const int8_t* b_ptr = packB; | |||
| @@ -521,15 +519,15 @@ static __attribute__((noinline)) void kern_16x12(const int16_t* packA, | |||
| "6:\n" STORE_C | |||
| "101:\n" | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||
| [outptr] "+r"(outptr), [remain_n] "+r"(remain_n) | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC), | |||
| [is_first_k] "+r"(is_first_k), [outptr] "+r"(outptr), | |||
| [remain_n] "+r"(remain_n) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||
| "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||
| "v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7", | |||
| "x8", "x9", "x10", "cc", "memory"); | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||
| "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||
| "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", | |||
| "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc", | |||
| "memory"); | |||
| #undef STORE_C | |||
| #undef STORE_LINE | |||
| @@ -554,10 +552,9 @@ static __attribute__((noinline)) void kern_16x12(const int16_t* packA, | |||
| * Accumulator | |||
| */ | |||
| // clang-format on | |||
| static __attribute__((noinline)) void kern_8x12(const int16_t* packA, | |||
| const int8_t* packB, int K, | |||
| int16_t* output, int LDC, | |||
| bool is_first_k, int remain_n) { | |||
| static __attribute__((noinline)) void kern_8x12( | |||
| const int16_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||
| bool is_first_k, int remain_n) { | |||
| K /= 4; | |||
| const int16_t* a_ptr = packA; | |||
| const int8_t* b_ptr = packB; | |||
| @@ -858,14 +855,13 @@ static __attribute__((noinline)) void kern_8x12(const int16_t* packA, | |||
| "6:\n" STORE_C | |||
| "101:\n" | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||
| [outptr] "+r"(outptr), [remain_n] "+r"(remain_n) | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC), | |||
| [is_first_k] "+r"(is_first_k), [outptr] "+r"(outptr), | |||
| [remain_n] "+r"(remain_n) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||
| "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc", | |||
| "memory"); | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||
| "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", "x2", "x3", | |||
| "x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc", "memory"); | |||
| #undef STORE_C | |||
| #undef STORE_LINE | |||
| @@ -890,10 +886,9 @@ static __attribute__((noinline)) void kern_8x12(const int16_t* packA, | |||
| * Accumulator | |||
| */ | |||
| // clang-format on | |||
| static __attribute__((noinline)) void kern_4x12(const int16_t* packA, | |||
| const int8_t* packB, int K, | |||
| int16_t* output, int LDC, | |||
| bool is_first_k, int remain_n) { | |||
| static __attribute__((noinline)) void kern_4x12( | |||
| const int16_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||
| bool is_first_k, int remain_n) { | |||
| K /= 4; | |||
| const int16_t* a_ptr = packA; | |||
| const int8_t* b_ptr = packB; | |||
| @@ -1162,22 +1157,21 @@ static __attribute__((noinline)) void kern_4x12(const int16_t* packA, | |||
| "6:\n" STORE_C | |||
| "101:\n" | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
| [LDC] "+r"(LDC), [is_first_k] "+r"(is_first_k), | |||
| [outptr] "+r"(outptr), [remain_n] "+r"(remain_n) | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [LDC] "+r"(LDC), | |||
| [is_first_k] "+r"(is_first_k), [outptr] "+r"(outptr), | |||
| [remain_n] "+r"(remain_n) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||
| "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc", | |||
| "memory"); | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||
| "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "x1", "x2", "x3", | |||
| "x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc", "memory"); | |||
| #undef STORE_C | |||
| #undef STORE_LINE | |||
| } | |||
| static void gemm_s8x8x16_mk4_16x12_pack_A(dt_int16* outptr, | |||
| const dt_int8* inptr, int ldin, | |||
| int m0, int mmax, int k0, int kmax) { | |||
| static void gemm_s8x8x16_mk4_16x12_pack_A( | |||
| dt_int16* outptr, const dt_int8* inptr, int ldin, int m0, int mmax, int k0, | |||
| int kmax) { | |||
| megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4"); | |||
| megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | |||
| constexpr int pack_m = 16; | |||
| @@ -1224,9 +1218,8 @@ static void gemm_s8x8x16_mk4_16x12_pack_A(dt_int16* outptr, | |||
| } | |||
| } | |||
| static void gemm_s8x8x16_mk4_16x12_pack_B(dt_int8* out, const dt_int8* in, | |||
| int ldin, int n0, int nmax, int k0, | |||
| int kmax) { | |||
| static void gemm_s8x8x16_mk4_16x12_pack_B( | |||
| dt_int8* out, const dt_int8* in, int ldin, int n0, int nmax, int k0, int kmax) { | |||
| megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | |||
| constexpr int pack_n = 12; | |||
| @@ -43,8 +43,9 @@ namespace matmul_mk4_4x4x8_a72 { | |||
| */ | |||
| // clang-format on | |||
| static inline void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
| int16_t* output, int LDC, bool, int remain_n) { | |||
| static inline void kern_4x4( | |||
| const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, bool, | |||
| int remain_n) { | |||
| K = div_ceil(K, 8); | |||
| int oddk = (K & 1); | |||
| K = ((K + 1) / 2) - 1; | |||
| @@ -261,15 +262,14 @@ static inline void kern_4x4(const int8_t* packA, const int8_t* packB, int K, | |||
| "7:\n" STORE_C | |||
| "101:\n" | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), | |||
| [oddk] "+r"(oddk), [LDC] "+r"(LDC), [outptr] "+r"(outptr), | |||
| [remain_n] "+r"(remain_n) | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [K] "+r"(K), [oddk] "+r"(oddk), | |||
| [LDC] "+r"(LDC), [outptr] "+r"(outptr), [remain_n] "+r"(remain_n) | |||
| : | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||
| "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||
| "v29", "v30", "v31", "x1", "x2", "x3", "x4", "x5", "x6", "x7", | |||
| "x8", "x9", "x10", "cc", "memory"); | |||
| : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", | |||
| "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", | |||
| "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", | |||
| "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "cc", | |||
| "memory"); | |||
| #undef STORE_C | |||
| #undef STORE_LINE | |||
| @@ -282,26 +282,23 @@ static inline void transpose_8x4_b(const dt_int8* inptr, dt_int8* outptr) { | |||
| vst1_s8(outptr + 3 * 8, in0.val[3]); | |||
| } | |||
| static inline void interleve_8x4_b(const dt_int8* inptr, const dt_int8* inptr2, | |||
| dt_int8* outptr) { | |||
| static inline void interleve_8x4_b( | |||
| const dt_int8* inptr, const dt_int8* inptr2, dt_int8* outptr) { | |||
| int8x16_t in0 = vld1q_s8(inptr); | |||
| int8x16_t in1 = vld1q_s8(inptr2); | |||
| int32x4x2_t in_x2 = { | |||
| {vreinterpretq_s32_s8(in0), vreinterpretq_s32_s8(in1)}}; | |||
| int32x4x2_t in_x2 = {{vreinterpretq_s32_s8(in0), vreinterpretq_s32_s8(in1)}}; | |||
| vst2q_s32(reinterpret_cast<int32_t*>(outptr), in_x2); | |||
| } | |||
| static inline void interleve_8x4_b_pad(const dt_int8* inptr, dt_int8* outptr) { | |||
| int8x16_t in0 = vld1q_s8(inptr); | |||
| int8x16_t in1 = vdupq_n_s8(0); | |||
| int32x4x2_t in_x2 = { | |||
| {vreinterpretq_s32_s8(in0), vreinterpretq_s32_s8(in1)}}; | |||
| int32x4x2_t in_x2 = {{vreinterpretq_s32_s8(in0), vreinterpretq_s32_s8(in1)}}; | |||
| vst2q_s32(reinterpret_cast<int32_t*>(outptr), in_x2); | |||
| } | |||
| static void gemm_s8x8x16_mk4_4x4x8_pack_A(dt_int8* out, const dt_int8* in, | |||
| int ldin, int m0, int mmax, int k0, | |||
| int kmax) { | |||
| static void gemm_s8x8x16_mk4_4x4x8_pack_A( | |||
| dt_int8* out, const dt_int8* in, int ldin, int m0, int mmax, int k0, int kmax) { | |||
| megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4"); | |||
| megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | |||
| constexpr int pack_m = 4; | |||
| @@ -330,9 +327,8 @@ static void gemm_s8x8x16_mk4_4x4x8_pack_A(dt_int8* out, const dt_int8* in, | |||
| } | |||
| } | |||
| static void gemm_s8x8x16_mk4_4x4x8_pack_B(dt_int8* out, const dt_int8* in, | |||
| int ldin, int n0, int nmax, int k0, | |||
| int kmax) { | |||
| static void gemm_s8x8x16_mk4_4x4x8_pack_B( | |||
| dt_int8* out, const dt_int8* in, int ldin, int n0, int nmax, int k0, int kmax) { | |||
| megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | |||
| constexpr int pack_n = 4; | |||
| @@ -18,7 +18,6 @@ namespace megdnn { | |||
| namespace aarch64 { | |||
| namespace matmul_mk4_8x8x8 { | |||
| /** | |||
| * Overview of register layout: | |||
| * | |||
| @@ -39,18 +38,18 @@ namespace matmul_mk4_8x8x8 { | |||
| * | v16 | | v28 | | |||
| * | v17 | | v29 | | |||
| * | v16 | | v30 | | |||
| * | v17 | | v31 | | |||
| * | v17 | | v31 | | |||
| * +--------+ - - - - +---------------------------------+ | |||
| * | |||
| * Accumulator | |||
| */ | |||
| static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||
| int16_t* output, int LDC, bool is_first_k, int m_remain, | |||
| int n_remain) { | |||
| static void kern_8x8( | |||
| const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||
| bool is_first_k, int m_remain, int n_remain) { | |||
| K /= 8; | |||
| LDC = LDC * sizeof(int16_t); | |||
| const int8_t* a_ptr = packB;//packA; | |||
| const int8_t* b_ptr = packA;//packB; | |||
| const int8_t* a_ptr = packB; // packA; | |||
| const int8_t* b_ptr = packA; // packB; | |||
| // clang-format off | |||
| #define LOAD_C_8 \ | |||
| "ld1 {v0.8h}, [x0], #16\n" \ | |||
| @@ -291,17 +290,17 @@ static void kern_8x8(const int8_t* packA, const int8_t* packB, int K, | |||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||
| "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||
| "v29", "v30", "v31"); | |||
| // clang-format on | |||
| // clang-format on | |||
| } | |||
| static void kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, | |||
| int16_t* output, int LDC, bool is_first_k, int m_remain, | |||
| int n_remain) { | |||
| static void kern_8x8_remain( | |||
| const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||
| bool is_first_k, int m_remain, int n_remain) { | |||
| K /= 8; | |||
| LDC = LDC * sizeof(int16_t); | |||
| const int8_t* a_ptr = packB; | |||
| const int8_t* b_ptr = packA; | |||
| // clang-format off | |||
| // clang-format off | |||
| register int16_t* outptr asm("x0") = output; | |||
| asm volatile( | |||
| "add x1, x0, %x[LDC]\n" | |||
| @@ -476,7 +475,7 @@ static void kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, | |||
| "cbnz %w[K], 1b\n" | |||
| "cmp %w[is_first_k], #1\n" | |||
| "beq 2f\n" | |||
| "beq 2f\n" | |||
| "cmp %x[m_remain], #8 \n" | |||
| "beq 8f \n" | |||
| "cmp %x[m_remain], #4 \n" | |||
| @@ -633,7 +632,7 @@ static void kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, | |||
| "zip2 v15.2d, v30.2d, v31.2d \n" | |||
| "add v6.8h, v6.8h, v13.8h \n" | |||
| "add v7.8h, v7.8h, v15.8h \n" | |||
| //save to memory | |||
| // save to memory | |||
| "cmp %x[m_remain], #8 \n" | |||
| "beq 4f \n" | |||
| "cmp %x[m_remain], #4 \n" | |||
| @@ -766,31 +765,27 @@ static void kern_8x8_remain(const int8_t* packA, const int8_t* packB, int K, | |||
| "b 1000f \n" | |||
| "1000: \n" | |||
| : [a_ptr] "+r"(a_ptr), [b_ptr] "+r"(b_ptr), [is_first_k] "+r"(is_first_k), | |||
| [K] "+r"(K), [LDC] "+r"(LDC), [outptr] "+r"(outptr), | |||
| [m_remain] "+r"(m_remain), [n_remain] "+r"(n_remain) | |||
| : | |||
| [ a_ptr ] "+r"(a_ptr), [ b_ptr ] "+r"(b_ptr), | |||
| [ is_first_k ] "+r"(is_first_k), [ K ] "+r"(K), [ LDC ] "+r"(LDC), | |||
| [ outptr ] "+r"(outptr), [ m_remain ] "+r"(m_remain), | |||
| [ n_remain ] "+r"(n_remain) | |||
| : | |||
| : "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", | |||
| "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", | |||
| "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", | |||
| "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", | |||
| "v29", "v30", "v31"); | |||
| // clang-format on | |||
| : "cc", "memory", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "v0", | |||
| "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", | |||
| "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", | |||
| "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); | |||
| // clang-format on | |||
| #undef LOAD_C_8 | |||
| #undef STORE_C_8 | |||
| } | |||
| static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, | |||
| int16_t* output, int LDC, bool is_first_k, int m_remain, | |||
| int n_remain) { | |||
| static void kern_4x8( | |||
| const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||
| bool is_first_k, int m_remain, int n_remain) { | |||
| K /= 8; | |||
| LDC = LDC * sizeof(int16_t); | |||
| const int8_t* a_ptr = packB;//packA; | |||
| const int8_t* b_ptr = packA;//packB; | |||
| const int8_t* a_ptr = packB; // packA; | |||
| const int8_t* b_ptr = packA; // packB; | |||
| // clang-format off | |||
| #define LOAD_C_4 \ | |||
| "ld1 {v0.8h}, [x0], #16\n" \ | |||
| @@ -1018,14 +1013,14 @@ static void kern_4x8(const int8_t* packA, const int8_t* packB, int K, | |||
| #undef LOAD_C_4 | |||
| #undef STORE_C_4 | |||
| } | |||
| static void kern_4x8_remain(const int8_t* packA, const int8_t* packB, int K, | |||
| int16_t* output, int LDC, bool is_first_k, int m_remain, | |||
| int n_remain) { | |||
| static void kern_4x8_remain( | |||
| const int8_t* packA, const int8_t* packB, int K, int16_t* output, int LDC, | |||
| bool is_first_k, int m_remain, int n_remain) { | |||
| K /= 8; | |||
| LDC = LDC * sizeof(int16_t); | |||
| const int8_t* a_ptr = packB;//packA; | |||
| const int8_t* b_ptr = packA;//packB; | |||
| // clang-format off | |||
| const int8_t* a_ptr = packB; // packA; | |||
| const int8_t* b_ptr = packA; // packB; | |||
| // clang-format off | |||
| register int16_t* outptr asm("x0") = output; | |||
| asm volatile( | |||
| @@ -1324,13 +1319,12 @@ static void kern_4x8_remain(const int8_t* packA, const int8_t* packB, int K, | |||
| #undef STORE_C_4 | |||
| } | |||
| //! pack to icxoc | |||
| //! (M/4,K/4,4(K),4(M)) pack to (M/8,k/8,8(K_ic_0~3_ic_4~7),8(M_oc0~3_OC_4~7)) | |||
| //! if M K is not times of 8,pack 0 instead | |||
| static void gemm_s8x8x16_mk4_8x8x8_pack_A(dt_int8* outptr, | |||
| const dt_int8* inptr, int ldin, | |||
| int m0, int mmax, int k0, int kmax) { | |||
| //! if M K is not times of 8,pack 0 instead | |||
| static void gemm_s8x8x16_mk4_8x8x8_pack_A( | |||
| dt_int8* outptr, const dt_int8* inptr, int ldin, int m0, int mmax, int k0, | |||
| int kmax) { | |||
| megdnn_assert(m0 % 4 == 0 && mmax % 4 == 0, "M must be time of 4"); | |||
| megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | |||
| constexpr int pack_m = 8; | |||
| @@ -1349,8 +1343,8 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_A(dt_int8* outptr, | |||
| prefetch_2x(inptr0); | |||
| prefetch_2x(inptr1); | |||
| int k_idx = k0; | |||
| for ( ; k_idx + 7 < kmax; k_idx += pack_k) { | |||
| interleave_8x8_mk4_b(inptr0,inptr1,outptr); | |||
| for (; k_idx + 7 < kmax; k_idx += pack_k) { | |||
| interleave_8x8_mk4_b(inptr0, inptr1, outptr); | |||
| } | |||
| if (k_idx < kmax) { | |||
| @@ -1368,9 +1362,9 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_A(dt_int8* outptr, | |||
| prefetch_2x(inptr0); | |||
| prefetch_2x(inptr1); | |||
| int k_idx = k0; | |||
| for ( ; k_idx + 7 < kmax; k_idx += pack_k) { | |||
| for (; k_idx + 7 < kmax; k_idx += pack_k) { | |||
| inptr1 = zerobuff; | |||
| interleave_8x8_mk4_b(inptr0,inptr1,outptr); | |||
| interleave_8x8_mk4_b(inptr0, inptr1, outptr); | |||
| } | |||
| if (k_idx < kmax) { | |||
| @@ -1383,9 +1377,8 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_A(dt_int8* outptr, | |||
| } | |||
| //! pack to nxic | |||
| //! (K/4,N,4) pack to K/8,N,8(ic0~7) ,K is not times of 8 ,pack 0 instead. | |||
| static void gemm_s8x8x16_mk4_8x8x8_pack_B(dt_int8* out, const dt_int8* in, | |||
| int ldin, int n0, int nmax, int k0, | |||
| int kmax) { | |||
| static void gemm_s8x8x16_mk4_8x8x8_pack_B( | |||
| dt_int8* out, const dt_int8* in, int ldin, int n0, int nmax, int k0, int kmax) { | |||
| megdnn_assert(k0 % 4 == 0 && kmax % 4 == 0, "K must be time of 4"); | |||
| constexpr int pack_n = 8; | |||
| @@ -1394,14 +1387,14 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_B(dt_int8* out, const dt_int8* in, | |||
| int8_t tmpbuff0[pack_n * pack_size] = {0}; | |||
| int8_t tmpbuff1[pack_n * pack_size] = {0}; | |||
| int8_t zerobuff[pack_n * pack_size] = {0}; | |||
| const int ksize = round_up<int>((kmax - k0),8); | |||
| const int ksize = round_up<int>((kmax - k0), 8); | |||
| const int nsize = nmax - n0; | |||
| const int n_end = nsize / pack_n * pack_n + n0; | |||
| const int remain_n = nsize % pack_n; | |||
| int output_stride = ksize * pack_n; | |||
| int8_t* outptr_base = out; | |||
| int k_idx = k0; | |||
| for ( ; k_idx + 7 < kmax; k_idx += pack_k) { | |||
| for (; k_idx + 7 < kmax; k_idx += pack_k) { | |||
| const int8_t* inptr0 = in + k_idx / pack_size * ldin + n0 * pack_size; | |||
| const int8_t* inptr1 = inptr0 + ldin; | |||
| prefetch_3x(inptr0); | |||
| @@ -1410,7 +1403,7 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_B(dt_int8* out, const dt_int8* in, | |||
| auto outptr = outptr_base; | |||
| for (int n_idx = n0; n_idx < n_end; n_idx += pack_n) { | |||
| transpose_8x8_mk4_b(inptr0, inptr1, outptr); | |||
| outptr += output_stride; | |||
| outptr += output_stride; | |||
| } | |||
| if (remain_n > 0) { | |||
| memcpy(tmpbuff0, inptr0, sizeof(int8_t) * remain_n * pack_size); | |||
| @@ -1422,8 +1415,8 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_B(dt_int8* out, const dt_int8* in, | |||
| } | |||
| outptr_base += pack_n * pack_k; | |||
| } | |||
| if(k_idx < kmax){ | |||
| if (k_idx < kmax) { | |||
| const int8_t* inptr0 = in + k_idx / pack_size * ldin + n0 * pack_size; | |||
| const int8_t* inptr1 = nullptr; | |||
| prefetch_3x(inptr0); | |||
| @@ -1444,7 +1437,7 @@ static void gemm_s8x8x16_mk4_8x8x8_pack_B(dt_int8* out, const dt_int8* in, | |||
| } | |||
| } | |||
| } // namespace matmul_mk4_16x12x4_a53 | |||
| } // namespace matmul_mk4_8x8x8 | |||
| } // namespace aarch64 | |||
| } // namespace megdnn | |||
| @@ -10,13 +10,13 @@ | |||
| * implied. | |||
| */ | |||
| #include "src/aarch64/matrix_mul/int8x8x16/strategy.h" | |||
| #include "src/aarch64/matrix_mul/asm/common.h" | |||
| #include "src/aarch64/matrix_mul/int8x8x16/kernel_4x4x16.h" | |||
| #include "src/aarch64/matrix_mul/int8x8x16/kernel_8x8x8.h" | |||
| #include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_8x8x8.h" | |||
| #include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_16x12x4_a53.h" | |||
| #include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_4x4x8_a72.h" | |||
| #include "src/aarch64/matrix_mul/int8x8x16/strategy.h" | |||
| #include "src/aarch64/matrix_mul/int8x8x16/kernel_mk4_8x8x8.h" | |||
| #include "src/arm_common/simd_macro/marm_neon.h" | |||
| #include "src/common/utils.h" | |||
| #include "src/fallback/matrix_mul/gemm_common.h" | |||
| @@ -28,39 +28,35 @@ using namespace aarch64::matmul; | |||
| // ===========================gemm_s8x8x16_4x4================================== | |||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_8x8); | |||
| void gemm_s8x8x16_8x8::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, | |||
| int ymax, int k0, int kmax, | |||
| bool transpose) const { | |||
| void gemm_s8x8x16_8x8::pack_A( | |||
| dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax, | |||
| bool transpose) const { | |||
| if (transpose) { | |||
| matmul_8x8x8::gemm_s8x8x16_8x8_transpose_pack_A_n(out, in, ldin, y0, | |||
| ymax, k0, kmax); | |||
| matmul_8x8x8::gemm_s8x8x16_8x8_transpose_pack_A_n( | |||
| out, in, ldin, y0, ymax, k0, kmax); | |||
| } else { | |||
| matmul_8x8x8::gemm_s8x8x16_8x8_pack_A_n(out, in, ldin, y0, ymax, k0, | |||
| kmax); | |||
| matmul_8x8x8::gemm_s8x8x16_8x8_pack_A_n(out, in, ldin, y0, ymax, k0, kmax); | |||
| } | |||
| } | |||
| void gemm_s8x8x16_8x8::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||
| int xmax, int k0, int kmax, | |||
| bool transpose) const { | |||
| void gemm_s8x8x16_8x8::pack_B( | |||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||
| bool transpose) const { | |||
| if (transpose) { | |||
| matmul_8x8x8::gemm_s8x8x16_8x8_transpose_pack_B_n(out, in, ldin, x0, | |||
| xmax, k0, kmax); | |||
| matmul_8x8x8::gemm_s8x8x16_8x8_transpose_pack_B_n( | |||
| out, in, ldin, x0, xmax, k0, kmax); | |||
| } else { | |||
| matmul_8x8x8::gemm_s8x8x16_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, | |||
| kmax); | |||
| matmul_8x8x8::gemm_s8x8x16_8x8_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | |||
| } | |||
| } | |||
| void gemm_s8x8x16_8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||
| size_t M, size_t N, size_t K, dt_int16* C, | |||
| size_t LDC, bool is_first_k, const dt_int16*, | |||
| dt_int16*) const { | |||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
| (A_dtype.enumv() == DTypeEnum::Int8 && | |||
| C_dtype.enumv() == DTypeEnum::Int16), | |||
| "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||
| C_dtype.name()); | |||
| void gemm_s8x8x16_8x8::kern( | |||
| const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||
| dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const { | |||
| megdnn_assert( | |||
| A_dtype.enumv() == B_dtype.enumv() && (A_dtype.enumv() == DTypeEnum::Int8 && | |||
| C_dtype.enumv() == DTypeEnum::Int16), | |||
| "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); | |||
| MEGDNN_MARK_USED_VAR(A_dtype); | |||
| MEGDNN_MARK_USED_VAR(B_dtype); | |||
| MEGDNN_MARK_USED_VAR(C_dtype); | |||
| @@ -79,15 +75,15 @@ void gemm_s8x8x16_8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||
| size_t n = 0; | |||
| const dt_int8* cur_packB = packB; | |||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
| matmul_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, | |||
| is_first_k); | |||
| matmul_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, is_first_k); | |||
| output += B_INTERLEAVE; | |||
| cur_packB += K8; | |||
| } | |||
| for (; n < N; n += 4) { | |||
| matmul_8x8x8::kern_8x4(packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(N - n, 4)); | |||
| matmul_8x8x8::kern_8x4( | |||
| packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(N - n, 4)); | |||
| output += 4; | |||
| cur_packB += K4; | |||
| } | |||
| @@ -99,16 +95,17 @@ void gemm_s8x8x16_8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||
| const dt_int8* cur_packB = packB; | |||
| size_t n = 0; | |||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
| matmul_8x8x8::kern_4x8(packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(M - m, 4)); | |||
| matmul_8x8x8::kern_4x8( | |||
| packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(M - m, 4)); | |||
| output += B_INTERLEAVE; | |||
| cur_packB += K8; | |||
| } | |||
| for (; n < N; n += 4) { | |||
| matmul_8x8x8::kern_4x4(packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(M - m, 4), | |||
| std::min<size_t>(N - n, 4)); | |||
| matmul_8x8x8::kern_4x4( | |||
| packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(M - m, 4), std::min<size_t>(N - n, 4)); | |||
| output += 4; | |||
| cur_packB += K4; | |||
| } | |||
| @@ -119,39 +116,33 @@ void gemm_s8x8x16_8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||
| // ===========================gemm_s8x8x16_4x4================================== | |||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_4x4); | |||
| void gemm_s8x8x16_4x4::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, | |||
| int ymax, int k0, int kmax, | |||
| bool transpose) const { | |||
| void gemm_s8x8x16_4x4::pack_A( | |||
| dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax, | |||
| bool transpose) const { | |||
| if (transpose) { | |||
| matmul_4x4x16::gemm_s8x8x16_4x4_pack_B_n(out, in, ldin, y0, ymax, k0, | |||
| kmax); | |||
| matmul_4x4x16::gemm_s8x8x16_4x4_pack_B_n(out, in, ldin, y0, ymax, k0, kmax); | |||
| } else { | |||
| matmul_4x4x16::gemm_s8x8x16_4x4_pack_A_n(out, in, ldin, y0, ymax, k0, | |||
| kmax); | |||
| matmul_4x4x16::gemm_s8x8x16_4x4_pack_A_n(out, in, ldin, y0, ymax, k0, kmax); | |||
| } | |||
| } | |||
| void gemm_s8x8x16_4x4::pack_B(dt_int8* out, const dt_int8* in, int ldin, int x0, | |||
| int xmax, int k0, int kmax, | |||
| bool transpose) const { | |||
| void gemm_s8x8x16_4x4::pack_B( | |||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||
| bool transpose) const { | |||
| if (transpose) { | |||
| matmul_4x4x16::gemm_s8x8x16_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, | |||
| kmax); | |||
| matmul_4x4x16::gemm_s8x8x16_4x4_pack_A_n(out, in, ldin, x0, xmax, k0, kmax); | |||
| } else { | |||
| matmul_4x4x16::gemm_s8x8x16_4x4_pack_B_n(out, in, ldin, x0, xmax, k0, | |||
| kmax); | |||
| matmul_4x4x16::gemm_s8x8x16_4x4_pack_B_n(out, in, ldin, x0, xmax, k0, kmax); | |||
| } | |||
| } | |||
| void gemm_s8x8x16_4x4::kern(const dt_int8* packA, const dt_int8* packB, | |||
| size_t M, size_t N, size_t K, dt_int16* C, | |||
| size_t LDC, bool is_first_k, const dt_int16*, | |||
| dt_int16*) const { | |||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
| (A_dtype.enumv() == DTypeEnum::Int8 && | |||
| C_dtype.enumv() == DTypeEnum::Int16), | |||
| "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), | |||
| C_dtype.name()); | |||
| void gemm_s8x8x16_4x4::kern( | |||
| const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||
| dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const { | |||
| megdnn_assert( | |||
| A_dtype.enumv() == B_dtype.enumv() && (A_dtype.enumv() == DTypeEnum::Int8 && | |||
| C_dtype.enumv() == DTypeEnum::Int16), | |||
| "A: %s B: %s C: %s", A_dtype.name(), B_dtype.name(), C_dtype.name()); | |||
| MEGDNN_MARK_USED_VAR(A_dtype); | |||
| MEGDNN_MARK_USED_VAR(B_dtype); | |||
| MEGDNN_MARK_USED_VAR(C_dtype); | |||
| @@ -169,16 +160,17 @@ void gemm_s8x8x16_4x4::kern(const dt_int8* packA, const dt_int8* packB, | |||
| size_t n = 0; | |||
| const dt_int8* cur_packB = packB; | |||
| for (; n + B_INTERLEAVE - 1 < N; n += B_INTERLEAVE) { | |||
| matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC, | |||
| is_first_k, A_INTERLEAVE, B_INTERLEAVE); | |||
| matmul_4x4x16::kern_4x4( | |||
| packA, cur_packB, K, output, LDC, is_first_k, A_INTERLEAVE, | |||
| B_INTERLEAVE); | |||
| output += B_INTERLEAVE; | |||
| cur_packB += K4; | |||
| } | |||
| for (; n < N; n += B_INTERLEAVE) { | |||
| matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC, | |||
| is_first_k, A_INTERLEAVE, | |||
| std::min<size_t>(N - n, B_INTERLEAVE)); | |||
| matmul_4x4x16::kern_4x4( | |||
| packA, cur_packB, K, output, LDC, is_first_k, A_INTERLEAVE, | |||
| std::min<size_t>(N - n, B_INTERLEAVE)); | |||
| output += B_INTERLEAVE; | |||
| cur_packB += K4; | |||
| } | |||
| @@ -191,10 +183,10 @@ void gemm_s8x8x16_4x4::kern(const dt_int8* packA, const dt_int8* packB, | |||
| size_t n = 0; | |||
| const dt_int8* cur_packB = packB; | |||
| for (; n < N; n += B_INTERLEAVE) { | |||
| matmul_4x4x16::kern_4x4(packA, cur_packB, K, output, LDC, | |||
| is_first_k, | |||
| std::min<size_t>(M - m, A_INTERLEAVE), | |||
| std::min<size_t>(N - n, B_INTERLEAVE)); | |||
| matmul_4x4x16::kern_4x4( | |||
| packA, cur_packB, K, output, LDC, is_first_k, | |||
| std::min<size_t>(M - m, A_INTERLEAVE), | |||
| std::min<size_t>(N - n, B_INTERLEAVE)); | |||
| output += B_INTERLEAVE; | |||
| cur_packB += K4; | |||
| } | |||
| @@ -205,28 +197,26 @@ void gemm_s8x8x16_4x4::kern(const dt_int8* packA, const dt_int8* packB, | |||
| // ===========================gemm_s8x8x16_mk4_16x12================================== | |||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_16x12_a53); | |||
| void gemm_s8x8x16_mk4_16x12_a53::pack_A(dt_int16* out, const dt_int8* in, | |||
| int ldin, int y0, int ymax, int k0, | |||
| int kmax, bool) const { | |||
| matmul_mk4_16x12x4_a53::gemm_s8x8x16_mk4_16x12_pack_A(out, in, ldin, y0, | |||
| ymax, k0, kmax); | |||
| void gemm_s8x8x16_mk4_16x12_a53::pack_A( | |||
| dt_int16* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax, | |||
| bool) const { | |||
| matmul_mk4_16x12x4_a53::gemm_s8x8x16_mk4_16x12_pack_A( | |||
| out, in, ldin, y0, ymax, k0, kmax); | |||
| } | |||
| void gemm_s8x8x16_mk4_16x12_a53::pack_B(dt_int8* out, const dt_int8* in, | |||
| int ldin, int x0, int xmax, int k0, | |||
| int kmax, bool) const { | |||
| matmul_mk4_16x12x4_a53::gemm_s8x8x16_mk4_16x12_pack_B(out, in, ldin, x0, | |||
| xmax, k0, kmax); | |||
| void gemm_s8x8x16_mk4_16x12_a53::pack_B( | |||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||
| bool) const { | |||
| matmul_mk4_16x12x4_a53::gemm_s8x8x16_mk4_16x12_pack_B( | |||
| out, in, ldin, x0, xmax, k0, kmax); | |||
| } | |||
| void gemm_s8x8x16_mk4_16x12_a53::kern(const dt_int16* packA, | |||
| const dt_int8* packB, size_t M, size_t N, | |||
| size_t K, dt_int16* C, size_t LDC, | |||
| bool is_first_k, const dt_int16*, | |||
| dt_int16*) const { | |||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
| C_dtype.enumv() == DTypeEnum::Int16 && | |||
| A_dtype.enumv() == DTypeEnum::Int8); | |||
| void gemm_s8x8x16_mk4_16x12_a53::kern( | |||
| const dt_int16* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||
| dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const { | |||
| megdnn_assert( | |||
| A_dtype.enumv() == B_dtype.enumv() && C_dtype.enumv() == DTypeEnum::Int16 && | |||
| A_dtype.enumv() == DTypeEnum::Int8); | |||
| megdnn_assert(is_first_k == true, "only impl is_first_k"); | |||
| MEGDNN_MARK_USED_VAR(A_dtype); | |||
| MEGDNN_MARK_USED_VAR(B_dtype); | |||
| @@ -246,14 +236,14 @@ void gemm_s8x8x16_mk4_16x12_a53::kern(const dt_int16* packA, | |||
| size_t n_idx = 0; | |||
| const int8_t* cur_packB = packB; | |||
| for (; n_idx + pack_n <= N; n_idx += pack_n) { | |||
| matmul_mk4_16x12x4_a53::kern_16x12(packA, cur_packB, K, output, LDC, | |||
| is_first_k, pack_n); | |||
| matmul_mk4_16x12x4_a53::kern_16x12( | |||
| packA, cur_packB, K, output, LDC, is_first_k, pack_n); | |||
| output += pack_n * pack_size; | |||
| cur_packB += pack_n * K; | |||
| } | |||
| if (remain_n > 0) { | |||
| matmul_mk4_16x12x4_a53::kern_16x12(packA, cur_packB, K, output, LDC, | |||
| is_first_k, remain_n); | |||
| matmul_mk4_16x12x4_a53::kern_16x12( | |||
| packA, cur_packB, K, output, LDC, is_first_k, remain_n); | |||
| output += remain_n * pack_size; | |||
| cur_packB += pack_n * K; | |||
| } | |||
| @@ -265,14 +255,14 @@ void gemm_s8x8x16_mk4_16x12_a53::kern(const dt_int16* packA, | |||
| size_t n_idx = 0; | |||
| const int8_t* cur_packB = packB; | |||
| for (; n_idx + pack_n <= N; n_idx += pack_n) { | |||
| matmul_mk4_16x12x4_a53::kern_8x12(packA, cur_packB, K, output, LDC, | |||
| is_first_k, pack_n); | |||
| matmul_mk4_16x12x4_a53::kern_8x12( | |||
| packA, cur_packB, K, output, LDC, is_first_k, pack_n); | |||
| output += pack_n * pack_size; | |||
| cur_packB += pack_n * K; | |||
| } | |||
| if (remain_n > 0) { | |||
| matmul_mk4_16x12x4_a53::kern_8x12(packA, cur_packB, K, output, LDC, | |||
| is_first_k, remain_n); | |||
| matmul_mk4_16x12x4_a53::kern_8x12( | |||
| packA, cur_packB, K, output, LDC, is_first_k, remain_n); | |||
| output += remain_n * pack_size; | |||
| cur_packB += pack_n * K; | |||
| } | |||
| @@ -286,14 +276,14 @@ void gemm_s8x8x16_mk4_16x12_a53::kern(const dt_int16* packA, | |||
| size_t n_idx = 0; | |||
| const int8_t* cur_packB = packB; | |||
| for (; n_idx + pack_n <= N; n_idx += pack_n) { | |||
| matmul_mk4_16x12x4_a53::kern_4x12(packA, cur_packB, K, output, LDC, | |||
| is_first_k, pack_n); | |||
| matmul_mk4_16x12x4_a53::kern_4x12( | |||
| packA, cur_packB, K, output, LDC, is_first_k, pack_n); | |||
| output += pack_n * pack_size; | |||
| cur_packB += pack_n * K; | |||
| } | |||
| if (remain_n > 0) { | |||
| matmul_mk4_16x12x4_a53::kern_4x12(packA, cur_packB, K, output, LDC, | |||
| is_first_k, remain_n); | |||
| matmul_mk4_16x12x4_a53::kern_4x12( | |||
| packA, cur_packB, K, output, LDC, is_first_k, remain_n); | |||
| output += remain_n * pack_size; | |||
| cur_packB += pack_n * K; | |||
| } | |||
| @@ -303,27 +293,26 @@ void gemm_s8x8x16_mk4_16x12_a53::kern(const dt_int16* packA, | |||
| // ===========================gemm_s8x8x16_mk4_4x4_a72================================== | |||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_4x4_a72); | |||
| void gemm_s8x8x16_mk4_4x4_a72::pack_A(dt_int8* out, const dt_int8* in, int ldin, | |||
| int y0, int ymax, int k0, int kmax, | |||
| bool) const { | |||
| matmul_mk4_4x4x8_a72::gemm_s8x8x16_mk4_4x4x8_pack_A(out, in, ldin, y0, ymax, | |||
| k0, kmax); | |||
| void gemm_s8x8x16_mk4_4x4_a72::pack_A( | |||
| dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax, | |||
| bool) const { | |||
| matmul_mk4_4x4x8_a72::gemm_s8x8x16_mk4_4x4x8_pack_A( | |||
| out, in, ldin, y0, ymax, k0, kmax); | |||
| } | |||
| void gemm_s8x8x16_mk4_4x4_a72::pack_B(dt_int8* out, const dt_int8* in, int ldin, | |||
| int x0, int xmax, int k0, int kmax, | |||
| bool) const { | |||
| matmul_mk4_4x4x8_a72::gemm_s8x8x16_mk4_4x4x8_pack_B(out, in, ldin, x0, xmax, | |||
| k0, kmax); | |||
| void gemm_s8x8x16_mk4_4x4_a72::pack_B( | |||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||
| bool) const { | |||
| matmul_mk4_4x4x8_a72::gemm_s8x8x16_mk4_4x4x8_pack_B( | |||
| out, in, ldin, x0, xmax, k0, kmax); | |||
| } | |||
| void gemm_s8x8x16_mk4_4x4_a72::kern(const dt_int8* packA, const dt_int8* packB, | |||
| size_t M, size_t N, size_t K, dt_int16* C, | |||
| size_t LDC, bool is_first_k, | |||
| const dt_int16*, dt_int16*) const { | |||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
| C_dtype.enumv() == DTypeEnum::Int16 && | |||
| A_dtype.enumv() == DTypeEnum::Int8); | |||
| void gemm_s8x8x16_mk4_4x4_a72::kern( | |||
| const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||
| dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const { | |||
| megdnn_assert( | |||
| A_dtype.enumv() == B_dtype.enumv() && C_dtype.enumv() == DTypeEnum::Int16 && | |||
| A_dtype.enumv() == DTypeEnum::Int8); | |||
| megdnn_assert(is_first_k == true, "only impl is_first_k"); | |||
| MEGDNN_MARK_USED_VAR(A_dtype); | |||
| MEGDNN_MARK_USED_VAR(B_dtype); | |||
| @@ -343,14 +332,14 @@ void gemm_s8x8x16_mk4_4x4_a72::kern(const dt_int8* packA, const dt_int8* packB, | |||
| const int8_t* cur_packB = packB; | |||
| for (size_t n_idx = 0; n_idx < nend; n_idx += pack_n) { | |||
| matmul_mk4_4x4x8_a72::kern_4x4(packA, cur_packB, K, output, LDC, | |||
| is_first_k, pack_n); | |||
| matmul_mk4_4x4x8_a72::kern_4x4( | |||
| packA, cur_packB, K, output, LDC, is_first_k, pack_n); | |||
| output += pack_n * pack_size; | |||
| cur_packB += pack_n * packed_k; | |||
| } | |||
| if (remain_n > 0) { | |||
| matmul_mk4_4x4x8_a72::kern_4x4(packA, cur_packB, K, output, LDC, | |||
| is_first_k, remain_n); | |||
| matmul_mk4_4x4x8_a72::kern_4x4( | |||
| packA, cur_packB, K, output, LDC, is_first_k, remain_n); | |||
| output += remain_n * pack_size; | |||
| cur_packB += pack_n * packed_k; | |||
| } | |||
| @@ -361,27 +350,24 @@ void gemm_s8x8x16_mk4_4x4_a72::kern(const dt_int8* packA, const dt_int8* packB, | |||
| // ===========================gemm_s8x8x16_mk4_8x8x8================================== | |||
| MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8x8x16_mk4_8x8x8); | |||
| void gemm_s8x8x16_mk4_8x8x8::pack_A(dt_int8* out, const dt_int8* in, | |||
| int ldin, int y0, int ymax, int k0, | |||
| int kmax, bool) const { | |||
| matmul_mk4_8x8x8::gemm_s8x8x16_mk4_8x8x8_pack_A(out, in, ldin, y0, | |||
| ymax, k0, kmax); | |||
| void gemm_s8x8x16_mk4_8x8x8::pack_A( | |||
| dt_int8* out, const dt_int8* in, int ldin, int y0, int ymax, int k0, int kmax, | |||
| bool) const { | |||
| matmul_mk4_8x8x8::gemm_s8x8x16_mk4_8x8x8_pack_A(out, in, ldin, y0, ymax, k0, kmax); | |||
| } | |||
| void gemm_s8x8x16_mk4_8x8x8::pack_B(dt_int8* out, const dt_int8* in, | |||
| int ldin, int x0, int xmax, int k0, | |||
| int kmax, bool) const { | |||
| matmul_mk4_8x8x8::gemm_s8x8x16_mk4_8x8x8_pack_B(out, in, ldin, x0, | |||
| xmax, k0, kmax); | |||
| void gemm_s8x8x16_mk4_8x8x8::pack_B( | |||
| dt_int8* out, const dt_int8* in, int ldin, int x0, int xmax, int k0, int kmax, | |||
| bool) const { | |||
| matmul_mk4_8x8x8::gemm_s8x8x16_mk4_8x8x8_pack_B(out, in, ldin, x0, xmax, k0, kmax); | |||
| } | |||
| void gemm_s8x8x16_mk4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||
| size_t M, size_t N, size_t K, dt_int16* C, | |||
| size_t LDC, bool is_first_k, const dt_int16*, | |||
| dt_int16*) const { | |||
| megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && | |||
| C_dtype.enumv() == DTypeEnum::Int16 && | |||
| A_dtype.enumv() == DTypeEnum::Int8); | |||
| void gemm_s8x8x16_mk4_8x8x8::kern( | |||
| const dt_int8* packA, const dt_int8* packB, size_t M, size_t N, size_t K, | |||
| dt_int16* C, size_t LDC, bool is_first_k, const dt_int16*, dt_int16*) const { | |||
| megdnn_assert( | |||
| A_dtype.enumv() == B_dtype.enumv() && C_dtype.enumv() == DTypeEnum::Int16 && | |||
| A_dtype.enumv() == DTypeEnum::Int8); | |||
| megdnn_assert(is_first_k == true, "only impl is_first_k"); | |||
| MEGDNN_MARK_USED_VAR(A_dtype); | |||
| MEGDNN_MARK_USED_VAR(B_dtype); | |||
| @@ -402,14 +388,14 @@ void gemm_s8x8x16_mk4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||
| size_t n_idx = 0; | |||
| const int8_t* cur_packB = packB; | |||
| for (; n_idx + pack_n <= N; n_idx += pack_n) { | |||
| matmul_mk4_8x8x8::kern_8x8(packA, cur_packB, K, output, LDC, | |||
| is_first_k, pack_m, pack_n); | |||
| matmul_mk4_8x8x8::kern_8x8( | |||
| packA, cur_packB, K, output, LDC, is_first_k, pack_m, pack_n); | |||
| output += pack_n * pack_size; | |||
| cur_packB += KSIZE8; | |||
| } | |||
| if (remain_n > 0) { | |||
| matmul_mk4_8x8x8::kern_8x8_remain(packA, cur_packB, K, output, LDC, | |||
| is_first_k, pack_m, remain_n); | |||
| matmul_mk4_8x8x8::kern_8x8_remain( | |||
| packA, cur_packB, K, output, LDC, is_first_k, pack_m, remain_n); | |||
| output += remain_n * pack_size; | |||
| cur_packB += KSIZE8; | |||
| } | |||
| @@ -421,14 +407,14 @@ void gemm_s8x8x16_mk4_8x8x8::kern(const dt_int8* packA, const dt_int8* packB, | |||
| size_t n_idx = 0; | |||
| const int8_t* cur_packB = packB; | |||
| for (; n_idx + pack_n <= N; n_idx += pack_n) { | |||
| matmul_mk4_8x8x8::kern_4x8(packA, cur_packB, K, output, LDC, | |||
| is_first_k, 4, pack_n); | |||
| matmul_mk4_8x8x8::kern_4x8( | |||
| packA, cur_packB, K, output, LDC, is_first_k, 4, pack_n); | |||
| output += pack_n * pack_size; | |||
| cur_packB += pack_n * K; | |||
| } | |||
| if (remain_n > 0) { | |||
| matmul_mk4_8x8x8::kern_4x8_remain(packA, cur_packB, K, output, LDC, | |||
| is_first_k, 4, remain_n); | |||
| matmul_mk4_8x8x8::kern_4x8_remain( | |||
| packA, cur_packB, K, output, LDC, is_first_k, 4, remain_n); | |||
| output += remain_n * pack_size; | |||
| cur_packB += pack_n * K; | |||
| } | |||