Browse Source

fix(dnn/x86): fix x86 matrix usable ignore format

GitOrigin-RevId: 40fe508aca
tags/v1.0.0-rc1
Megvii Engine Team Xinran Xu 5 years ago
parent
commit
40e79e9dab
1 changed files with 20 additions and 10 deletions
  1. +20
    -10
      dnn/src/x86/matrix_mul/algos.cpp

+ 20
- 10
dnn/src/x86/matrix_mul/algos.cpp View File

@@ -10,13 +10,13 @@
* implied.
*/

#include "src/x86/matrix_mul/algos.h"
#include "midout.h"
#include "src/common/utils.h"
#include "src/fallback/matrix_mul/gemm_impl.h"
#include "src/x86/matrix_mul/algos.h"
#include "src/x86/matrix_mul/f32/strategy.h"
#include "src/x86/matrix_mul/int8/strategy.h"

#include "src/x86/matrix_mul/f32/strategy.h"
#include "midout.h"

MIDOUT_DECL(megdnn_x86_matmul_kern)
MIDOUT_DECL(megdnn_x86_matmul_kern_mk8_8x8)
@@ -170,6 +170,7 @@ bool MatrixMulImpl::AlgoInt8x8x32Vnni::usable(
(kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 &&
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32)) &&
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == Param::Format::DEFAULT &&
preferred(kern_size_param) && is_supported(SIMDType::VNNI);
}

@@ -230,6 +231,7 @@ bool MatrixMulImpl::AlgoInt8x8x32Mkldnn::usable(
(kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 &&
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32)) &&
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == Param::Format::DEFAULT &&
is_supported(SIMDType::VNNI) && preferred(kern_size_param);
}

@@ -365,8 +367,10 @@ bool MatrixMulImpl::AlgoInt8x8x16AVX2::usable(
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS16));
bool is_mode_ok =
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == Param::Format::DEFAULT &&
is_supported(SIMDType::AVX2);
bool is_param_ok = is_ab_same && is_type_ok && is_mode_ok;

return is_param_ok;
}
bool MatrixMulImpl::AlgoInt8x8x16AVX2::preferred(const KernSizeParam&) const {
@@ -440,6 +444,7 @@ bool MatrixMulImpl::AlgoInt8x8x16SSE::usable(
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS16));
bool is_mode_ok =
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == Param::Format::DEFAULT &&
is_supported(SIMDType::SSE4_1);
bool is_param_ok = is_ab_same && is_type_ok && is_mode_ok;
return is_param_ok;
@@ -478,13 +483,16 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::get_kern(
}
bool MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::usable(
const KernSizeParam& kern_size_param) const {
return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() &&
((kern_size_param.A_type.enumv() == DTypeEnum::Int8 &&
kern_size_param.C_type.enumv() == DTypeEnum::Int32) ||
(kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 &&
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32)) &&
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
is_supported(SIMDType::AVX2);
bool is_param_ok =
kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() &&
((kern_size_param.A_type.enumv() == DTypeEnum::Int8 &&
kern_size_param.C_type.enumv() == DTypeEnum::Int32) ||
(kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 &&
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32)) &&
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == Param::Format::DEFAULT &&
is_supported(SIMDType::AVX2);
return is_param_ok;
}
size_t MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::get_workspace(
const KernSizeParam& kern_param) const {
@@ -522,6 +530,7 @@ bool MatrixMulImpl::AlgoInt8x8x32AVX2M2N4K16::usable(
(kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 &&
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32)) &&
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == Param::Format::DEFAULT &&
is_supported(SIMDType::AVX2);
}
size_t MatrixMulImpl::AlgoInt8x8x32AVX2M2N4K16::get_workspace(
@@ -562,6 +571,7 @@ bool MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2::usable(
(kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8 &&
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32)) &&
kern_size_param.compute_mode == Param::ComputeMode::DEFAULT &&
kern_size_param.format == Param::Format::DEFAULT &&
is_supported(SIMDType::SSE4_1);
}
size_t MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2::get_workspace(


Loading…
Cancel
Save