Merge pull request !7122 from liuzhongkai/winogradtags/v1.1.0
| @@ -0,0 +1,206 @@ | |||
| #ifdef __aarch64__ | |||
| .text | |||
| .align 5 | |||
| .global MatrixMultiplyWinogradFp16 | |||
| #ifndef __APPLE__ | |||
| .type MatrixMultiplyWinogradFp16, %function | |||
| #endif | |||
| // MatrixMultiplyWinogradFp16(float16_t *matix_a, float16_t *matrix_b, float16_t *matrix_c, int m, int k, int n, int in_channel) | |||
| // x0: matrix_a, x1: matrix_b, x2: matrix_c, x3: m, x4: k, x5: n, x6: in_channel | |||
| MatrixMultiplyWinogradFp16: | |||
| // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to | |||
| // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers | |||
| // x19 ~ x29 should be also preserved | |||
| // whereas our coding style do not permit such amount of parameters | |||
| sub sp, sp, #48 | |||
| st1 {v8.8h}, [sp], #16 | |||
| stp x19, x20, [sp], #16 | |||
| stp x21, x22, [sp], #16 | |||
| mov x8, #2 | |||
| mul x10, x5, x8 // n * 2 | |||
| mov x17, x3 // m | |||
| mul x13, x6, x8 // in_channel * 2 | |||
| mul x21, x13, x4 // in_channel * k * 2 | |||
| LoopM: | |||
| mov x15, x5 // n | |||
| mov x14, x1 // mat_b | |||
| LoopN: | |||
| mov x16, x0 // mat_a_m | |||
| sub x18, x5, x15 // ni | |||
| sub x19, x17, x3 // mi | |||
| mul x18, x18, x17 // ni * m | |||
| mov x11, x6 // in_channel | |||
| add x18, x18, x19 // (ni * m) + mi | |||
| mul x18, x18, x13 // x18 * channel_in * 2 | |||
| add x20, x2, x18 // dst + offset | |||
| cmp x11, #32 | |||
| bge LoopC32 | |||
| cmp x11, #16 | |||
| bge LoopC16 | |||
| cmp x11, #8 | |||
| bge LoopC8 | |||
| cmp x11, #4 | |||
| bge LoopC4 | |||
| cmp x11, #1 | |||
| bge LoopC | |||
| b EndLoopC | |||
| LoopC32: | |||
| mov x12, x14 | |||
| mov x9, x4 // new_k | |||
| dup v5.8h, wzr | |||
| dup v6.8h, wzr | |||
| dup v7.8h, wzr | |||
| dup v8.8h, wzr | |||
| LoopK32: | |||
| ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x16], x13 | |||
| ldr h4, [x12] | |||
| add x12, x12, x10 | |||
| fmla v5.8h, v0.8h, v4.h[0] | |||
| fmla v6.8h, v1.8h, v4.h[0] | |||
| fmla v7.8h, v2.8h, v4.h[0] | |||
| fmla v8.8h, v3.8h, v4.h[0] | |||
| subs x9, x9, #1 | |||
| bne LoopK32 | |||
| Write32: | |||
| st1 {v5.8h}, [x20], #16 | |||
| st1 {v6.8h}, [x20], #16 | |||
| st1 {v7.8h}, [x20], #16 | |||
| st1 {v8.8h}, [x20], #16 | |||
| sub x16, x16, x21 // back x13 * k | |||
| add x16, x16, #64 // add 64B | |||
| subs x11, x11, #32 | |||
| beq EndLoopC | |||
| cmp x11, #32 | |||
| bge LoopC32 | |||
| cmp x11, #16 | |||
| bge LoopC16 | |||
| cmp x11, #8 | |||
| bge LoopC8 | |||
| cmp x11, #4 | |||
| bge LoopC4 | |||
| cmp x11, #1 | |||
| bge LoopC | |||
| LoopC16: | |||
| dup v5.8h, wzr | |||
| dup v6.8h, wzr | |||
| mov x9, x4 // new_k | |||
| mov x12, x14 | |||
| LoopK16: | |||
| ld1 {v0.8h, v1.8h}, [x16], x13 | |||
| ldr h4, [x12] | |||
| add x12, x12, x10 | |||
| fmla v5.8h, v0.8h, v4.h[0] | |||
| fmla v6.8h, v1.8h, v4.h[0] | |||
| subs x9, x9, #1 | |||
| bne LoopK16 | |||
| Write16: | |||
| st1 {v5.8h}, [x20], #16 | |||
| st1 {v6.8h}, [x20], #16 | |||
| sub x16, x16, x21 // back x13 * k | |||
| add x16, x16, #32 // add 32B | |||
| subs x11, x11, #16 | |||
| beq EndLoopC | |||
| cmp x11, #16 | |||
| bge LoopC16 | |||
| cmp x11, #8 | |||
| bge LoopC8 | |||
| cmp x11, #4 | |||
| bge LoopC4 | |||
| cmp x11, #1 | |||
| bge LoopC | |||
| LoopC8: | |||
| dup v5.8h, wzr | |||
| mov x9, x4 // new_k | |||
| mov x12, x14 | |||
| LoopK8: | |||
| ld1 {v0.8h}, [x16], x13 | |||
| ldr h4, [x12] | |||
| add x12, x12, x10 | |||
| fmla v5.8h, v0.8h, v4.h[0] | |||
| subs x9, x9, #1 | |||
| bne LoopK8 | |||
| Write8: | |||
| st1 {v5.8h}, [x20], #16 | |||
| sub x16, x16, x21 // ptr back x13 * k | |||
| add x16, x16, #16 // add 16B | |||
| subs x11, x11, #8 | |||
| beq EndLoopC | |||
| cmp x11, #8 | |||
| bge LoopC8 | |||
| cmp x11, #4 | |||
| bge LoopC4 | |||
| cmp x11, #1 | |||
| bge LoopC | |||
| LoopC4: | |||
| dup v5.4h, wzr | |||
| mov x9, x4 // new_k | |||
| mov x12, x14 | |||
| LoopK4: | |||
| ld1 {v0.4h}, [x16], x13 | |||
| ldr h4, [x12] | |||
| add x12, x12, x10 | |||
| fmla v5.4h, v0.4h, v4.h[0] | |||
| subs x9, x9, #1 | |||
| bne LoopK4 | |||
| Write4: | |||
| st1 {v5.4h}, [x20], #8 | |||
| sub x16, x16, x21 // ptr back x13 * k | |||
| add x16, x16, #8 // add 8B | |||
| subs x11, x11, #4 | |||
| beq EndLoopC | |||
| cmp x11, #4 | |||
| bge LoopC4 | |||
| cmp x11, #1 | |||
| bge LoopC | |||
| LoopC: | |||
| dup v5.8h, wzr | |||
| mov x9, x4 // new_k | |||
| mov x12, x14 | |||
| LoopK: | |||
| ldr h0, [x16] | |||
| add x16, x16, x13 | |||
| ldr h4, [x12] | |||
| add x12, x12, x10 | |||
| fmul h0, h0, h4 | |||
| fadd h5, h5, h0 | |||
| subs x9, x9, #1 | |||
| bne LoopK | |||
| Write: | |||
| str h5, [x20], #2 | |||
| sub x16, x16, x21 // ptr back x13 * k | |||
| add x16, x16, #2 // ptr add 2B | |||
| subs x11, x11, #1 | |||
| beq EndLoopC | |||
| b LoopC | |||
| EndLoopC: | |||
| add x14, x14, #2 | |||
| subs x15, x15, #1 | |||
| beq EndLoopN | |||
| b LoopN | |||
| EndLoopN: | |||
| subs x3, x3, #1 | |||
| beq EndLoopM | |||
| add x0, x0, x21 | |||
| b LoopM | |||
| EndLoopM: | |||
| sub sp, sp, #48 | |||
| st1 {v8.8h}, [sp], #16 | |||
| ldp x19, x20, [sp], #16 | |||
| ldp x21, x22, [sp], #16 | |||
| ret | |||
| #endif | |||
| @@ -32,6 +32,24 @@ void MatrixMultiplyFp16(const float16_t *matrix_a, const float16_t *matrix_b, fl | |||
| } | |||
| } | |||
| #ifndef ENABLE_ARM64 | |||
| void MatrixMultiplyWinogradFp16(const float16_t *matix_a, const float16_t *matrix_b, float16_t *matrix_c, int m, int k, | |||
| int n, int in_channel) { | |||
| int cnt = 0; | |||
| for (int i = 0; i < m; ++i) { | |||
| for (int j = 0; j < n; ++j) { | |||
| for (int y = 0; y < in_channel; ++y) { | |||
| float16_t tmp = 0; | |||
| for (int z = 0; z < k; ++z) { | |||
| tmp += matix_a[z * in_channel + y + i * in_channel * k] * matrix_b[j + z * n]; | |||
| } | |||
| matrix_c[cnt++] = tmp; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| #endif | |||
| void MatrixMultiplyVecFp16(const float16x8_t *matrix_a, const float16x8_t *matrix_b, float16x8_t *matrix_c, | |||
| const float16_t *bias, int m, int k, int n) { | |||
| if (bias == NULL) { | |||
| @@ -26,6 +26,8 @@ void MatrixMultiplyFp16(const float16_t *matrix_a, const float16_t *matrix_b, fl | |||
| void MatrixMultiplyVecFp16(const float16x8_t *matrix_a, const float16x8_t *matrix_b, float16x8_t *matrix_c, | |||
| const float16_t *bias, int m, int k, int n); | |||
| void MatrixMultiplyWinogradFp16(const float16_t *matix_a, const float16_t *matrix_b, float16_t *matrix_c, int m, int k, | |||
| int n, int in_channel); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| @@ -65,6 +65,14 @@ void Im2ColPackUnitFp16(float16_t *input_data, ConvParameter *conv_param, float1 | |||
| } // tile num loop | |||
| } | |||
| void PackHWCToWHCFp16(const float16_t *src, float16_t *dst, int height, int width, int channel) { | |||
| for (int i = 0; i < height; ++i) { | |||
| for (int j = 0; j < width; ++j) { | |||
| memcpy(dst + (j * height + i) * channel, src + (i * width + j) * channel, channel * sizeof(float16_t)); | |||
| } | |||
| } | |||
| } | |||
| void PackWeightToC8Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param) { | |||
| // origin weight format : ohwi | |||
| int input_channel = conv_param->input_channel_; | |||
| @@ -31,6 +31,8 @@ void Im2ColPackUnitFp16(float16_t *input_data, ConvParameter *conv_param, float1 | |||
| void PackWeightToC8Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param); | |||
| void PackHWCToWHCFp16(const float16_t *src, float16_t *dst, int height, int width, int channel); | |||
| void PackWeightToC4Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param); | |||
| void PackNHWCToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int channel); | |||
| @@ -36,91 +36,106 @@ using mindspore::schema::PrimitiveType_Conv2D; | |||
| namespace mindspore::kernel { | |||
| int ConvolutionWinogradFP16CPUKernel::WinogradFilterTransformFp16(const float16_t *weight_data, float *matrix_g, | |||
| float *matrix_gt, int oc_block) { | |||
| if (oc_block == 0) { | |||
| MS_LOG(ERROR) << "Divide by zero"; | |||
| return RET_ERROR; | |||
| } | |||
| // original weight format : ohwi | |||
| auto channel_in = conv_param_->input_channel_; | |||
| auto channel_out = conv_param_->output_channel_; | |||
| int input_unit_square = input_unit_ * input_unit_; | |||
| int oc_block_num = UP_DIV(channel_out, oc_block); | |||
| int block_stride = channel_in * oc_block; | |||
| int block_num_stride = block_stride * oc_block_num; | |||
| auto matrix_g_data_fp16 = reinterpret_cast<float16_t *>(malloc(input_unit_ * kernel_unit_ * sizeof(float16_t))); | |||
| if (matrix_g_data_fp16 == nullptr) { | |||
| MS_LOG(ERROR) << "malloc matrix_g_data_fp16 failed."; | |||
| return RET_ERROR; | |||
| } | |||
| auto matrix_gt_data_fp16 = reinterpret_cast<float16_t *>(malloc(input_unit_ * kernel_unit_ * sizeof(float16_t))); | |||
| if (matrix_gt_data_fp16 == nullptr) { | |||
| free(matrix_g_data_fp16); | |||
| MS_LOG(ERROR) << "malloc matrix_gt_data_fp16 failed."; | |||
| return RET_ERROR; | |||
| } | |||
| Float32ToFloat16(matrix_g, matrix_g_data_fp16, input_unit_ * kernel_unit_); | |||
| Float32ToFloat16(matrix_gt, matrix_gt_data_fp16, input_unit_ * kernel_unit_); | |||
| // trans_filter = G*g*GT (g represents weight_data) | |||
| // separate into two steps ===> tmp = G*g ===> out = tmp * GT | |||
| auto tmp_weight_data = reinterpret_cast<float16_t *>(malloc(kernel_unit_ * kernel_unit_ * sizeof(float16_t))); | |||
| if (tmp_weight_data == nullptr) { | |||
| free(matrix_g_data_fp16); | |||
| free(matrix_gt_data_fp16); | |||
| MS_LOG(ERROR) << "malloc tmp_weight_data failed."; | |||
| return RET_ERROR; | |||
| } | |||
| auto tmp_data = reinterpret_cast<float16_t *>(malloc(input_unit_ * kernel_unit_ * sizeof(float16_t))); | |||
| // trans_filter = G*g*GT (g represents weight_data) = [(g * (G)T)T * (G)T]T | |||
| // separate into two steps ===> tmp = (g * (G)T)T ===> out = [tmp * (G)T]T | |||
| auto tmp_data = reinterpret_cast<float16_t *>(malloc(channel_in * input_unit_ * kernel_unit_ * sizeof(float16_t))); | |||
| if (tmp_data == nullptr) { | |||
| free(tmp_weight_data); | |||
| free(matrix_g_data_fp16); | |||
| free(matrix_gt_data_fp16); | |||
| MS_LOG(ERROR) << "malloc tmp_data failed."; | |||
| return RET_ERROR; | |||
| } | |||
| auto trans_out_data = reinterpret_cast<float16_t *>(malloc(input_unit_ * input_unit_ * sizeof(float16_t))); | |||
| auto trans_out_data = | |||
| reinterpret_cast<float16_t *>(malloc(channel_in * input_unit_ * input_unit_ * sizeof(float16_t))); | |||
| if (trans_out_data == nullptr) { | |||
| free(tmp_data); | |||
| free(tmp_weight_data); | |||
| free(matrix_g_data_fp16); | |||
| free(matrix_gt_data_fp16); | |||
| MS_LOG(ERROR) << "malloc trans_out_data failed."; | |||
| return RET_ERROR; | |||
| } | |||
| if (oc_block == 0) { | |||
| MS_LOG(ERROR) << "Divide by zero"; | |||
| free(tmp_weight_data); | |||
| #ifndef ENABLE_ARM64 | |||
| auto tmp_data1 = reinterpret_cast<float16_t *>(malloc(channel_in * input_unit_ * kernel_unit_ * sizeof(float16_t))); | |||
| if (tmp_data1 == nullptr) { | |||
| free(tmp_data); | |||
| free(matrix_gt_data_fp16); | |||
| free(trans_out_data); | |||
| free(matrix_g_data_fp16); | |||
| MS_LOG(ERROR) << "malloc tmp_data1 failed."; | |||
| return RET_ERROR; | |||
| } | |||
| auto trans_out_data1 = | |||
| reinterpret_cast<float16_t *>(malloc(channel_in * input_unit_ * input_unit_ * sizeof(float16_t))); | |||
| if (trans_out_data1 == nullptr) { | |||
| free(tmp_data); | |||
| free(tmp_data1); | |||
| free(matrix_gt_data_fp16); | |||
| free(trans_out_data); | |||
| MS_LOG(ERROR) << "malloc trans_out_data1 failed."; | |||
| return RET_ERROR; | |||
| } | |||
| int stride1 = channel_in * oc_block; | |||
| #endif | |||
| int input_oz_offset = kernel_unit_ * kernel_unit_ * channel_in; | |||
| for (int i = 0; i < channel_out; i++) { | |||
| int out_c_block = i / oc_block; | |||
| int out_c_res = i % oc_block; | |||
| int input_oz_offset = i * kernel_unit_ * kernel_unit_ * channel_in; | |||
| int output_oz_offset = out_c_block * stride1 + out_c_res; | |||
| for (int j = 0; j < channel_in; j++) { | |||
| int input_iz_offset = input_oz_offset + j; | |||
| int output_iz_offset = output_oz_offset + j * oc_block; | |||
| for (int k = 0; k < kernel_unit_ * kernel_unit_; k++) { | |||
| int input_xy_offset = input_iz_offset + k * channel_in; | |||
| tmp_weight_data[k] = *(weight_data + input_xy_offset); | |||
| } | |||
| // now we only support row-major matrix-multiply | |||
| // tmp = G * g | |||
| MatrixMultiplyFp16(matrix_g_data_fp16, tmp_weight_data, tmp_data, input_unit_, kernel_unit_, kernel_unit_); | |||
| // out = tmp * GT | |||
| MatrixMultiplyFp16(tmp_data, matrix_gt_data_fp16, trans_out_data, input_unit_, kernel_unit_, input_unit_); | |||
| for (int z = 0; z < input_unit_square; z++) { | |||
| int output_xy_offset = output_iz_offset + z * oc_block_num * stride1; | |||
| trans_weight_[output_xy_offset] = trans_out_data[z]; | |||
| int output_oz_offset = out_c_block * block_stride + out_c_res; | |||
| #ifndef ENABLE_ARM64 | |||
| // tmp_data = g * GT | |||
| MatrixMultiplyWinogradFp16(weight_data + i * input_oz_offset, matrix_gt_data_fp16, tmp_data, kernel_unit_, | |||
| kernel_unit_, input_unit_, channel_in); | |||
| // tmp_data1 = (tmp_data)T | |||
| PackHWCToWHCFp16(tmp_data, tmp_data1, kernel_unit_, input_unit_, channel_in); | |||
| // trans_out_data1 = tmp * GT | |||
| MatrixMultiplyWinogradFp16(tmp_data1, matrix_gt_data_fp16, trans_out_data1, input_unit_, kernel_unit_, input_unit_, | |||
| channel_in); | |||
| // trans_out_data = (trans_out_data1)T | |||
| PackHWCToWHCFp16(trans_out_data1, trans_out_data, input_unit_, input_unit_, channel_in); | |||
| #else | |||
| // tmp = (g * GT)T | |||
| MatrixMultiplyWinogradFp16(weight_data + i * input_oz_offset, matrix_gt_data_fp16, tmp_data, kernel_unit_, | |||
| kernel_unit_, input_unit_, channel_in); | |||
| // trans = (tmp * GT)T | |||
| MatrixMultiplyWinogradFp16(tmp_data, matrix_gt_data_fp16, trans_out_data, input_unit_, kernel_unit_, input_unit_, | |||
| channel_in); | |||
| #endif | |||
| int in_offset = 0; | |||
| for (int j = 0; j < input_unit_; ++j) { | |||
| for (int k = 0; k < input_unit_; ++k) { | |||
| for (int c = 0; c < channel_in; ++c) { | |||
| *(trans_weight_ + output_oz_offset + c * oc_block) = trans_out_data[in_offset + c]; | |||
| } | |||
| in_offset += channel_in; | |||
| output_oz_offset += block_num_stride; | |||
| } | |||
| } | |||
| } | |||
| free(tmp_weight_data); | |||
| #ifndef ENABLE_ARM64 | |||
| free(tmp_data1); | |||
| free(trans_out_data1); | |||
| #endif | |||
| free(tmp_data); | |||
| free(trans_out_data); | |||
| free(matrix_g_data_fp16); | |||
| free(matrix_gt_data_fp16); | |||
| return RET_OK; | |||
| } | |||