|
|
|
@@ -400,9 +400,9 @@ kernel::LiteKernel *CpuConvInt8KernelCreator(const std::vector<lite::tensor::Ten |
|
|
|
kernel::LiteKernel *kernel; |
|
|
|
auto filter_quant_size = inputs[kWeightIndex]->GetQuantParams().size(); |
|
|
|
if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) { |
|
|
|
kernel = new (std::nothrow) kernel::Convolution3x3Int8CPUKernel(opParameter, inputs, outputs, ctx, primitive); |
|
|
|
kernel = new (std::nothrow) kernel::ConvolutionInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); |
|
|
|
} else if (kernel_h == 1 && kernel_w == 1 && filter_quant_size == 1) { |
|
|
|
kernel = new (std::nothrow) kernel::Convolution1x1Int8CPUKernel(opParameter, inputs, outputs, ctx, primitive); |
|
|
|
kernel = new (std::nothrow) kernel::ConvolutionInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); |
|
|
|
} else { |
|
|
|
kernel = new (std::nothrow) kernel::ConvolutionInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); |
|
|
|
} |
|
|
|
|