diff --git a/mindspore/lite/nnacl/pack.c b/mindspore/lite/nnacl/pack.c index 457dfc2603..a39c1f862d 100644 --- a/mindspore/lite/nnacl/pack.c +++ b/mindspore/lite/nnacl/pack.c @@ -14,10 +14,10 @@ * limitations under the License. */ -#include "nnacl/pack.h" -#include #include +#include #include "nnacl/int8/conv_int8.h" +#include "nnacl/pack.h" void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel) { return PackNCHWToNHWCFp32(src, dst, 1, plane, channel); @@ -458,11 +458,35 @@ void PackWeightToC8Int8(const int8_t *origin_weight_data, int16_t *packed_weight } int src_oc_offset = src_kernel_offset + o * kernel_plane * input_channel; int dst_oc_offset = dst_kernel_offset + o * ic8 * kernel_plane * C8NUM; - for (int i = 0; i < input_channel; i++) { - int c8_block_num = i / C8NUM; + int i = 0; + for (; i < (ic8 - 1); i += C8NUM) { + int src_ic_offset = src_oc_offset + i; + int dst_ic_offset = dst_oc_offset + i * kernel_plane; +#ifdef ENABLE_ARM64 + int8x8_t src_s8 = vld1_s8(origin_weight_data + src_ic_offset); + int16x8_t src_s16 = vmovl_s8(src_s8); + int16x4_t src1_s16 = vget_low_s16(src_s16); + int16x4_t src2_s16 = vget_high_s16(src_s16); + int32x4_t src1_s32 = vmovl_s16(src1_s16); + int32x4_t src2_s32 = vmovl_s16(src2_s16); + int32x4_t zp_s32 = vdupq_n_s32(zp); + int32x4_t dst1_s32 = vsubq_s32(src1_s32, zp_s32); + int32x4_t dst2_s32 = vsubq_s32(src2_s32, zp_s32); + int16x4_t dst1_s16 = vqmovn_s32(dst1_s32); + int16x4_t dst2_s16 = vqmovn_s32(dst2_s32); + vst1_s16(packed_weight_data + dst_ic_offset, dst1_s16); + vst1_s16(packed_weight_data + dst_ic_offset + 4, dst2_s16); +#else + for (int ci = 0; ci < C8NUM; ++ci) { + (packed_weight_data + dst_ic_offset + ci)[0] = (int16_t)((origin_weight_data + src_ic_offset + ci)[0] - zp); + } +#endif + } + dst_oc_offset += (ic8 - 1) * kernel_plane * C8NUM; + for (; i < input_channel; i++) { int c8_block_rem = i % C8NUM; int src_ic_offset = src_oc_offset + i; - int dst_ic_offset = dst_oc_offset + c8_block_num * kernel_plane * C8NUM + c8_block_rem; + int dst_ic_offset = dst_oc_offset + c8_block_rem; (packed_weight_data + dst_ic_offset)[0] = (int16_t)((origin_weight_data + src_ic_offset)[0] - zp); } }