|
|
|
@@ -10,13 +10,13 @@ |
|
|
|
* implied. |
|
|
|
*/ |
|
|
|
|
|
|
|
#include "src/x86/matrix_mul/algos.h" |
|
|
|
#include "midout.h" |
|
|
|
#include "src/common/utils.h" |
|
|
|
#include "src/fallback/matrix_mul/gemm_impl.h" |
|
|
|
#include "src/x86/matrix_mul/algos.h" |
|
|
|
#include "src/x86/matrix_mul/f32/strategy.h" |
|
|
|
#include "src/x86/matrix_mul/int8/strategy.h" |
|
|
|
|
|
|
|
#include "src/x86/matrix_mul/f32/strategy.h" |
|
|
|
#include "midout.h" |
|
|
|
|
|
|
|
MIDOUT_DECL(megdnn_x86_matmul_kern) |
|
|
|
MIDOUT_DECL(megdnn_x86_matmul_kern_mk8_8x8) |
|
|
|
@@ -170,6 +170,7 @@ bool MatrixMulImpl::AlgoInt8x8x32Vnni::usable( |
|
|
|
(kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 && |
|
|
|
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32)) && |
|
|
|
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && |
|
|
|
kern_size_param.format == Param::Format::DEFAULT && |
|
|
|
preferred(kern_size_param) && is_supported(SIMDType::VNNI); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -230,6 +231,7 @@ bool MatrixMulImpl::AlgoInt8x8x32Mkldnn::usable( |
|
|
|
(kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 && |
|
|
|
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32)) && |
|
|
|
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && |
|
|
|
kern_size_param.format == Param::Format::DEFAULT && |
|
|
|
is_supported(SIMDType::VNNI) && preferred(kern_size_param); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -365,8 +367,10 @@ bool MatrixMulImpl::AlgoInt8x8x16AVX2::usable( |
|
|
|
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS16)); |
|
|
|
bool is_mode_ok = |
|
|
|
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && |
|
|
|
kern_size_param.format == Param::Format::DEFAULT && |
|
|
|
is_supported(SIMDType::AVX2); |
|
|
|
bool is_param_ok = is_ab_same && is_type_ok && is_mode_ok; |
|
|
|
|
|
|
|
return is_param_ok; |
|
|
|
} |
|
|
|
bool MatrixMulImpl::AlgoInt8x8x16AVX2::preferred(const KernSizeParam&) const { |
|
|
|
@@ -440,6 +444,7 @@ bool MatrixMulImpl::AlgoInt8x8x16SSE::usable( |
|
|
|
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS16)); |
|
|
|
bool is_mode_ok = |
|
|
|
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && |
|
|
|
kern_size_param.format == Param::Format::DEFAULT && |
|
|
|
is_supported(SIMDType::SSE4_1); |
|
|
|
bool is_param_ok = is_ab_same && is_type_ok && is_mode_ok; |
|
|
|
return is_param_ok; |
|
|
|
@@ -478,13 +483,16 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::get_kern( |
|
|
|
} |
|
|
|
bool MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::usable( |
|
|
|
const KernSizeParam& kern_size_param) const { |
|
|
|
return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() && |
|
|
|
((kern_size_param.A_type.enumv() == DTypeEnum::Int8 && |
|
|
|
kern_size_param.C_type.enumv() == DTypeEnum::Int32) || |
|
|
|
(kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 && |
|
|
|
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32)) && |
|
|
|
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && |
|
|
|
is_supported(SIMDType::AVX2); |
|
|
|
bool is_param_ok = |
|
|
|
kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() && |
|
|
|
((kern_size_param.A_type.enumv() == DTypeEnum::Int8 && |
|
|
|
kern_size_param.C_type.enumv() == DTypeEnum::Int32) || |
|
|
|
(kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 && |
|
|
|
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32)) && |
|
|
|
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && |
|
|
|
kern_size_param.format == Param::Format::DEFAULT && |
|
|
|
is_supported(SIMDType::AVX2); |
|
|
|
return is_param_ok; |
|
|
|
} |
|
|
|
size_t MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::get_workspace( |
|
|
|
const KernSizeParam& kern_param) const { |
|
|
|
@@ -522,6 +530,7 @@ bool MatrixMulImpl::AlgoInt8x8x32AVX2M2N4K16::usable( |
|
|
|
(kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 && |
|
|
|
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32)) && |
|
|
|
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && |
|
|
|
kern_size_param.format == Param::Format::DEFAULT && |
|
|
|
is_supported(SIMDType::AVX2); |
|
|
|
} |
|
|
|
size_t MatrixMulImpl::AlgoInt8x8x32AVX2M2N4K16::get_workspace( |
|
|
|
@@ -562,6 +571,7 @@ bool MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2::usable( |
|
|
|
(kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 && |
|
|
|
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32)) && |
|
|
|
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && |
|
|
|
kern_size_param.format == Param::Format::DEFAULT && |
|
|
|
is_supported(SIMDType::SSE4_1); |
|
|
|
} |
|
|
|
size_t MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2::get_workspace( |
|
|
|
|