| @@ -251,12 +251,6 @@ void ConvFp32(float *input_data, float *packed_input, float *packed_weight, cons | |||
| } | |||
| } | |||
| // fp32 conv1x1 strassen matmul | |||
| int Conv1x1Fp32(const float *input_data, const float *weight_data, float *output_data, float *tmp_ptr, | |||
| StrassenMatMulParameter matmul_param) { | |||
| return StrassenMatmul(input_data, weight_data, output_data, &matmul_param, FP32_STRASSEN_MAX_RECURSION, 0, tmp_ptr); | |||
| } | |||
| // fp32 conv winograd | |||
| void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_data, TmpBufferAddress *buffer_list, | |||
| int task_id, ConvParameter *conv_param, InputTransformUnitFunc input_trans_func, | |||
| @@ -24,7 +24,6 @@ | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/common_func.h" | |||
| #include "nnacl/conv_parameter.h" | |||
| #include "nnacl/fp32/strassen_matmul.h" | |||
| #include "nnacl/winograd_utils.h" | |||
| #include "nnacl/fp32/conv_depthwise.h" | |||
| @@ -52,10 +51,6 @@ void ConvFp32(float *input_data, float *packed_input, float *packed_weight, cons | |||
| float *tmp_out_block, float *output_data, int task_id, ConvParameter *conv_param, | |||
| GEMM_FUNC_FP32 gemm_func); | |||
| // fp32 conv1x1 strassen matmul | |||
| int Conv1x1Fp32(const float *input_data, const float *weight_data, float *output_data, float *tmp_ptr, | |||
| StrassenMatMulParameter matmul_param); | |||
| // fp32 convolution winograd | |||
| void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_data, TmpBufferAddress *buffer_list, | |||
| int task_id, ConvParameter *conv_param, InputTransformUnitFunc input_trans_func, | |||
| @@ -33,18 +33,18 @@ void PackDeConvWeightFp32(const float *weight, float *dst, int input_channel, in | |||
| return; | |||
| } | |||
| int DeConvPostFp32C8x8(const float *src, float *tmp, const float *bias, float *dst, int output_channel, | |||
| ConvParameter *conv_param) { | |||
| /* row8x8-major(ih*iw x oc*kh*kw) -> row8-major(oh*ow x oc) */ | |||
| int DeConvPostFp32C12x8(const float *src, float *tmp, const float *bias, float *dst, int output_channel, | |||
| ConvParameter *conv_param) { | |||
| /* row12x8-major(ih*iw x oc*kh*kw) -> row8-major(oh*ow x oc) */ | |||
| size_t input_plane = conv_param->input_w_ * conv_param->input_h_; | |||
| size_t kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_; | |||
| size_t output_plane = conv_param->output_w_ * conv_param->output_h_; | |||
| int oc8 = UP_ROUND(output_channel, C8NUM); | |||
| int in_plane8 = UP_ROUND(input_plane, C8NUM); | |||
| int in_plane12 = UP_ROUND(input_plane, C12NUM); | |||
| int src_iw_stride = C8NUM; | |||
| int src_ih_stride = conv_param->input_w_ * C8NUM; | |||
| int src_kw_stride = in_plane8 * C8NUM; | |||
| int src_kh_stride = in_plane8 * conv_param->kernel_w_ * C8NUM; | |||
| int src_kw_stride = in_plane12 * C8NUM; | |||
| int src_kh_stride = in_plane12 * conv_param->kernel_w_ * C8NUM; | |||
| int dst_oh_stride = conv_param->output_w_ * C8NUM; | |||
| int dst_ow_stride = C8NUM; | |||
| int dst_kh_stride = conv_param->dilation_h_ * conv_param->output_w_ * C8NUM; | |||
| @@ -52,7 +52,7 @@ int DeConvPostFp32C8x8(const float *src, float *tmp, const float *bias, float *d | |||
| for (int c = 0; c < oc8; c += 8) { | |||
| float *dst_ptr = tmp + c * output_plane; | |||
| const float *src_ptr = src + c * in_plane8 * kernel_plane; | |||
| const float *src_ptr = src + c * in_plane12 * kernel_plane; | |||
| memset(dst_ptr, 0, output_plane * C8NUM * sizeof(float)); | |||
| for (int ih = 0; ih < conv_param->input_h_; ih++) { | |||
| @@ -101,41 +101,3 @@ int DeConvPostFp32C8x8(const float *src, float *tmp, const float *bias, float *d | |||
| conv_param->is_relu6_); | |||
| return NNACL_OK; | |||
| } | |||
| int DeConvPostFp32C4(const float *src, float *tmp_c4, float *dst, const float *bias, int output_channel, | |||
| int input_plane, int kernel_plane, int output_plane, ConvParameter *conv_param) { | |||
| int oc4 = UP_DIV(output_channel, C4NUM); | |||
| for (int c = 0; c < oc4; c++) { | |||
| float *dst_ptr = tmp_c4 + c * output_plane * C4NUM; | |||
| const float *src_ptr = src + c * input_plane * kernel_plane * C4NUM; | |||
| memset(dst_ptr, 0, output_plane * C4NUM * sizeof(float)); | |||
| for (int ih = 0; ih < conv_param->input_h_; ih++) { | |||
| for (int iw = 0; iw < conv_param->input_w_; iw++) { | |||
| int oh = ih * conv_param->stride_h_ - conv_param->pad_h_; | |||
| int ow = iw * conv_param->stride_w_ - conv_param->pad_w_; | |||
| int kh_start = MSMAX(0, UP_DIV(-oh, conv_param->dilation_h_)); | |||
| int kh_end = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->output_h_ - oh, conv_param->dilation_h_)); | |||
| int kw_start = MSMAX(0, UP_DIV(-ow, conv_param->dilation_w_)); | |||
| int kw_end = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->output_w_ - ow, conv_param->dilation_w_)); | |||
| for (int kh = kh_start; kh < kh_end; kh++) { | |||
| for (int kw = kw_start; kw < kw_end; kw++) { | |||
| int src_index = ih * conv_param->input_w_ * C4NUM + iw * C4NUM + | |||
| kh * input_plane * conv_param->kernel_w_ * C4NUM + kw * input_plane * C4NUM; | |||
| int dst_index = oh * conv_param->output_w_ * C4NUM + ow * C4NUM + | |||
| kh * conv_param->dilation_h_ * conv_param->output_w_ * C4NUM + | |||
| kw * conv_param->dilation_w_ * C4NUM; | |||
| for (int i = 0; i < C4NUM; i++) { | |||
| dst_ptr[dst_index + i] += src_ptr[src_index + i]; | |||
| } | |||
| } /*kw*/ | |||
| } /*kh*/ | |||
| } /*iw*/ | |||
| } /*ih*/ | |||
| } /*oc4*/ | |||
| PostConvFuncFp32C4(tmp_c4, dst, bias, output_channel, output_plane, conv_param->output_channel_, conv_param->is_relu_, | |||
| conv_param->is_relu6_); | |||
| return NNACL_OK; | |||
| } | |||
| @@ -16,20 +16,19 @@ | |||
| #ifndef MINDSPORE_LITE_NNACL_FP32_DECONV_H_ | |||
| #define MINDSPORE_LITE_NNACL_FP32_DECONV_H_ | |||
| #include <string.h> | |||
| #include "nnacl/pack.h" | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/conv_parameter.h" | |||
| #include "nnacl/fp32/strassen_matmul.h" | |||
| #include "nnacl/errorcode.h" | |||
| #include "nnacl/fp32/common_func.h" | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| void PackDeConvWeightFp32(const float *weight, float *dst, int input_channel, int output_channel, int plane); | |||
| int DeConvPostFp32C4(const float *src, float *tmp_c4, float *dst, const float *bias, int output_channel, | |||
| int input_plane, int kernel_plane, int output_plane, ConvParameter *conv_param); | |||
| int DeConvPostFp32C8x8(const float *src, float *tmp_out, const float *bias, float *dst, int output_channel, | |||
| ConvParameter *conv_param); | |||
| int DeConvPostFp32C12x8(const float *src, float *tmp_out, const float *bias, float *dst, int output_channel, | |||
| ConvParameter *conv_param); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| @@ -28,6 +28,18 @@ void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col) { | |||
| return; | |||
| } | |||
| void RowMajor2Row12Major(float *src_ptr, float *dst_ptr, int row, int col) { | |||
| for (int r = 0; r < row; r++) { | |||
| float *src = src_ptr + r * col; | |||
| for (int c = 0; c < col; c++) { | |||
| int cd8 = c / C12NUM; | |||
| int cm8 = c % C12NUM; | |||
| dst_ptr[cd8 * C12NUM * row + r * C12NUM + cm8] = src[c]; | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| void RowMajor2Col12Major(float *src_ptr, float *dst_ptr, size_t row, size_t col) { | |||
| size_t row12 = row / C12NUM * C12NUM; | |||
| size_t col4 = col / C4NUM * C4NUM; | |||
| @@ -323,18 +335,18 @@ void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col | |||
| return; | |||
| } | |||
| void MatMul8x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row, | |||
| int col, int stride, bool write_nhwc) { | |||
| if (write_nhwc) { | |||
| void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row, | |||
| int col, int stride, int out_type) { | |||
| if (out_type == OutType_Nhwc) { | |||
| /* col8-major * row8-major => col-major */ | |||
| for (int r = 0; r < row; r++) { | |||
| for (int c = 0; c < col; c++) { | |||
| int r8div = r / 8, r8mod = r % 8; | |||
| int r12div = r / 12, r12mod = r % 12; | |||
| int c8div = c / 8, c8mod = c % 8; | |||
| size_t ci = r * stride + c; | |||
| float value = 0; | |||
| for (int d = 0; d < deep; d++) { | |||
| size_t ai = r8div * deep * 8 + d * 8 + r8mod; | |||
| size_t ai = r12div * deep * 12 + d * 12 + r12mod; | |||
| size_t bi = c8div * deep * 8 + d * 8 + c8mod; | |||
| value = value + a[ai] * b[bi]; | |||
| } | |||
| @@ -345,18 +357,20 @@ void MatMul8x8(const float *a, const float *b, float *dst, const float *bias, Ac | |||
| } | |||
| } | |||
| } else { | |||
| /* col8-major * row8-major => col8x8-major */ | |||
| /* col8-major * row8-major => col12x8-major */ | |||
| int col_8 = UP_ROUND(col, C8NUM); | |||
| int row_8 = UP_ROUND(row, C8NUM); | |||
| for (int r = 0; r < row_8; r++) { | |||
| int row_12 = UP_ROUND(row, C12NUM); | |||
| for (int r = 0; r < row_12; r++) { | |||
| for (int c = 0; c < col_8; c++) { | |||
| int r8div = r / 8, r8mod = r % 8; | |||
| int c8div = c / 8, c8mod = c % 8; | |||
| size_t ci = c8div * row_8 * 8 + r * 8 + c8mod; | |||
| int r12div = r / C12NUM, r12mod = r % C12NUM; | |||
| int c8div = c / C8NUM, c8mod = c % C8NUM; | |||
| int c4div = c / C4NUM, c4mod = c % C4NUM; | |||
| size_t ci = (out_type == OutType_C4) ? (c4div * C4NUM * row_12 + r * C4NUM + c4mod) | |||
| : (c8div * C8NUM * row_12 + r * C8NUM + c8mod); | |||
| float value = 0; | |||
| for (int d = 0; d < deep; d++) { | |||
| size_t ai = r8div * deep * 8 + d * 8 + r8mod; | |||
| size_t bi = c8div * deep * 8 + d * 8 + c8mod; | |||
| size_t ai = r12div * deep * C12NUM + d * C12NUM + r12mod; | |||
| size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod; | |||
| value = value + a[ai] * b[bi]; | |||
| } | |||
| if (bias != NULL) value += bias[c]; | |||
| @@ -369,45 +383,12 @@ void MatMul8x8(const float *a, const float *b, float *dst, const float *bias, Ac | |||
| return; | |||
| } | |||
| void MatMul12x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row, | |||
| int col, size_t stride, size_t writeNhwc, size_t writeC4) { | |||
| if (writeNhwc != 0) { | |||
| /* col8-major * row8-major => col-major */ | |||
| for (int r = 0; r < row; r++) { | |||
| for (int c = 0; c < col; c++) { | |||
| int r12div = r / 12, r12mod = r % 12; | |||
| int c8div = c / 8, c8mod = c % 8; | |||
| size_t ci = r * stride + c; | |||
| float value = 0; | |||
| for (int d = 0; d < deep; d++) { | |||
| size_t ai = r12div * deep * 12 + d * 12 + r12mod; | |||
| size_t bi = c8div * deep * 8 + d * 8 + c8mod; | |||
| value = value + a[ai] * b[bi]; | |||
| } | |||
| if (bias != NULL) value += bias[c]; | |||
| if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); | |||
| if (act_type != ActType_No) value = MSMAX(0.0f, value); | |||
| dst[ci] = value; | |||
| } | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, int col, | |||
| int stride, bool write_nhwc) { | |||
| #ifdef ENABLE_ARM64 | |||
| MatmulFloatNeon64(a, b, c, bias, (int)act_type, deep, row, col, stride, write_nhwc); | |||
| #else | |||
| MatMul8x8(a, b, c, bias, act_type, deep, row, col, stride, write_nhwc); | |||
| #endif | |||
| } | |||
| void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, | |||
| int col, size_t stride, size_t writeNhwc, size_t writeC4) { | |||
| int col, size_t stride, int out_type) { | |||
| #ifdef ENABLE_ARM64 | |||
| MatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, writeNhwc, writeC4); | |||
| MatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type == OutType_Nhwc), | |||
| (int)(out_type == OutType_C4)); | |||
| #else | |||
| MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, writeNhwc, writeC4); | |||
| MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, out_type); | |||
| #endif | |||
| } | |||
| @@ -26,11 +26,11 @@ | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int depth, int row, int col, | |||
| int stride, bool write_nhwc); | |||
| void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int depth, int row, | |||
| int col, size_t stride, size_t writeNhwc, size_t writeC4); | |||
| void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, | |||
| int col, size_t stride, int out_type); | |||
| void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col); | |||
| void RowMajor2Row12Major(float *src_ptr, float *dst_ptr, int row, int col); | |||
| void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col); | |||
| void RowMajor2Col12Major(float *src_ptr, float *dst_ptr, size_t row, size_t col); | |||
| void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col, size_t stride); | |||
| @@ -38,7 +38,7 @@ void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col | |||
| void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, | |||
| int col, size_t stride, bool write_nhwc); | |||
| void MatmulFloatNeon64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row, | |||
| int col, size_t stride, size_t writeNhwc, size_t writeC4); | |||
| int col, size_t stride, size_t write_nhwc, size_t write_c4); | |||
| #endif | |||
| #ifdef __cplusplus | |||
| } | |||
| @@ -1,204 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "nnacl/fp32/strassen_matmul.h" | |||
| bool CheckRecursion(int row, int col, int deep, int max_recursion, int cur_recursion) { | |||
| if (cur_recursion >= max_recursion) { | |||
| return false; | |||
| } | |||
| if (row % 2 != 0 || col % 2 != 0 || deep % 2 != 0) { | |||
| return false; | |||
| } | |||
| int row2 = row / 2; | |||
| int col2 = col / 2; | |||
| int deep2 = deep / 2; | |||
| float save_cost = row * col * 4 * deep * 4 * 2 + row * col * 4 - | |||
| 7 * (row2 * col2 * 4 * deep2 * 4 * 2 - row2 * col2 * 4) - 4 * (row2 * deep2 * 4 * 3) - | |||
| 4 * (deep2 * 4 * col2 * 4 * 3) - 7 * (row2 * col2 * 4 * 3); | |||
| return (save_cost > 0.f); | |||
| } | |||
| void GemmMatMulComm(const float *a_ptr, const float *b_ptr, float *dst_ptr, int row, int col, int deep, int b_stride, | |||
| int c_stride) { | |||
| int row4mod = row % 4; | |||
| int row4div = row / 4; | |||
| for (int r = 0; r < row; r++) { | |||
| int r4mod = r % 4; | |||
| int r4div = r / 4; | |||
| for (int c = 0; c < col * 4; c++) { | |||
| float value = 0; | |||
| int ic = c / 4 * c_stride + r * 4 + c % 4; | |||
| for (int d = 0; d < deep * 4; d++) { | |||
| int d4mod = d % 4; | |||
| int d4div = d / 4; | |||
| int a_stride = (r < (row4div * 4)) ? 4 : row4mod; | |||
| int ai = r4div * 4 * deep * 4 + d4div * a_stride * 4 + r4mod * 4 + d4mod; | |||
| int bi = c / 4 * b_stride + d * 4 + c % 4; | |||
| value = value + a_ptr[ai] * b_ptr[bi]; | |||
| } | |||
| dst_ptr[ic] = value; | |||
| } | |||
| } | |||
| return; | |||
| } | |||
| void GemmMatMul(const float *a_ptr, const float *b_ptr, float *dst_ptr, int row, int col, int deep, int b_stride, | |||
| int c_stride) { | |||
| int row4mod = row % 4; | |||
| int row4div = row / 4; | |||
| if (row4div > 0) { | |||
| GemmMatMulComm(a_ptr, b_ptr, dst_ptr, row4div * 4, col, deep, b_stride, c_stride); | |||
| } | |||
| if (row4mod != 0) { | |||
| GemmMatMulComm(a_ptr + row4div * deep * 4 * 4, b_ptr, dst_ptr + row4div * 4 * 4, row4mod, col, deep, b_stride, | |||
| c_stride); | |||
| } | |||
| return; | |||
| } | |||
| int RecursionMatmul(const float *a_ptr, const float *b_ptr, float *c_ptr, StrassenMatMulParameter *matmul_param, | |||
| int max_recursion, int cur_recursion, float *tmp_a_ptr) { | |||
| size_t row2 = matmul_param->row_ / 2; | |||
| size_t deep2 = matmul_param->deep_ / 2; | |||
| size_t col2 = matmul_param->col_ / 2; | |||
| size_t a_stride = matmul_param->a_stride_; | |||
| size_t b_stride = matmul_param->b_stride_; | |||
| size_t c_stride = matmul_param->c_stride_; | |||
| StrassenMatMulParameter rec_matmul; | |||
| rec_matmul.row_ = row2; | |||
| rec_matmul.deep_ = deep2; | |||
| rec_matmul.col_ = col2; | |||
| float *x_ptr = (float *)(malloc(row2 * MSMAX(deep2, col2) * FP32_STRASSEN_UINT * sizeof(float))); | |||
| if (x_ptr == NULL) { | |||
| return NNACL_ERRCODE_STRASSEN_RECURSION_MALLOC; | |||
| } | |||
| float *y_ptr = (float *)(malloc(col2 * deep2 * FP32_STRASSEN_WEIGHT_UINT * sizeof(float))); | |||
| if (y_ptr == NULL) { | |||
| free(x_ptr); | |||
| return NNACL_ERRCODE_STRASSEN_RECURSION_MALLOC; | |||
| } | |||
| size_t x_stride = row2 * FP32_STRASSEN_UINT; | |||
| size_t y_stride = deep2 * FP32_STRASSEN_WEIGHT_UINT; | |||
| const float *a11 = a_ptr; | |||
| const float *a12 = a_ptr + deep2 * a_stride; | |||
| const float *a21 = a_ptr + row2 * FP32_STRASSEN_UINT; | |||
| const float *a22 = a_ptr + deep2 * a_stride + row2 * FP32_STRASSEN_UINT; | |||
| const float *b11 = b_ptr; | |||
| const float *b12 = b_ptr + col2 * b_stride; | |||
| const float *b21 = b_ptr + deep2 * FP32_STRASSEN_WEIGHT_UINT; | |||
| const float *b22 = b_ptr + col2 * b_stride + deep2 * FP32_STRASSEN_WEIGHT_UINT; | |||
| float *c11 = c_ptr; | |||
| float *c12 = c_ptr + col2 * c_stride; | |||
| float *c21 = c_ptr + row2 * FP32_STRASSEN_UINT; | |||
| float *c22 = c_ptr + col2 * c_stride + row2 * FP32_STRASSEN_UINT; | |||
| /* S3 = A11 - A21 */ | |||
| MatrixSub(a11, a21, x_ptr, a_stride, a_stride, x_stride, row2, deep2); | |||
| /* T3 = B22 - B12 */ | |||
| MatrixSub(b22, b12, y_ptr, b_stride, b_stride, y_stride, deep2 * 4, col2); | |||
| /* P7 = S3T3 */ | |||
| rec_matmul.a_stride_ = x_stride; | |||
| rec_matmul.b_stride_ = y_stride; | |||
| rec_matmul.c_stride_ = c_stride; | |||
| StrassenMatmul(x_ptr, y_ptr, c21, &rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr); | |||
| /* S1 = A21 + A22 */ | |||
| MatrixAdd(a21, a22, x_ptr, a_stride, a_stride, x_stride, row2, deep2); | |||
| /* T1 = B12 - B11 */ | |||
| MatrixSub(b12, b11, y_ptr, b_stride, b_stride, y_stride, deep2 * 4, col2); | |||
| /* P5 = S1T1 */ | |||
| StrassenMatmul(x_ptr, y_ptr, c22, &rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr); | |||
| /* S2 = S1 - A11 */ | |||
| MatrixSub(x_ptr, a11, x_ptr, x_stride, a_stride, x_stride, row2, deep2); | |||
| /* T2 = B22 - T1 */ | |||
| MatrixSub(b22, y_ptr, y_ptr, b_stride, y_stride, y_stride, deep2 * 4, col2); | |||
| /* P6 = S2T2 */ | |||
| StrassenMatmul(x_ptr, y_ptr, c12, &rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr); | |||
| /* S4 = A12 - S2 */ | |||
| MatrixSub(a12, x_ptr, x_ptr, a_stride, x_stride, x_stride, row2, deep2); | |||
| /* P3 = S4B22 */ | |||
| rec_matmul.b_stride_ = b_stride; | |||
| StrassenMatmul(x_ptr, b22, c11, &rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr); | |||
| /* P1 = A11B11 */ | |||
| rec_matmul.a_stride_ = a_stride; | |||
| rec_matmul.c_stride_ = row2 * FP32_STRASSEN_UINT; | |||
| StrassenMatmul(a11, b11, x_ptr, &rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr); | |||
| /* U2 = P1 + P6 | |||
| U3 = U2 + P7 | |||
| U4 = U2 + P5 | |||
| U7 = U3 + P5 | |||
| U5 = U4 + P3 */ | |||
| MatrixMultiAdd(c11, c12, c21, c22, x_ptr, row2, col2, c_stride, x_stride); | |||
| /* T4 = T2 - B21 */ | |||
| MatrixSub(y_ptr, b21, y_ptr, y_stride, b_stride, y_stride, deep2 * 4, col2); | |||
| /* P4 = A22T4 */ | |||
| rec_matmul.b_stride_ = y_stride; | |||
| rec_matmul.c_stride_ = c_stride; | |||
| StrassenMatmul(a22, y_ptr, c11, &rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr); | |||
| /* U6 = U3 - P4 */ | |||
| MatrixSub(c21, c11, c21, c_stride, c_stride, c_stride, row2, col2); | |||
| /* P2 = A12B21 */ | |||
| rec_matmul.b_stride_ = b_stride; | |||
| StrassenMatmul(a12, b21, c11, &rec_matmul, max_recursion, cur_recursion + 1, tmp_a_ptr); | |||
| /* U1 = P1 + P2 */ | |||
| MatrixAdd(x_ptr, c11, c11, x_stride, c_stride, c_stride, row2, col2); | |||
| free(x_ptr); | |||
| free(y_ptr); | |||
| return NNACL_OK; | |||
| } | |||
| int CommonMatMul(const float *a_ptr, const float *b_ptr, float *c_ptr, StrassenMatMulParameter *matmul_param, | |||
| float *tmp_a_ptr) { | |||
| MatrixPack(a_ptr, tmp_a_ptr, matmul_param->row_, matmul_param->deep_, matmul_param->a_stride_); | |||
| GemmMatMul(tmp_a_ptr, b_ptr, c_ptr, matmul_param->row_, matmul_param->col_, matmul_param->deep_, | |||
| matmul_param->b_stride_, matmul_param->c_stride_); | |||
| return NNACL_OK; | |||
| } | |||
| int StrassenMatmul(const float *a_ptr, const float *b_ptr, float *c_ptr, StrassenMatMulParameter *matmul_param, | |||
| int max_recursion, int cur_recursion, float *tmp_a_ptr) { | |||
| if (CheckRecursion(matmul_param->row_, matmul_param->col_, matmul_param->deep_, cur_recursion, max_recursion)) { | |||
| return RecursionMatmul(a_ptr, b_ptr, c_ptr, matmul_param, max_recursion, cur_recursion, tmp_a_ptr); | |||
| } | |||
| return CommonMatMul(a_ptr, b_ptr, c_ptr, matmul_param, tmp_a_ptr); | |||
| } | |||
| @@ -1,45 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_NNACL_FP32_STRASSEN_MATMUL_H_ | |||
| #define MINDSPORE_LITE_NNACL_FP32_STRASSEN_MATMUL_H_ | |||
| #include <memory.h> | |||
| #include "nnacl/pack.h" | |||
| #include "nnacl/op_base.h" | |||
| #include "nnacl/errorcode.h" | |||
| #include "nnacl/strassen_matmul.h" | |||
| #include "nnacl/fp32/common_func.h" | |||
| #define FP32_STRASSEN_UINT C4NUM | |||
| #define FP32_STRASSEN_WEIGHT_UINT (C4NUM * C4NUM) | |||
| #define FP32_STRASSEN_MAX_RECURSION 5 | |||
| #ifdef __cplusplus | |||
| extern "C" { | |||
| #endif | |||
| int RecursionMatmul(const float *a_ptr, const float *b_ptr, float *c_ptr, StrassenMatMulParameter *matmul_param, | |||
| int max_recursion, int, float *tmp_a_ptr); | |||
| int CommonMatMul(const float *a_ptr, const float *b_ptr, float *c_ptr, StrassenMatMulParameter *Matmul_param, | |||
| float *tmp_a_ptr); | |||
| int StrassenMatmul(const float *a_ptr, const float *b_ptr, float *c_ptr, StrassenMatMulParameter *matmul_param, | |||
| int max_recursion, int cur_recursion, float *tmp_a_ptr); | |||
| #ifdef __cplusplus | |||
| } | |||
| #endif | |||
| #endif // MINDSPORE_LITE_NNACL_FP32_STRASSEN_MATMUL_H_ | |||
| @@ -31,6 +31,8 @@ typedef void (*MAT_TRANS_FUNC)(void *dst, void *a, int row, int col); | |||
| typedef enum ActType { ActType_No, ActType_Relu, ActType_Relu6 } ActType; | |||
| typedef enum OutType { OutType_C8 = 0, OutType_Nhwc = 1, OutType_C4 = 2 } OutType; | |||
| typedef struct MatMulParameter { | |||
| OpParameter op_parameter_; | |||
| int row_; | |||
| @@ -1,33 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_LITE_NNACL_STRASSEN_MATMUL_H_ | |||
| #define MINDSPORE_LITE_NNACL_STRASSEN_MATMUL_H_ | |||
| #include "nnacl/op_base.h" | |||
| /* hw*inc4 X inc4*oc4 */ | |||
| typedef struct StrassenMatMulParameter { | |||
| OpParameter op_parameter; | |||
| int row_; /* h * w */ | |||
| int col_; /* oc4 / 4 */ | |||
| int deep_; /* inc4 / 4 */ | |||
| int a_stride_; /* h * w * 4 */ | |||
| int b_stride_; /* inc4 * 4 */ | |||
| int c_stride_; /* h * w * 4 */ | |||
| } StrassenMatMulParameter; | |||
| #endif // MINDSPORE_LITE_NNACL_STRASSEN_MATMUL_H_ | |||
| @@ -39,6 +39,10 @@ void Convolution1x1CPUKernel::FreeTmpBuffer() { | |||
| free(pack_input_); | |||
| pack_input_ = nullptr; | |||
| } | |||
| if (pre_trans_input_ && input_ptr_ != nullptr) { | |||
| free(input_ptr_); | |||
| input_ptr_ = nullptr; | |||
| } | |||
| return; | |||
| } | |||
| @@ -106,6 +110,16 @@ int Convolution1x1CPUKernel::InitConv1x1Param() { | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| memset(pack_input_, 0, matmul_param_->row_12_ * matmul_param_->deep_ * sizeof(float)); | |||
| if (pre_trans_input_) { | |||
| input_ptr_ = reinterpret_cast<float *>(malloc(matmul_param_->row_ * matmul_param_->deep_ * sizeof(float))); | |||
| if (input_ptr_ == nullptr) { | |||
| MS_LOG(ERROR) << "Conv1x1 Malloc input_ptr_ error!"; | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| memset(input_ptr_, 0, matmul_param_->row_ * matmul_param_->deep_ * sizeof(float)); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -140,13 +154,10 @@ int Convolution1x1CPUKernel::DoConv1x1(int task_id) { | |||
| if (cur_oc <= 0) { | |||
| return RET_OK; | |||
| } | |||
| auto bias = (bias_data_ == nullptr) ? nullptr : reinterpret_cast<float *>(bias_data_) + thread_stride_ * task_id; | |||
| MatMulOpt(pack_input_, weight_ptr_ + task_id * thread_stride_ * matmul_param_->deep_, | |||
| output_ptr_ + task_id * thread_stride_, bias, matmul_param_->act_type_, matmul_param_->deep_, | |||
| matmul_param_->row_, cur_oc, matmul_param_->col_, 1, 0); | |||
| output_ptr_ + task_id * thread_stride_, reinterpret_cast<float *>(bias_data_) + thread_stride_ * task_id, | |||
| matmul_param_->act_type_, matmul_param_->deep_, matmul_param_->row_, cur_oc, matmul_param_->col_, | |||
| OutType_Nhwc); | |||
| return RET_OK; | |||
| } | |||
| @@ -169,15 +180,6 @@ int Convolution1x1CPUKernel::Run() { | |||
| auto src_in = reinterpret_cast<float *>(in_tensors_[0]->Data()); | |||
| auto src_out = reinterpret_cast<float *>(out_tensors_[0]->Data()); | |||
| if (pre_trans_input_) { | |||
| input_ptr_ = | |||
| reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_ * matmul_param_->deep_ * sizeof(float))); | |||
| if (input_ptr_ == nullptr) { | |||
| MS_LOG(ERROR) << "Conv1x1 Malloc input_ptr_ error!"; | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| } | |||
| for (int batch_index = 0; batch_index < conv_param_->input_batch_; batch_index++) { | |||
| Pre1x1Trans(src_in + batch_index * conv_param_->input_h_ * conv_param_->input_w_ * conv_param_->input_channel_, | |||
| src_out + batch_index * matmul_param_->row_ * matmul_param_->col_); | |||
| @@ -189,10 +191,6 @@ int Convolution1x1CPUKernel::Run() { | |||
| } | |||
| } | |||
| if (pre_trans_input_) { | |||
| ctx_->allocator->Free(input_ptr_); | |||
| input_ptr_ = nullptr; | |||
| } | |||
| return RET_OK; | |||
| } | |||
| } // namespace mindspore::kernel | |||
| @@ -95,13 +95,13 @@ int DeConvolutionCPUKernel::InitParam() { | |||
| matmul_param_->row_ = input_plane_; | |||
| matmul_param_->deep_ = conv_param_->input_channel_; | |||
| matmul_param_->col_ = conv_param_->output_channel_ * kernel_plane_; | |||
| matmul_param_->row_8_ = UP_ROUND(matmul_param_->row_, C8NUM); | |||
| matmul_param_->row_12_ = UP_ROUND(matmul_param_->row_, C12NUM); | |||
| matmul_param_->col_8_ = UP_ROUND(conv_param_->output_channel_, C8NUM) * kernel_plane_; | |||
| thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(conv_param_->output_channel_, C8NUM)); | |||
| thread_stride_ = UP_DIV(UP_DIV(conv_param_->output_channel_, C8NUM), thread_count_); | |||
| pack_input_ = reinterpret_cast<float *>(malloc(matmul_param_->row_8_ * matmul_param_->deep_ * sizeof(float))); | |||
| pack_input_ = reinterpret_cast<float *>(malloc(matmul_param_->row_12_ * matmul_param_->deep_ * sizeof(float))); | |||
| if (pack_input_ == nullptr) { | |||
| MS_LOG(ERROR) << "deconv Malloc pack_input_ error!"; | |||
| return RET_ERROR; | |||
| @@ -126,14 +126,14 @@ int DeConvolutionCPUKernel::DoDeconv(int task_id) { | |||
| return RET_OK; | |||
| } | |||
| auto tmp_buffer = tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_8_; | |||
| MatMul(pack_input_, weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_, tmp_buffer, | |||
| nullptr, ActType_No, matmul_param_->deep_, matmul_param_->row_8_, oc * C8NUM * kernel_plane_, | |||
| matmul_param_->col_, false); | |||
| auto tmp_buffer = tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_12_; | |||
| MatMulOpt(pack_input_, weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_, | |||
| tmp_buffer, nullptr, ActType_No, matmul_param_->deep_, matmul_param_->row_12_, oc * C8NUM * kernel_plane_, | |||
| matmul_param_->col_, OutType_C8); | |||
| DeConvPostFp32C8x8(tmp_buffer, pack_output_ + task_id * thread_stride_ * C8NUM * output_plane_, | |||
| reinterpret_cast<float *>(bias_data_) + thread_stride_ * task_id * C8NUM, | |||
| output_ptr_ + task_id * thread_stride_ * C8NUM, oc_res, conv_param_); | |||
| DeConvPostFp32C12x8(tmp_buffer, pack_output_ + task_id * thread_stride_ * C8NUM * output_plane_, | |||
| reinterpret_cast<float *>(bias_data_) + thread_stride_ * task_id * C8NUM, | |||
| output_ptr_ + task_id * thread_stride_ * C8NUM, oc_res, conv_param_); | |||
| return RET_OK; | |||
| } | |||
| @@ -165,7 +165,7 @@ int DeConvolutionCPUKernel::InitRunBuf() { | |||
| } | |||
| tmp_buffer_ = | |||
| reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_8_ * matmul_param_->col_8_ * sizeof(float))); | |||
| reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_12_ * matmul_param_->col_8_ * sizeof(float))); | |||
| if (tmp_buffer_ == nullptr) { | |||
| MS_LOG(ERROR) << "Conv1x1 Malloc tmp_buffer_ error!"; | |||
| return RET_NULL_PTR; | |||
| @@ -192,7 +192,7 @@ int DeConvolutionCPUKernel::Run() { | |||
| input_ptr_ = src_in + batch_index * input_plane_ * conv_param_->input_channel_; | |||
| output_ptr_ = src_out + batch_index * output_plane_ * conv_param_->output_channel_; | |||
| RowMajor2Col8Major(input_ptr_, pack_input_, input_plane_, conv_param_->input_channel_); | |||
| RowMajor2Col12Major(input_ptr_, pack_input_, input_plane_, conv_param_->input_channel_); | |||
| error_code = LiteBackendParallelLaunch(DeConvFp32Run, this, thread_count_); | |||
| if (error_code != RET_OK) { | |||
| @@ -27,18 +27,14 @@ FullconnectionCPUKernel::~FullconnectionCPUKernel() { | |||
| } | |||
| void FullconnectionCPUKernel::FreeBuf() { | |||
| if (a_c8_ptr_ != nullptr) { | |||
| free(a_c8_ptr_); | |||
| a_c8_ptr_ = nullptr; | |||
| if (a_c12_ptr_ != nullptr) { | |||
| free(a_c12_ptr_); | |||
| a_c12_ptr_ = nullptr; | |||
| } | |||
| if (b_r8_ptr_ != nullptr) { | |||
| free(b_r8_ptr_); | |||
| b_r8_ptr_ = nullptr; | |||
| } | |||
| if (c_r8x8_ptr_ != nullptr) { | |||
| free(c_r8x8_ptr_); | |||
| c_r8x8_ptr_ = nullptr; | |||
| } | |||
| if (bias_ptr_ != nullptr) { | |||
| free(bias_ptr_); | |||
| bias_ptr_ = nullptr; | |||
| @@ -51,8 +47,8 @@ int FullconnectionCPUKernel::ReSize() { | |||
| fc_param_->col_ = (in_tensors_[1]->shape())[0]; | |||
| fc_param_->deep_ = (in_tensors_[1]->shape())[1]; | |||
| fc_param_->row_8_ = UP_ROUND(fc_param_->row_, 8); | |||
| fc_param_->col_8_ = UP_ROUND(fc_param_->col_, 8); | |||
| fc_param_->row_12_ = UP_ROUND(fc_param_->row_, C12NUM); | |||
| fc_param_->col_8_ = UP_ROUND(fc_param_->col_, C8NUM); | |||
| thread_count_ = MSMIN(thread_count_, UP_DIV(fc_param_->col_8_, 8)); | |||
| thread_stride_ = UP_DIV(UP_DIV(fc_param_->col_8_, 8), thread_count_); | |||
| @@ -63,11 +59,11 @@ int FullconnectionCPUKernel::ReSize() { | |||
| memcpy(bias_ptr_, in_tensors_[2]->Data(), fc_param_->col_ * sizeof(float)); | |||
| } | |||
| a_c8_ptr_ = reinterpret_cast<float *>(malloc(fc_param_->row_8_ * fc_param_->deep_ * sizeof(float))); | |||
| if (a_c8_ptr_ == nullptr) { | |||
| a_c12_ptr_ = reinterpret_cast<float *>(malloc(fc_param_->row_12_ * fc_param_->deep_ * sizeof(float))); | |||
| if (a_c12_ptr_ == nullptr) { | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| memset(a_c8_ptr_, 0, fc_param_->row_8_ * fc_param_->deep_ * sizeof(float)); | |||
| memset(a_c12_ptr_, 0, fc_param_->row_12_ * fc_param_->deep_ * sizeof(float)); | |||
| b_r8_ptr_ = reinterpret_cast<float *>(malloc(fc_param_->col_8_ * fc_param_->deep_ * sizeof(float))); | |||
| if (b_r8_ptr_ == nullptr) { | |||
| @@ -76,16 +72,9 @@ int FullconnectionCPUKernel::ReSize() { | |||
| } | |||
| memset(b_r8_ptr_, 0, fc_param_->col_8_ * fc_param_->deep_ * sizeof(float)); | |||
| c_r8x8_ptr_ = reinterpret_cast<float *>(malloc(fc_param_->row_8_ * fc_param_->col_8_ * sizeof(float))); | |||
| if (c_r8x8_ptr_ == nullptr) { | |||
| FreeBuf(); | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| memset(c_r8x8_ptr_, 0, fc_param_->row_8_ * fc_param_->col_8_ * sizeof(float)); | |||
| fc_param_->a_const_ = false; | |||
| fc_param_->b_const_ = false; | |||
| InitMatrixA(reinterpret_cast<float *>(in_tensors_[0]->Data()), a_c8_ptr_); | |||
| InitMatrixA(reinterpret_cast<float *>(in_tensors_[0]->Data()), a_c12_ptr_); | |||
| InitMatrixB(reinterpret_cast<float *>(in_tensors_[1]->Data()), b_r8_ptr_); | |||
| return RET_OK; | |||
| } | |||
| @@ -105,7 +94,7 @@ void FullconnectionCPUKernel::InitMatrixA(float *src_ptr, float *dst_ptr) { | |||
| return; | |||
| } | |||
| fc_param_->a_const_ = true; | |||
| RowMajor2Col8Major(src_ptr, a_c8_ptr_, fc_param_->row_, fc_param_->deep_); | |||
| RowMajor2Col12Major(src_ptr, a_c12_ptr_, fc_param_->row_, fc_param_->deep_); | |||
| return; | |||
| } | |||
| @@ -132,15 +121,14 @@ int FcFp32MatmulRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { | |||
| } | |||
| int FullconnectionCPUKernel::DoMatmul(int task_id) { | |||
| int cur_oc = MSMIN(thread_stride_, UP_DIV(fc_param_->col_8_, 8) - task_id * thread_stride_); | |||
| int cur_oc = MSMIN(thread_stride_ * C8NUM, fc_param_->col_ - task_id * thread_stride_ * C8NUM); | |||
| if (cur_oc <= 0) { | |||
| return RET_OK; | |||
| } | |||
| MatMul(a_c8_ptr_, b_r8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->deep_, | |||
| c_r8x8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->row_8_, | |||
| bias_ptr_ + task_id * thread_stride_ * C8NUM, fc_param_->act_type_, fc_param_->deep_, fc_param_->row_8_, | |||
| cur_oc * 8, 0, false); | |||
| MatMulOpt(a_c12_ptr_, b_r8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->deep_, | |||
| c_r_ptr + task_id * thread_stride_ * C8NUM, bias_ptr_ + task_id * thread_stride_ * C8NUM, | |||
| fc_param_->act_type_, fc_param_->deep_, fc_param_->row_, cur_oc, fc_param_->col_, OutType_Nhwc); | |||
| return RET_OK; | |||
| } | |||
| @@ -152,14 +140,13 @@ int FullconnectionCPUKernel::Run() { | |||
| } | |||
| auto a_ptr = reinterpret_cast<float *>(in_tensors_.at(0)->Data()); | |||
| auto b_ptr = reinterpret_cast<float *>(in_tensors_.at(1)->Data()); | |||
| auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->Data()); | |||
| c_r_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->Data()); | |||
| InitMatrixA(a_ptr, a_c8_ptr_); | |||
| InitMatrixA(a_ptr, a_c12_ptr_); | |||
| InitMatrixB(b_ptr, b_r8_ptr_); | |||
| LiteBackendParallelLaunch(FcFp32MatmulRun, this, thread_count_); | |||
| Row8x8Major2RowMajor(c_r8x8_ptr_, output_ptr, fc_param_->row_, fc_param_->col_, fc_param_->col_); | |||
| return RET_OK; | |||
| } | |||
| } // namespace mindspore::kernel | |||
| @@ -47,9 +47,9 @@ class FullconnectionCPUKernel : public FullconnectionBaseCPUKernel { | |||
| void InitMatrixB(float *src_ptr, float *dst_ptr); | |||
| private: | |||
| float *a_c8_ptr_ = nullptr; | |||
| float *a_c12_ptr_ = nullptr; | |||
| float *b_r8_ptr_ = nullptr; | |||
| float *c_r8x8_ptr_ = nullptr; | |||
| float *c_r_ptr = nullptr; | |||
| float *bias_ptr_ = nullptr; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| @@ -28,18 +28,14 @@ namespace mindspore::kernel { | |||
| MatmulCPUKernel::~MatmulCPUKernel() { FreeTmpBuffer(); } | |||
| void MatmulCPUKernel::FreeTmpBuffer() { | |||
| if (a_c8_ptr_ != nullptr) { | |||
| ctx_->allocator->Free(a_c8_ptr_); | |||
| a_c8_ptr_ = nullptr; | |||
| if (a_c12_ptr_ != nullptr) { | |||
| ctx_->allocator->Free(a_c12_ptr_); | |||
| a_c12_ptr_ = nullptr; | |||
| } | |||
| if (b_r8_ptr_ != nullptr) { | |||
| ctx_->allocator->Free(b_r8_ptr_); | |||
| b_r8_ptr_ = nullptr; | |||
| } | |||
| if (c_r8x8_ptr_ != nullptr) { | |||
| ctx_->allocator->Free(c_r8x8_ptr_); | |||
| c_r8x8_ptr_ = nullptr; | |||
| } | |||
| if (bias_ptr_ != nullptr) { | |||
| ctx_->allocator->Free(bias_ptr_); | |||
| bias_ptr_ = nullptr; | |||
| @@ -66,45 +62,37 @@ int MatmulCPUKernel::ReSize() { | |||
| params_->row_ = c_shape[c_shape.size() - 2]; | |||
| params_->col_ = c_shape[c_shape.size() - 1]; | |||
| params_->deep_ = params_->a_transpose_ ? a_shape[a_shape.size() - 2] : a_shape[a_shape.size() - 1]; | |||
| params_->row_8_ = UP_ROUND(params_->row_, 8); | |||
| params_->row_12_ = UP_ROUND(params_->row_, C12NUM); | |||
| params_->col_8_ = UP_ROUND(params_->col_, 8); | |||
| thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_8_, 8)); | |||
| thread_stride_ = UP_DIV(UP_DIV(params_->col_8_, 8), thread_count_); | |||
| a_c8_ptr_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(params_->row_8_ * params_->deep_ * sizeof(float))); | |||
| if (a_c8_ptr_ == nullptr) { | |||
| a_c12_ptr_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(params_->row_12_ * params_->deep_ * sizeof(float))); | |||
| if (a_c12_ptr_ == nullptr) { | |||
| FreeTmpBuffer(); | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| memset(a_c8_ptr_, 0, params_->row_8_ * params_->deep_ * sizeof(float)); | |||
| memset(a_c12_ptr_, 0, params_->row_12_ * params_->deep_ * sizeof(float)); | |||
| b_r8_ptr_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(params_->col_8_ * params_->deep_ * sizeof(float))); | |||
| if (b_r8_ptr_ == nullptr) { | |||
| FreeTmpBuffer(); | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| memset(b_r8_ptr_, 0, params_->col_8_ * params_->deep_ * sizeof(float)); | |||
| c_r8x8_ptr_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(params_->row_8_ * params_->col_8_ * sizeof(float))); | |||
| if (c_r8x8_ptr_ == nullptr) { | |||
| FreeTmpBuffer(); | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| memset(c_r8x8_ptr_, 0, params_->row_8_ * params_->col_8_ * sizeof(float)); | |||
| params_->a_const_ = false; | |||
| params_->b_const_ = false; | |||
| InitMatrixA(reinterpret_cast<float *>(in_tensors_[0]->Data()), a_c8_ptr_); | |||
| InitMatrixA(reinterpret_cast<float *>(in_tensors_[0]->Data()), a_c12_ptr_); | |||
| InitMatrixB(reinterpret_cast<float *>(in_tensors_[1]->Data()), b_r8_ptr_); | |||
| bias_ptr_ = reinterpret_cast<float *>(malloc(params_->col_8_ * sizeof(float))); | |||
| if (bias_ptr_ == nullptr) { | |||
| FreeTmpBuffer(); | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| memset(bias_ptr_, 0, params_->col_8_ * sizeof(float)); | |||
| if (in_tensors_.size() == 3) { | |||
| bias_ptr_ = reinterpret_cast<float *>(malloc(params_->col_8_ * sizeof(float))); | |||
| if (bias_ptr_ == nullptr) { | |||
| FreeTmpBuffer(); | |||
| return RET_MEMORY_FAILED; | |||
| } | |||
| memset(bias_ptr_, 0, params_->col_8_ * sizeof(float)); | |||
| memcpy(bias_ptr_, in_tensors_[2]->Data(), params_->col_ * sizeof(float)); | |||
| } else { | |||
| bias_ptr_ = nullptr; | |||
| } | |||
| return RET_OK; | |||
| @@ -120,9 +108,9 @@ void MatmulCPUKernel::InitMatrixA(float *src_ptr, float *dst_ptr) { | |||
| params_->a_const_ = true; | |||
| if (params_->a_transpose_) { | |||
| RowMajor2Row8Major(src_ptr, dst_ptr, params_->deep_, params_->row_); | |||
| RowMajor2Row12Major(src_ptr, dst_ptr, params_->deep_, params_->row_); | |||
| } else { | |||
| RowMajor2Col8Major(src_ptr, a_c8_ptr_, params_->row_, params_->deep_); | |||
| RowMajor2Col12Major(src_ptr, dst_ptr, params_->row_, params_->deep_); | |||
| } | |||
| return; | |||
| } | |||
| @@ -152,18 +140,13 @@ int MatmulCPUKernel::Init() { | |||
| } | |||
| int MatmulCPUKernel::RunImpl(int task_id) { | |||
| int cur_oc = MSMIN(thread_stride_, UP_DIV(params_->col_8_, 8) - task_id * thread_stride_); | |||
| int cur_oc = MSMIN(thread_stride_ * C8NUM, params_->col_ - task_id * thread_stride_ * C8NUM); | |||
| if (cur_oc <= 0) { | |||
| return RET_OK; | |||
| } | |||
| auto cur_b = b_r8_ptr_ + task_id * thread_stride_ * C8NUM * params_->deep_; | |||
| auto cur_c = c_r8x8_ptr_ + task_id * thread_stride_ * C8NUM * params_->row_8_; | |||
| if (bias_ptr_) { | |||
| auto cur_bias = bias_ptr_ + task_id * thread_stride_ * C8NUM; | |||
| MatMul(a_c8_ptr_, cur_b, cur_c, cur_bias, ActType_No, params_->deep_, params_->row_8_, cur_oc * 8, 0, false); | |||
| } else { | |||
| MatMul(a_c8_ptr_, cur_b, cur_c, NULL, ActType_No, params_->deep_, params_->row_8_, cur_oc * 8, 0, false); | |||
| } | |||
| MatMulOpt(a_c12_ptr_, b_r8_ptr_ + task_id * thread_stride_ * C8NUM * params_->deep_, | |||
| c_r_ptr_ + task_id * thread_stride_ * C8NUM, bias_ptr_ + task_id * thread_stride_ * C8NUM, ActType_No, | |||
| params_->deep_, params_->row_, cur_oc, params_->col_, OutType_Nhwc); | |||
| return RET_OK; | |||
| } | |||
| @@ -192,13 +175,12 @@ int MatmulCPUKernel::Run() { | |||
| for (int i = 0; i < params_->batch; ++i) { | |||
| auto cur_a_ptr = a_ptr + i * a_stride; | |||
| auto cur_b_ptr = b_ptr + i * b_stride; | |||
| auto cur_c_ptr = c_ptr + i * c_stride; | |||
| c_r_ptr_ = c_ptr + i * c_stride; | |||
| InitMatrixA(cur_a_ptr, a_c8_ptr_); | |||
| InitMatrixA(cur_a_ptr, a_c12_ptr_); | |||
| InitMatrixB(cur_b_ptr, b_r8_ptr_); | |||
| LiteBackendParallelLaunch(MatmulFloatRun, this, thread_count_); | |||
| Row8x8Major2RowMajor(c_r8x8_ptr_, cur_c_ptr, params_->row_, params_->col_, params_->col_); | |||
| } | |||
| return RET_OK; | |||
| } | |||
| @@ -41,9 +41,9 @@ class MatmulCPUKernel : public MatmulBaseCPUKernel { | |||
| void FreeTmpBuffer(); | |||
| private: | |||
| float *a_c8_ptr_ = nullptr; | |||
| float *a_c12_ptr_ = nullptr; | |||
| float *b_r8_ptr_ = nullptr; | |||
| float *c_r8x8_ptr_ = nullptr; | |||
| float *c_r_ptr_ = nullptr; | |||
| float *bias_ptr_ = nullptr; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| @@ -19,9 +19,8 @@ | |||
| #include "utils/log_adapter.h" | |||
| #include "common/common_test.h" | |||
| #include "src/common/file_utils.h" | |||
| #include "src/runtime/kernel/arm/fp32/convolution_1x1.h" | |||
| #include "nnacl/matmul_parameter.h" | |||
| #include "nnacl/strassen_matmul.h" | |||
| #include "src/runtime/kernel/arm/fp32/convolution_1x1.h" | |||
| namespace mindspore { | |||
| using mindspore::lite::tensor::Tensor; | |||
| @@ -548,14 +548,14 @@ TEST_F(TestDeConvolutionFp32, DeConvTest2) { | |||
| float *correct; | |||
| int total_size = DeConvTestInit2(&inputs_, &outputs_, deconv_param, &correct); | |||
| lite::Context *ctx = new lite::Context; | |||
| ctx->thread_num_ = 4; | |||
| ctx->thread_num_ = 1; | |||
| kernel::DeConvolutionCPUKernel *deconv = | |||
| new kernel::DeConvolutionCPUKernel(reinterpret_cast<OpParameter *>(deconv_param), inputs_, outputs_, ctx, nullptr); | |||
| deconv->Init(); | |||
| deconv->Run(); | |||
| EXPECT_EQ(0, lite::CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size)); | |||
| delete deconv_param; | |||
| CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001); | |||
| delete deconv; | |||
| for (auto t : inputs_) delete t; | |||
| for (auto t : outputs_) delete t; | |||
| @@ -635,7 +635,6 @@ TEST_F(TestDeConvolutionFp32, DeConvTest3) { | |||
| deconv->Run(); | |||
| CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001); | |||
| delete deconv_param; | |||
| delete deconv; | |||
| for (auto t : inputs_) delete t; | |||
| for (auto t : outputs_) delete t; | |||
| @@ -723,7 +722,6 @@ TEST_F(TestDeConvolutionFp32, DeConvTest4) { | |||
| uint64_t time_avg = cost / loop_count; | |||
| printf("deconv fp32 average time : %f ms\n", time_avg / 1000.0f); | |||
| delete deconv_param; | |||
| delete deconv; | |||
| for (auto t : inputs_) delete t; | |||
| for (auto t : outputs_) delete t; | |||
| @@ -1,369 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <iostream> | |||
| #include <memory> | |||
| #include "utils/log_adapter.h" | |||
| #include "common/common_test.h" | |||
| #include "src/common/file_utils.h" | |||
| #include "mindspore/lite/nnacl/pack.h" | |||
| #include "mindspore/lite/nnacl/fp32/strassen_matmul.h" | |||
| #include "mindspore/lite/nnacl/conv_parameter.h" | |||
| namespace mindspore { | |||
| class TestStrassenFp32 : public mindspore::CommonTest { | |||
| public: | |||
| TestStrassenFp32() {} | |||
| }; | |||
| TEST_F(TestStrassenFp32, MatrixAdd1) { | |||
| float a[] = {0.06796285, 0.6176181, 0.33195993, 0.2752791, 0.36864007, 0.04605605, 0.33899087, 0.9820137, | |||
| 0.49804246, 0.8242412, 0.8458231, 0.6530539, 0.6336898, 0.8367749, 0.57166654, 0.25895607, | |||
| 0.90079665, 0.10585558, 0.8215811, 0.48977906, 0.7895138, 0.41816455, 0.18999523, 0.28736928, | |||
| 0.5882977, 0.44262612, 0.65245426, 0.7834421, 0.60903394, 0.82289135, 0.03855767, 0.30543327, | |||
| 0.37747085, 0, 0, 0, 0.590335, 0, 0, 0, | |||
| 0.7578682, 0, 0, 0, 0.81001425, 0, 0, 0, | |||
| 0.9487712, 0, 0, 0, 0.11742989, 0, 0, 0, | |||
| 0.60004807, 0, 0, 0, 0.05973052, 0, 0, 0}; | |||
| float b[] = {0.112120815, 0.6869974, 0.08290442, 0.43003577, 0.044390075, 0.23077105, 0.23964432, 0.4426781, | |||
| 0.6612115, 0.14988606, 0.84881437, 0.032587975, 0.35028255, 0.41838303, 0.12859282, 0.060378596, | |||
| 0.8272769, 0.6949804, 0.9120368, 0.12399232, 0.9292184, 0.7566025, 0.10235854, 0.015936268, | |||
| 0.20426726, 0.9926392, 0.54714125, 0.7022856, 0.58746314, 0.95714045, 0.26433542, 0.9030878, | |||
| 0.8596953, 0, 0, 0, 0.8341476, 0, 0, 0, | |||
| 0.72301114, 0, 0, 0, 0.40733734, 0, 0, 0, | |||
| 0.2873559, 0, 0, 0, 0.612321, 0, 0, 0, | |||
| 0.5008707, 0, 0, 0, 0.2586266, 0, 0, 0}; | |||
| float add[] = {0.18008366, 1.3046155, 0.41486436, 0.7053149, 0.41303015, 0.2768271, 0.5786352, 1.4246918, | |||
| 1.159254, 0.9741273, 1.6946375, 0.6856419, 0.9839724, 1.255158, 0.7002593, 0.3193347, | |||
| 1.7280736, 0.80083597, 1.7336179, 0.6137714, 1.7187322, 1.174767, 0.29235378, 0.30330554, | |||
| 0.792565, 1.4352653, 1.1995955, 1.4857277, 1.1964971, 1.7800318, 0.3028931, 1.2085211, | |||
| 1.2371662, 0, 0, 0, 1.4244826, 0, 0, 0, | |||
| 1.4808793, 0, 0, 0, 1.2173516, 0, 0, 0, | |||
| 1.2361271, 0, 0, 0, 0.72975093, 0, 0, 0, | |||
| 1.1009188, 0, 0, 0, 0.31835714, 0, 0, 0}; | |||
| float out[64] = {0}; | |||
| MatrixAdd(a, b, out, 32, 32, 32, 8, 2); | |||
| EXPECT_EQ(0, lite::CompareOutputData(out, add, 64)); | |||
| } | |||
| TEST_F(TestStrassenFp32, MatrixAdd2) { | |||
| float a[] = {0.06796285, 0.6176181, 0.33195993, 0.2752791, 0.36864007, 0.04605605, 0.33899087, 0.9820137, | |||
| 0.49804246, 0.8242412, 0.8458231, 0.6530539, 0.6336898, 0.8367749, 0.57166654, 0.25895607, | |||
| 0.90079665, 0.10585558, 0.8215811, 0.48977906, 0.7895138, 0.41816455, 0.18999523, 0.28736928, | |||
| 0.5882977, 0.44262612, 0.65245426, 0.7834421, 0.60903394, 0.82289135, 0.03855767, 0.30543327, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0.37747085, 0, 0, 0, | |||
| 0.590335, 0, 0, 0, 0.7578682, 0, 0, 0, | |||
| 0.81001425, 0, 0, 0, 0.9487712, 0, 0, 0, | |||
| 0.11742989, 0, 0, 0, 0.60004807, 0, 0, 0, | |||
| 0.05973052, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0}; | |||
| float b[] = {0.112120815, 0.6869974, 0.08290442, 0.43003577, 0.044390075, 0.23077105, 0.23964432, 0.4426781, | |||
| 0.6612115, 0.14988606, 0.84881437, 0.032587975, 0.35028255, 0.41838303, 0.12859282, 0.060378596, | |||
| 0.8272769, 0.6949804, 0.9120368, 0.12399232, 0.9292184, 0.7566025, 0.10235854, 0.015936268, | |||
| 0.20426726, 0.9926392, 0.54714125, 0.7022856, 0.58746314, 0.95714045, 0.26433542, 0.9030878, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0.8596953, 0, 0, 0, 0.8341476, 0, 0, 0, | |||
| 0.72301114, 0, 0, 0, 0.40733734, 0, 0, 0, | |||
| 0.2873559, 0, 0, 0, 0.612321, 0, 0, 0, | |||
| 0.5008707, 0, 0, 0, 0.2586266, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0}; | |||
| float add[] = {0.18008366, 1.3046155, 0.41486436, 0.7053149, 0.41303015, 0.2768271, 0.5786352, 1.4246918, | |||
| 1.159254, 0.9741273, 1.6946375, 0.6856419, 0.9839724, 1.255158, 0.7002593, 0.3193347, | |||
| 1.7280736, 0.80083597, 1.7336179, 0.6137714, 1.7187322, 1.174767, 0.29235378, 0.30330554, | |||
| 0.792565, 1.4352653, 1.1995955, 1.4857277, 1.1964971, 1.7800318, 0.3028931, 1.2085211, | |||
| 0, 0, 0, 0, 1.2371662, 0, 0, 0, | |||
| 1.4244826, 0, 0, 0, 1.4808793, 0, 0, 0, | |||
| 1.2173516, 0, 0, 0, 1.2361271, 0, 0, 0, | |||
| 0.72975093, 0, 0, 0, 1.1009188, 0, 0, 0, | |||
| 0.31835714, 0, 0, 0, 0, 0, 0, 0}; | |||
| float out[72] = {0}; | |||
| MatrixAdd(a, b, out, 44, 56, 36, 8, 2); | |||
| EXPECT_EQ(0, lite::CompareOutputData(out, add, 72)); | |||
| } | |||
| TEST_F(TestStrassenFp32, MatrixSub1) { | |||
| float a[] = {0.4160896, 0.55011475, 0.60395557, 0.964036, 0.8010256, 0.908257, 0.60170764, 0.008877548, | |||
| 0.4973592, 0.6104505, 0.2957374, 0.39589414, 0.0151615525, 0.45663023, 0.3815148, 0.6419536, | |||
| 0.9118046, 0.5312479, 0.104496025, 0.5972911, 0.9671534, 0.7195669, 0.23360363, 0.22078007, | |||
| 0.31118092, 0.7438336, 0.5592656, 0.7212792, 0.97856164, 0.26012093, 0.18205991, 0.90656054, | |||
| 0.24593723, 0, 0, 0, 0.5024593, 0, 0, 0, | |||
| 0.42271087, 0, 0, 0, 0.48668534, 0, 0, 0, | |||
| 0.4374295, 0, 0, 0, 0.22822042, 0, 0, 0, | |||
| 0.88180095, 0, 0, 0, 0.7505223, 0, 0, 0}; | |||
| float b[] = {0.14911577, 0.63214976, 0.74834836, 0.36854064, 0.5801671, 0.24166176, 0.64528674, 0.04887214, | |||
| 0.23637155, 0.34321627, 0.69035923, 0.6114065, 0.73006815, 0.575073, 0.88130534, 0.72951907, | |||
| 0.17092401, 0.652334, 0.6288812, 0.62121505, 0.12793411, 0.16503152, 0.7564361, 0.51976234, | |||
| 0.19353953, 0.5795124, 0.6671185, 0.10646773, 0.13608798, 0.37959677, 0.24294423, 0.1790138, | |||
| 0.85054415, 0, 0, 0, 0.18541782, 0, 0, 0, | |||
| 0.72714496, 0, 0, 0, 0.43221787, 0, 0, 0, | |||
| 0.7200413, 0, 0, 0, 0.15780604, 0, 0, 0, | |||
| 0.30473796, 0, 0, 0, 0.37719592, 0, 0, 0}; | |||
| float s[] = {0.26697382, -0.082035, -0.14439279, 0.59549534, 0.22085851, 0.6665952, -0.0435791, -0.03999459, | |||
| 0.26098764, 0.26723424, -0.39462185, -0.21551237, -0.7149066, -0.11844277, -0.49979055, -0.08756548, | |||
| 0.7408806, -0.12108606, -0.5243852, -0.02392393, 0.8392193, 0.5545354, -0.5228325, -0.29898226, | |||
| 0.11764139, 0.16432118, -0.10785288, 0.6148115, 0.8424736, -0.11947584, -0.06088431, 0.72754675, | |||
| -0.6046069, 0., 0., 0., 0.31704146, 0., 0., 0., | |||
| -0.3044341, 0., 0., 0., 0.05446747, 0., 0., 0., | |||
| -0.2826118, 0., 0., 0., 0.07041438, 0., 0., 0., | |||
| 0.57706296, 0., 0., 0., 0.3733264, 0., 0., 0.}; | |||
| float out[64] = {0}; | |||
| MatrixSub(a, b, out, 32, 32, 32, 8, 2); | |||
| EXPECT_EQ(0, lite::CompareOutputData(out, s, 64)); | |||
| } | |||
| TEST_F(TestStrassenFp32, MatrixSub2) { | |||
| float a[] = {0.4160896, 0.55011475, 0.60395557, 0.964036, 0.8010256, 0.908257, 0.60170764, 0.008877548, | |||
| 0.4973592, 0.6104505, 0.2957374, 0.39589414, 0.0151615525, 0.45663023, 0.3815148, 0.6419536, | |||
| 0.9118046, 0.5312479, 0.104496025, 0.5972911, 0.9671534, 0.7195669, 0.23360363, 0.22078007, | |||
| 0.31118092, 0.7438336, 0.5592656, 0.7212792, 0.97856164, 0.26012093, 0.18205991, 0.90656054, | |||
| 0.24593723, 0, 0, 0, 0.5024593, 0, 0, 0, | |||
| 0.42271087, 0, 0, 0, 0.48668534, 0, 0, 0, | |||
| 0.4374295, 0, 0, 0, 0.22822042, 0, 0, 0, | |||
| 0.88180095, 0, 0, 0, 0.7505223, 0, 0, 0}; | |||
| float b[] = {0.14911577, 0.63214976, 0.74834836, 0.36854064, 0.5801671, 0.24166176, 0.64528674, 0.04887214, | |||
| 0.23637155, 0.34321627, 0.69035923, 0.6114065, 0.73006815, 0.575073, 0.88130534, 0.72951907, | |||
| 0.17092401, 0.652334, 0.6288812, 0.62121505, 0.12793411, 0.16503152, 0.7564361, 0.51976234, | |||
| 0.19353953, 0.5795124, 0.6671185, 0.10646773, 0.13608798, 0.37959677, 0.24294423, 0.1790138, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0.85054415, 0, 0, 0, | |||
| 0.18541782, 0, 0, 0, 0.72714496, 0, 0, 0, | |||
| 0.43221787, 0, 0, 0, 0.7200413, 0, 0, 0, | |||
| 0.15780604, 0, 0, 0, 0.30473796, 0, 0, 0, | |||
| 0.37719592, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0}; | |||
| float s[] = {0.26697382, -0.082035, -0.14439279, 0.59549534, 0.22085851, 0.6665952, -0.0435791, -0.03999459, | |||
| 0.26098764, 0.26723424, -0.39462185, -0.21551237, -0.7149066, -0.11844277, -0.49979055, -0.08756548, | |||
| 0.7408806, -0.12108606, -0.5243852, -0.02392393, 0.8392193, 0.5545354, -0.5228325, -0.29898226, | |||
| 0.11764139, 0.16432118, -0.10785288, 0.6148115, 0.8424736, -0.11947584, -0.06088431, 0.72754675, | |||
| 0, 0, 0, 0, -0.6046069, 0., 0., 0., | |||
| 0.31704146, 0., 0., 0., -0.3044341, 0., 0., 0., | |||
| 0.05446747, 0., 0., 0., -0.2826118, 0., 0., 0., | |||
| 0.07041438, 0., 0., 0., 0.57706296, 0., 0., 0., | |||
| 0.3733264, 0., 0., 0, 0, 0, 0, 0.}; | |||
| float out[72] = {0}; | |||
| MatrixSub(a, b, out, 32, 44, 36, 8, 2); | |||
| EXPECT_EQ(0, lite::CompareOutputData(out, s, 72)); | |||
| } | |||
| TEST_F(TestStrassenFp32, MatrixPack1) { | |||
| float in[] = {4.1795, 13.142, -3.593, 16.505, 19.969, -6.235, -2.380, -9.027, 23.622, 8.3608, 47.325, -14.36, | |||
| -0.784, 37.925, -0.081, 6.1298, 37.998, 13.719, 11.029, 1.7127, 9.0560, 14.988, 3.1866, 0.0562, | |||
| 14.530, -14.10, -8.115, -8.071, 19.250, 17.923, 13.584, 3.3293, -1.514, -0.293, 18.686, 0.0873, | |||
| 19.899, 8.5562, 0.0, 0.0, 9.5542, 18.974, 0.0, 0.0, 15.370, 4.3049, 0.0, 0.0, | |||
| 0.6721, -1.517, 0.0, 0.0, -1.770, 41.903, 0.0, 0.0, 8.1381, 9.1391, 0.0, 0.0, | |||
| -8.158, 7.7566, 0.0, 0.0, 9.7341, 18.834, 0.0, 0.0, 4.2010, -2.253, 0.0, 0.0}; | |||
| float correct[] = {4.1795, 13.142, -3.593, 16.505, 19.969, -6.235, -2.380, -9.027, 23.622, 8.3608, 47.325, -14.36, | |||
| -0.784, 37.925, -0.081, 6.1298, 19.899, 8.5562, 0.0, 0.0, 9.5542, 18.974, 0.0, 0.0, | |||
| 15.370, 4.3049, 0.0, 0.0, 0.6721, -1.517, 0.0, 0.0, 37.998, 13.719, 11.029, 1.7127, | |||
| 9.0560, 14.988, 3.1866, 0.0562, 14.530, -14.10, -8.115, -8.071, -1.770, 41.903, 0.0, 0.0, | |||
| 8.1381, 9.1391, 0.0, 0.0, -8.158, 7.7566, 0.0, 0.0}; | |||
| float out[56] = {0}; | |||
| MatrixPack(in, out, 7, 2, 36); | |||
| EXPECT_EQ(0, lite::CompareOutputData(out, correct, 56)); | |||
| } | |||
| TEST_F(TestStrassenFp32, MatrixPack2) { | |||
| float in[] = {4.1795, 13.142, -3.593, 16.505, 19.969, -6.235, -2.380, -9.027, 23.622, 8.3608, 47.325, -14.36, | |||
| -0.784, 37.925, -0.081, 6.1298, 37.998, 13.719, 11.029, 1.7127, 9.0560, 14.988, 3.1866, 0.0562, | |||
| 14.530, -14.10, -8.115, -8.071, 19.250, 17.923, 13.584, 3.3293, -1.514, -0.293, 18.686, 0.0873, | |||
| 19.899, 8.5562, 0.0, 0.0, 9.5542, 18.974, 0.0, 0.0, 15.370, 4.3049, 0.0, 0.0, | |||
| 0.6721, -1.517, 0.0, 0.0, -1.770, 41.903, 0.0, 0.0, 8.1381, 9.1391, 0.0, 0.0, | |||
| -8.158, 7.7566, 0.0, 0.0, 9.7341, 18.834, 0.0, 0.0, 4.2010, -2.253, 0.0, 0.0}; | |||
| float correct[] = {4.1795, 13.142, -3.593, 16.505, 19.969, -6.235, -2.380, -9.027, 23.622, 8.3608, 47.325, -14.36, | |||
| -0.784, 37.925, -0.081, 6.1298, 19.899, 8.5562, 0.0, 0.0, 9.5542, 18.974, 0.0, 0.0, | |||
| 15.370, 4.3049, 0.0, 0.0, 0.6721, -1.517, 0.0, 0.0, 37.998, 13.719, 11.029, 1.7127, | |||
| 9.0560, 14.988, 3.1866, 0.0562, 14.530, -14.10, -8.115, -8.071, 19.250, 17.923, 13.584, 3.3293, | |||
| -1.770, 41.903, 0.0, 0.0, 8.1381, 9.1391, 0.0, 0.0, -8.158, 7.7566, 0.0, 0.0, | |||
| 9.7341, 18.834, 0.0, 0.0, -1.514, -0.293, 18.686, 0.0873, 4.2010, -2.253, 0.0, 0.0}; | |||
| float out[72] = {0}; | |||
| MatrixPack(in, out, 9, 2, 36); | |||
| EXPECT_EQ(0, lite::CompareOutputData(out, correct, 72)); | |||
| } | |||
| TEST_F(TestStrassenFp32, CommonMatmul1) { | |||
| float a_ptr[] = {7.756654, 19.250782, 17.923292, 0, 13.584222, 3.3293908, 9.734102, 0, | |||
| 18.83455, -1.51425, -0.29382, 0, 18.686155, 0.0873076, 4.2010098, 0, | |||
| -2.2539594, 4.1795673, 13.14235, 0, -3.59393, 16.50578, 19.899279, 0, | |||
| 8.556229, 19.969376, -6.2355065, 0, -2.380469, -9.027744, 9.5542, 0}; | |||
| float b_ptr[] = {0.2674241, 0.089372, -0.081915, 2.0580146, -0.295045, 1.377944, 0.703658, 1.055378, | |||
| 1.204049, -0.256505, -0.309640, 0.560465, 0, 0, 0, 0, | |||
| 0.646906, 0, 0, 0, -0.168206, 0, 0, 0, | |||
| -0.95630, 0, 0, 0, 0, 0, 0, 0}; | |||
| float correct[] = {17.97499, 22.622334, 7.360805, 46.325558, 14.37076, 3.304931, -1.784072, 36.925926, | |||
| 5.129812, -0.3278886, -2.517368, 36.99899, 10.029593, 0.7127603, -2.77004, 40.90305, | |||
| 13.988123, 2.186689, -0.943787, 7.138184, 18.128653, 17.31859, 5.7472067, 21.176342, | |||
| -11.11159, 29.880829, 15.281498, 35.1893, 13.530734, -15.10318, -9.11581, -9.071925, | |||
| -15.36046, 0, 0, 0, -1.081104, 0, 0, 0, | |||
| 12.719885, 0, 0, 0, 8.056052, 0, 0, 0, | |||
| -14.72927, 0, 0, 0, -24.1311, 0, 0, 0, | |||
| 8.139168, 0, 0, 0, -9.158176, 0, 0, 0}; | |||
| StrassenMatMulParameter *matmul_param = new StrassenMatMulParameter(); | |||
| matmul_param->row_ = 8; | |||
| matmul_param->deep_ = 1; | |||
| matmul_param->col_ = 2; | |||
| matmul_param->a_stride_ = 32; | |||
| matmul_param->b_stride_ = 16; | |||
| matmul_param->c_stride_ = 32; | |||
| float c_ptr[64] = {0}; | |||
| float tmp_ptr[32]; | |||
| CommonMatMul(a_ptr, b_ptr, c_ptr, matmul_param, tmp_ptr); | |||
| EXPECT_EQ(0, lite::CompareOutputData(c_ptr, correct, 64)); | |||
| delete matmul_param; | |||
| } | |||
| TEST_F(TestStrassenFp32, CommonMatmul2) { | |||
| StrassenMatMulParameter *matmul_param = new StrassenMatMulParameter(); | |||
| float a[] = {4.864725, 6.830073, 0.76780415, 8.922394, 5.096872, 2.4946148, 4.2148714, 1.7762588, 0.89195687, | |||
| 9.703938, 2.0654619, 9.048538, 2.358036, 5.643526, 2.5152204, 3.512572, 3.7913973, 3.7136157, | |||
| 8.820186, 1.5324963, 3.135459, 7.5792265, 7.1820426, 0.267987, 8.737802, 4.064117, 2.7232447, | |||
| 0.27355433, 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 6.320409, 9.479354, 0, 0, 1.6220464, 0.57753897, 0, 0, 9.786372, | |||
| 6.0404425, 0, 0, 2.1067812, 4.8034563, 0, 0, 2.1140356, 8.204062, | |||
| 0, 0, 3.29985, 1.2034118, 0, 0, 7.6059656, 4.162436, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 0}; | |||
| float b[] = { | |||
| 4.4558744, 0.6383263, 0.05037839, 9.730914, 8.1542015, 4.3625517, 8.654026, 3.805875, 9.845131, 4.08051, | |||
| 9.667656, 7.73955, 9.283867, 8.465257, 2.292051, 9.853942, 0.13320169, 3.8789113, 9.460265, 4.2616735, | |||
| 0.23831692, 4.420147, 0.5355651, 7.829217, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 1.9866786, 0, 0, 0, 6.0188327, 0, | |||
| 0, 0, 6.6249146, 0, 0, 0, 3.5639563, 0, 0, 0, | |||
| 0.14810833, 0, 0, 0, 7.4168983, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0}; | |||
| float c[] = {170.86482, 177.98166, 152.0957, 268.3473, 101.39282, 55.216248, 82.31873, 120.65008, 190.18558, | |||
| 192.58974, 220.54767, 239.75931, 115.32386, 95.52758, 103.82857, 145.08948, 150.4757, 112.04814, | |||
| 145.50496, 207.63342, 149.6962, 84.76027, 167.65851, 141.06763, 103.42963, 84.63687, 136.74927, | |||
| 189.26935, 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 158.90288, 0, 0, 0, 63.917973, | |||
| 0, 0, 0, 152.3613, 0, 0, 0, 103.77265, 0, | |||
| 0, 0, 154.94044, 0, 0, 0, 109.79707, 0, 0, | |||
| 0, 92.83551, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0}; | |||
| matmul_param->row_ = 7; | |||
| matmul_param->deep_ = 2; | |||
| matmul_param->col_ = 2; | |||
| matmul_param->a_stride_ = 36; | |||
| matmul_param->b_stride_ = 64; | |||
| matmul_param->c_stride_ = 40; | |||
| float out[80] = {0}; | |||
| float tmp_ptr[1000]; | |||
| CommonMatMul(a, b, out, matmul_param, tmp_ptr); | |||
| EXPECT_EQ(0, lite::CompareOutputData(out, c, 80)); | |||
| delete (matmul_param); | |||
| } | |||
| TEST_F(TestStrassenFp32, RecMatmul1) { | |||
| StrassenMatMulParameter *matmul_param = new StrassenMatMulParameter(); | |||
| matmul_param->row_ = 4; | |||
| matmul_param->deep_ = 2; | |||
| matmul_param->col_ = 2; | |||
| matmul_param->a_stride_ = 16; | |||
| matmul_param->b_stride_ = 32; | |||
| matmul_param->c_stride_ = 16; | |||
| float a[] = {9.02165, 8.657163, 0.56371903, 0.7272156, 1.6258951, 9.919627, 7.47593, 3.5311592, | |||
| 8.958062, 0.55338514, 9.611276, 7.429841, 8.23804, 3.7503464, 1.2829816, 6.4470887, | |||
| 4.303486, 6.282502, 0, 0, 9.4194765, 7.8199654, 0, 0, | |||
| 6.738705, 7.5398073, 0, 0, 0.47684374, 0.87746763, 0, 0}; | |||
| float b[] = {1.8100919, 6.016964, 5.733568, 5.768448, 2.2823029, 2.173359, 0.56861514, 7.134393, | |||
| 0.26377398, 3.9010656, 4.868408, 0.33401546, 1.7973539, 8.21896, 5.62239, 8.54786, | |||
| 0.97356945, 1.0714527, 6.447588, 6.161091, 3.332229, 2.8775468, 6.558747, 2.6986659, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 1.9830805, 0, 0, 0, 8.44718, 0, 0, 0, | |||
| 9.360418, 0, 0, 0, 6.220693, 0, 0, 0, | |||
| 1.8369701, 0, 0, 0, 4.3965054, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0}; | |||
| float c[] = {62.668518, 103.9633, 132.43439, 163.67749, 69.12974, 122.12326, 183.23413, 191.96806, | |||
| 65.052124, 182.57918, 233.14148, 184.20694, 38.785316, 118.74806, 100.689575, 135.12036, | |||
| 136.34613, 0, 0, 0, 230.64507, 0, 0, 0, | |||
| 204.15103, 0, 0, 0, 104.86488, 0, 0, 0}; | |||
| float out[32] = {0}; | |||
| float tmp_ptr[1000]; | |||
| RecursionMatmul(a, b, out, matmul_param, 1, 0, tmp_ptr); | |||
| EXPECT_EQ(0, lite::CompareOutputData(out, c, 32)); | |||
| delete (matmul_param); | |||
| } | |||
| TEST_F(TestStrassenFp32, RecMatmul2) { | |||
| StrassenMatMulParameter *matmul_param = new StrassenMatMulParameter(); | |||
| matmul_param->row_ = 4; | |||
| matmul_param->deep_ = 2; | |||
| matmul_param->col_ = 2; | |||
| matmul_param->a_stride_ = 32; | |||
| matmul_param->b_stride_ = 64; | |||
| matmul_param->c_stride_ = 32; | |||
| float a[] = {9.02165, 8.657163, 0.56371903, 0.7272156, 1.6258951, 9.919627, 7.47593, 3.5311592, | |||
| 8.958062, 0.55338514, 9.611276, 7.429841, 8.23804, 3.7503464, 1.2829816, 6.4470887, | |||
| 1, 2, 3, 4, 1, 2, 3, 4, | |||
| 3, 2, 3, 4, 4, 2, 3, 4, | |||
| 4.303486, 6.282502, 0, 0, 9.4194765, 7.8199654, 0, 0, | |||
| 6.738705, 7.5398073, 0, 0, 0.47684374, 0.87746763, 0, 0, | |||
| 1, 2, 3, 4, 1, 2, 3, 4, | |||
| 3, 2, 3, 4, 4, 2, 3, 4}; | |||
| float b[] = { | |||
| 1.8100919, 6.016964, 5.733568, 5.768448, 2.2823029, 2.173359, 0.56861514, 7.134393, 0.26377398, 3.9010656, | |||
| 4.868408, 0.33401546, 1.7973539, 8.21896, 5.62239, 8.54786, 0.97356945, 1.0714527, 6.447588, 6.161091, | |||
| 3.332229, 2.8775468, 6.558747, 2.6986659, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 11, 2, 3, 4, 22, 2, 3, 4, | |||
| 33, 3, 3, 4, 44, 2, 3, 4, 11, 2, | |||
| 3, 4, 22, 2, 3, 4, 33, 3, 3, 4, | |||
| 44, 2, 3, 4, 1.9830805, 0, 0, 0, 8.44718, 0, | |||
| 0, 0, 9.360418, 0, 0, 0, 6.220693, 0, 0, 0, | |||
| 1.8369701, 0, 0, 0, 4.3965054, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 11, 2, 3, 4, | |||
| 22, 2, 3, 4, 33, 3, 3, 4, 44, 2, | |||
| 3, 4, 11, 2, 3, 4, 22, 2, 3, 4, | |||
| 33, 3, 3, 4, 44, 2, 3, 4}; | |||
| float c[] = {62.668518, 103.9633, 132.43439, 163.67749, 69.12974, 122.12326, 183.23413, 191.96806, | |||
| 65.052124, 182.57918, 233.14148, 184.20694, 38.785316, 118.74806, 100.689575, 135.12036, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 136.34613, 0, 0, 0, 230.64507, 0, 0, 0, | |||
| 204.15103, 0, 0, 0, 104.86488, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0, | |||
| 0, 0, 0, 0, 0, 0, 0, 0}; | |||
| float out[64] = {0}; | |||
| float tmp_ptr[1000]; | |||
| RecursionMatmul(a, b, out, matmul_param, 1, 0, tmp_ptr); | |||
| EXPECT_EQ(0, lite::CompareOutputData(out, c, 64)); | |||
| delete (matmul_param); | |||
| } | |||
| } // namespace mindspore | |||