| @@ -14,10 +14,10 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "nnacl/pack.h" | |||||
| #include <string.h> | |||||
| #include <stdlib.h> | #include <stdlib.h> | ||||
| #include <string.h> | |||||
| #include "nnacl/int8/conv_int8.h" | #include "nnacl/int8/conv_int8.h" | ||||
| #include "nnacl/pack.h" | |||||
| void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel) { | void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel) { | ||||
| return PackNCHWToNHWCFp32(src, dst, 1, plane, channel); | return PackNCHWToNHWCFp32(src, dst, 1, plane, channel); | ||||
| @@ -458,11 +458,35 @@ void PackWeightToC8Int8(const int8_t *origin_weight_data, int16_t *packed_weight | |||||
| } | } | ||||
| int src_oc_offset = src_kernel_offset + o * kernel_plane * input_channel; | int src_oc_offset = src_kernel_offset + o * kernel_plane * input_channel; | ||||
| int dst_oc_offset = dst_kernel_offset + o * ic8 * kernel_plane * C8NUM; | int dst_oc_offset = dst_kernel_offset + o * ic8 * kernel_plane * C8NUM; | ||||
| for (int i = 0; i < input_channel; i++) { | |||||
| int c8_block_num = i / C8NUM; | |||||
| int i = 0; | |||||
| for (; i < (ic8 - 1); i += C8NUM) { | |||||
| int src_ic_offset = src_oc_offset + i; | |||||
| int dst_ic_offset = dst_oc_offset + i * kernel_plane; | |||||
| #ifdef ENABLE_ARM64 | |||||
| int8x8_t src_s8 = vld1_s8(origin_weight_data + src_ic_offset); | |||||
| int16x8_t src_s16 = vmovl_s8(src_s8); | |||||
| int16x4_t src1_s16 = vget_low_s16(src_s16); | |||||
| int16x4_t src2_s16 = vget_high_s16(src_s16); | |||||
| int32x4_t src1_s32 = vmovl_s16(src1_s16); | |||||
| int32x4_t src2_s32 = vmovl_s16(src2_s16); | |||||
| int32x4_t zp_s32 = vdupq_n_s32(zp); | |||||
| int32x4_t dst1_s32 = vsubq_s32(src1_s32, zp_s32); | |||||
| int32x4_t dst2_s32 = vsubq_s32(src2_s32, zp_s32); | |||||
| int16x4_t dst1_s16 = vqmovn_s32(dst1_s32); | |||||
| int16x4_t dst2_s16 = vqmovn_s32(dst2_s32); | |||||
| vst1_s16(packed_weight_data + dst_ic_offset, dst1_s16); | |||||
| vst1_s16(packed_weight_data + dst_ic_offset + 4, dst2_s16); | |||||
| #else | |||||
| for (int ci = 0; ci < C8NUM; ++ci) { | |||||
| (packed_weight_data + dst_ic_offset + ci)[0] = (int16_t)((origin_weight_data + src_ic_offset + ci)[0] - zp); | |||||
| } | |||||
| #endif | |||||
| } | |||||
| dst_oc_offset += (ic8 - 1) * kernel_plane * C8NUM; | |||||
| for (; i < input_channel; i++) { | |||||
| int c8_block_rem = i % C8NUM; | int c8_block_rem = i % C8NUM; | ||||
| int src_ic_offset = src_oc_offset + i; | int src_ic_offset = src_oc_offset + i; | ||||
| int dst_ic_offset = dst_oc_offset + c8_block_num * kernel_plane * C8NUM + c8_block_rem; | |||||
| int dst_ic_offset = dst_oc_offset + c8_block_rem; | |||||
| (packed_weight_data + dst_ic_offset)[0] = (int16_t)((origin_weight_data + src_ic_offset)[0] - zp); | (packed_weight_data + dst_ic_offset)[0] = (int16_t)((origin_weight_data + src_ic_offset)[0] - zp); | ||||
| } | } | ||||
| } | } | ||||