diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/conv_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/conv_fp16.cc index 8bd1e2f023..e02f7710ae 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/conv_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/conv_fp16.cc @@ -215,7 +215,7 @@ void Conv3x3Fp16(float16_t *input_data, float16_t *transed_weight, const float16 for (int c = 0; c < output_channel; c++) { int oc8_block = c / C8NUM; int oc8_res = c % C8NUM; - int src_offset = oc8_block * C8NUM * out_w_block * out_h_block * tile_num + + int src_offset = oc8_block * C8NUM * out_w_block * out_h_block * C4NUM * C4NUM + C8NUM * (h * out_w_block * output_unit + w) + oc8_res; int dst_offset = (h * output_w + w) * output_channel + c; (output_data + dst_offset)[0] = (tmp_out + src_offset)[0]; diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/winograd_transform_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/winograd_transform_fp16.cc index e585509825..941d3bc705 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/winograd_transform_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/winograd_transform_fp16.cc @@ -508,25 +508,25 @@ void Conv3x3Fp16OutputUnit(const float16_t *gemm_out, const float16_t *bias_data void Conv3x3Fp16OutputTransform(const float16_t *gemm_out, float16_t *out_data, const float16_t *bias_data, int start_index, int real_cal_num, int out_w_block, ConvParameter *conv_param) { int output_channel = conv_param->output_channel_; - int output_w = conv_param->output_w_; int output_h = conv_param->output_h_; + int out_h_block = UP_DIV(output_h, C4NUM); int oc8 = UP_DIV(output_channel, C8NUM); -// todo outputw --> out_w_block * out_unit + for (int i = 0; i < real_cal_num; i++) { int out_w_index = (start_index + i) % out_w_block; int out_h_index = (start_index + i) / out_w_block; int src_tile_offset = i * oc8 * C8NUM * 36; - int dst_tile_offset = 8 * (out_w_index * 4 + out_h_index * 4 * output_w); + int dst_tile_offset = C8NUM * (out_w_index * C4NUM + out_h_index * C4NUM * out_w_block * C4NUM); for (int j = 0; j < oc8; j++) { int src_oc8_offset = src_tile_offset + j * 36 * C8NUM; - int dst_oc8_offset = dst_tile_offset + j * C8NUM * output_h * output_w; + int dst_oc8_offset = dst_tile_offset + j * C8NUM * out_h_block * out_w_block * C4NUM * C4NUM; const float16_t *src_ptr = gemm_out + src_oc8_offset; const float16_t *bias_ptr = bias_data + j * C8NUM; float16_t *dst_ptr = out_data + dst_oc8_offset; // output transform - Conv3x3Fp16OutputUnit(src_ptr, bias_ptr, dst_ptr, output_w); + Conv3x3Fp16OutputUnit(src_ptr, bias_ptr, dst_ptr, out_w_block * C4NUM); } } }