Browse Source

fix qrshl for negative numbers

tags/v1.1.0
lixian 5 years ago
parent
commit
74a1f9b87a
1 changed files with 12 additions and 0 deletions
  1. +12
    -0
      mindspore/lite/nnacl/winograd_transform.c

+ 12
- 0
mindspore/lite/nnacl/winograd_transform.c View File

@@ -685,6 +685,9 @@ void Conv3x3Int8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, in


d00 = vqshlq_s32(d00, ls); d00 = vqshlq_s32(d00, ls);
d00 = vqrdmulhq_s32(d00, out_multiplier); d00 = vqrdmulhq_s32(d00, out_multiplier);
int32x4_t carry = vandq_s32(d00, rs);
carry = vshrq_n_s32(carry, 31);
d00 = vqaddq_s32(d00, carry);
d00 = vqrshlq_s32(d00, rs); d00 = vqrshlq_s32(d00, rs);
d00 = vaddq_s32(d00, out_zp); d00 = vaddq_s32(d00, out_zp);
d00 = vmaxq_s32(d00, output_min); d00 = vmaxq_s32(d00, output_min);
@@ -692,6 +695,9 @@ void Conv3x3Int8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, in


d01 = vqshlq_s32(d01, ls); d01 = vqshlq_s32(d01, ls);
d01 = vqrdmulhq_s32(d01, out_multiplier); d01 = vqrdmulhq_s32(d01, out_multiplier);
carry = vandq_s32(d01, rs);
carry = vshrq_n_s32(carry, 31);
d01 = vqaddq_s32(d01, carry);
d01 = vqrshlq_s32(d01, rs); d01 = vqrshlq_s32(d01, rs);
d01 = vaddq_s32(d01, out_zp); d01 = vaddq_s32(d01, out_zp);
d01 = vmaxq_s32(d01, output_min); d01 = vmaxq_s32(d01, output_min);
@@ -699,6 +705,9 @@ void Conv3x3Int8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, in


d10 = vqshlq_s32(d10, ls); d10 = vqshlq_s32(d10, ls);
d10 = vqrdmulhq_s32(d10, out_multiplier); d10 = vqrdmulhq_s32(d10, out_multiplier);
carry = vandq_s32(d10, rs);
carry = vshrq_n_s32(carry, 31);
d10 = vqaddq_s32(d10, carry);
d10 = vqrshlq_s32(d10, rs); d10 = vqrshlq_s32(d10, rs);
d10 = vaddq_s32(d10, out_zp); d10 = vaddq_s32(d10, out_zp);
d10 = vmaxq_s32(d10, output_min); d10 = vmaxq_s32(d10, output_min);
@@ -706,6 +715,9 @@ void Conv3x3Int8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, in


d11 = vqshlq_s32(d11, ls); d11 = vqshlq_s32(d11, ls);
d11 = vqrdmulhq_s32(d11, out_multiplier); d11 = vqrdmulhq_s32(d11, out_multiplier);
carry = vandq_s32(d11, rs);
carry = vshrq_n_s32(carry, 31);
d11 = vqaddq_s32(d11, carry);
d11 = vqrshlq_s32(d11, rs); d11 = vqrshlq_s32(d11, rs);
d11 = vaddq_s32(d11, out_zp); d11 = vaddq_s32(d11, out_zp);
d11 = vmaxq_s32(d11, output_min); d11 = vmaxq_s32(d11, output_min);


Loading…
Cancel
Save