|
|
@@ -36,91 +36,106 @@ using mindspore::schema::PrimitiveType_Conv2D; |
|
|
namespace mindspore::kernel { |
|
|
namespace mindspore::kernel { |
|
|
int ConvolutionWinogradFP16CPUKernel::WinogradFilterTransformFp16(const float16_t *weight_data, float *matrix_g, |
|
|
int ConvolutionWinogradFP16CPUKernel::WinogradFilterTransformFp16(const float16_t *weight_data, float *matrix_g, |
|
|
float *matrix_gt, int oc_block) { |
|
|
float *matrix_gt, int oc_block) { |
|
|
|
|
|
if (oc_block == 0) { |
|
|
|
|
|
MS_LOG(ERROR) << "Divide by zero"; |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
// original weight format : ohwi |
|
|
// original weight format : ohwi |
|
|
auto channel_in = conv_param_->input_channel_; |
|
|
auto channel_in = conv_param_->input_channel_; |
|
|
auto channel_out = conv_param_->output_channel_; |
|
|
auto channel_out = conv_param_->output_channel_; |
|
|
int input_unit_square = input_unit_ * input_unit_; |
|
|
|
|
|
int oc_block_num = UP_DIV(channel_out, oc_block); |
|
|
int oc_block_num = UP_DIV(channel_out, oc_block); |
|
|
|
|
|
int block_stride = channel_in * oc_block; |
|
|
|
|
|
int block_num_stride = block_stride * oc_block_num; |
|
|
|
|
|
|
|
|
auto matrix_g_data_fp16 = reinterpret_cast<float16_t *>(malloc(input_unit_ * kernel_unit_ * sizeof(float16_t))); |
|
|
|
|
|
if (matrix_g_data_fp16 == nullptr) { |
|
|
|
|
|
MS_LOG(ERROR) << "malloc matrix_g_data_fp16 failed."; |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
|
|
|
auto matrix_gt_data_fp16 = reinterpret_cast<float16_t *>(malloc(input_unit_ * kernel_unit_ * sizeof(float16_t))); |
|
|
auto matrix_gt_data_fp16 = reinterpret_cast<float16_t *>(malloc(input_unit_ * kernel_unit_ * sizeof(float16_t))); |
|
|
if (matrix_gt_data_fp16 == nullptr) { |
|
|
if (matrix_gt_data_fp16 == nullptr) { |
|
|
free(matrix_g_data_fp16); |
|
|
|
|
|
MS_LOG(ERROR) << "malloc matrix_gt_data_fp16 failed."; |
|
|
MS_LOG(ERROR) << "malloc matrix_gt_data_fp16 failed."; |
|
|
return RET_ERROR; |
|
|
return RET_ERROR; |
|
|
} |
|
|
} |
|
|
Float32ToFloat16(matrix_g, matrix_g_data_fp16, input_unit_ * kernel_unit_); |
|
|
|
|
|
Float32ToFloat16(matrix_gt, matrix_gt_data_fp16, input_unit_ * kernel_unit_); |
|
|
Float32ToFloat16(matrix_gt, matrix_gt_data_fp16, input_unit_ * kernel_unit_); |
|
|
|
|
|
|
|
|
// trans_filter = G*g*GT (g represents weight_data) |
|
|
|
|
|
// separate into two steps ===> tmp = G*g ===> out = tmp * GT |
|
|
|
|
|
auto tmp_weight_data = reinterpret_cast<float16_t *>(malloc(kernel_unit_ * kernel_unit_ * sizeof(float16_t))); |
|
|
|
|
|
if (tmp_weight_data == nullptr) { |
|
|
|
|
|
free(matrix_g_data_fp16); |
|
|
|
|
|
free(matrix_gt_data_fp16); |
|
|
|
|
|
MS_LOG(ERROR) << "malloc tmp_weight_data failed."; |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
|
|
|
auto tmp_data = reinterpret_cast<float16_t *>(malloc(input_unit_ * kernel_unit_ * sizeof(float16_t))); |
|
|
|
|
|
|
|
|
// trans_filter = G*g*GT (g represents weight_data) = [(g * (G)T)T * (G)T]T |
|
|
|
|
|
// separate into two steps ===> tmp = (g * (G)T)T ===> out = [tmp * (G)T]T |
|
|
|
|
|
auto tmp_data = reinterpret_cast<float16_t *>(malloc(channel_in * input_unit_ * kernel_unit_ * sizeof(float16_t))); |
|
|
if (tmp_data == nullptr) { |
|
|
if (tmp_data == nullptr) { |
|
|
free(tmp_weight_data); |
|
|
|
|
|
free(matrix_g_data_fp16); |
|
|
|
|
|
free(matrix_gt_data_fp16); |
|
|
free(matrix_gt_data_fp16); |
|
|
MS_LOG(ERROR) << "malloc tmp_data failed."; |
|
|
MS_LOG(ERROR) << "malloc tmp_data failed."; |
|
|
return RET_ERROR; |
|
|
return RET_ERROR; |
|
|
} |
|
|
} |
|
|
auto trans_out_data = reinterpret_cast<float16_t *>(malloc(input_unit_ * input_unit_ * sizeof(float16_t))); |
|
|
|
|
|
|
|
|
auto trans_out_data = |
|
|
|
|
|
reinterpret_cast<float16_t *>(malloc(channel_in * input_unit_ * input_unit_ * sizeof(float16_t))); |
|
|
if (trans_out_data == nullptr) { |
|
|
if (trans_out_data == nullptr) { |
|
|
free(tmp_data); |
|
|
free(tmp_data); |
|
|
free(tmp_weight_data); |
|
|
|
|
|
free(matrix_g_data_fp16); |
|
|
|
|
|
free(matrix_gt_data_fp16); |
|
|
free(matrix_gt_data_fp16); |
|
|
MS_LOG(ERROR) << "malloc trans_out_data failed."; |
|
|
MS_LOG(ERROR) << "malloc trans_out_data failed."; |
|
|
return RET_ERROR; |
|
|
return RET_ERROR; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if (oc_block == 0) { |
|
|
|
|
|
MS_LOG(ERROR) << "Divide by zero"; |
|
|
|
|
|
free(tmp_weight_data); |
|
|
|
|
|
|
|
|
#ifndef ENABLE_ARM64 |
|
|
|
|
|
auto tmp_data1 = reinterpret_cast<float16_t *>(malloc(channel_in * input_unit_ * kernel_unit_ * sizeof(float16_t))); |
|
|
|
|
|
if (tmp_data1 == nullptr) { |
|
|
free(tmp_data); |
|
|
free(tmp_data); |
|
|
|
|
|
free(matrix_gt_data_fp16); |
|
|
free(trans_out_data); |
|
|
free(trans_out_data); |
|
|
free(matrix_g_data_fp16); |
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "malloc tmp_data1 failed."; |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
|
|
|
auto trans_out_data1 = |
|
|
|
|
|
reinterpret_cast<float16_t *>(malloc(channel_in * input_unit_ * input_unit_ * sizeof(float16_t))); |
|
|
|
|
|
if (trans_out_data1 == nullptr) { |
|
|
|
|
|
free(tmp_data); |
|
|
|
|
|
free(tmp_data1); |
|
|
free(matrix_gt_data_fp16); |
|
|
free(matrix_gt_data_fp16); |
|
|
|
|
|
free(trans_out_data); |
|
|
|
|
|
MS_LOG(ERROR) << "malloc trans_out_data1 failed."; |
|
|
return RET_ERROR; |
|
|
return RET_ERROR; |
|
|
} |
|
|
} |
|
|
int stride1 = channel_in * oc_block; |
|
|
|
|
|
|
|
|
#endif |
|
|
|
|
|
|
|
|
|
|
|
int input_oz_offset = kernel_unit_ * kernel_unit_ * channel_in; |
|
|
for (int i = 0; i < channel_out; i++) { |
|
|
for (int i = 0; i < channel_out; i++) { |
|
|
int out_c_block = i / oc_block; |
|
|
int out_c_block = i / oc_block; |
|
|
int out_c_res = i % oc_block; |
|
|
int out_c_res = i % oc_block; |
|
|
int input_oz_offset = i * kernel_unit_ * kernel_unit_ * channel_in; |
|
|
|
|
|
int output_oz_offset = out_c_block * stride1 + out_c_res; |
|
|
|
|
|
for (int j = 0; j < channel_in; j++) { |
|
|
|
|
|
int input_iz_offset = input_oz_offset + j; |
|
|
|
|
|
int output_iz_offset = output_oz_offset + j * oc_block; |
|
|
|
|
|
for (int k = 0; k < kernel_unit_ * kernel_unit_; k++) { |
|
|
|
|
|
int input_xy_offset = input_iz_offset + k * channel_in; |
|
|
|
|
|
tmp_weight_data[k] = *(weight_data + input_xy_offset); |
|
|
|
|
|
} |
|
|
|
|
|
// now we only support row-major matrix-multiply |
|
|
|
|
|
// tmp = G * g |
|
|
|
|
|
MatrixMultiplyFp16(matrix_g_data_fp16, tmp_weight_data, tmp_data, input_unit_, kernel_unit_, kernel_unit_); |
|
|
|
|
|
// out = tmp * GT |
|
|
|
|
|
MatrixMultiplyFp16(tmp_data, matrix_gt_data_fp16, trans_out_data, input_unit_, kernel_unit_, input_unit_); |
|
|
|
|
|
|
|
|
|
|
|
for (int z = 0; z < input_unit_square; z++) { |
|
|
|
|
|
int output_xy_offset = output_iz_offset + z * oc_block_num * stride1; |
|
|
|
|
|
trans_weight_[output_xy_offset] = trans_out_data[z]; |
|
|
|
|
|
|
|
|
int output_oz_offset = out_c_block * block_stride + out_c_res; |
|
|
|
|
|
|
|
|
|
|
|
#ifndef ENABLE_ARM64 |
|
|
|
|
|
// tmp_data = g * GT |
|
|
|
|
|
MatrixMultiplyWinogradFp16(weight_data + i * input_oz_offset, matrix_gt_data_fp16, tmp_data, kernel_unit_, |
|
|
|
|
|
kernel_unit_, input_unit_, channel_in); |
|
|
|
|
|
// tmp_data1 = (tmp_data)T |
|
|
|
|
|
PackHWCToWHCFp16(tmp_data, tmp_data1, kernel_unit_, input_unit_, channel_in); |
|
|
|
|
|
// trans_out_data1 = tmp * GT |
|
|
|
|
|
MatrixMultiplyWinogradFp16(tmp_data1, matrix_gt_data_fp16, trans_out_data1, input_unit_, kernel_unit_, input_unit_, |
|
|
|
|
|
channel_in); |
|
|
|
|
|
// trans_out_data = (trans_out_data1)T |
|
|
|
|
|
PackHWCToWHCFp16(trans_out_data1, trans_out_data, input_unit_, input_unit_, channel_in); |
|
|
|
|
|
#else |
|
|
|
|
|
// tmp = (g * GT)T |
|
|
|
|
|
MatrixMultiplyWinogradFp16(weight_data + i * input_oz_offset, matrix_gt_data_fp16, tmp_data, kernel_unit_, |
|
|
|
|
|
kernel_unit_, input_unit_, channel_in); |
|
|
|
|
|
// trans = (tmp * GT)T |
|
|
|
|
|
MatrixMultiplyWinogradFp16(tmp_data, matrix_gt_data_fp16, trans_out_data, input_unit_, kernel_unit_, input_unit_, |
|
|
|
|
|
channel_in); |
|
|
|
|
|
#endif |
|
|
|
|
|
|
|
|
|
|
|
int in_offset = 0; |
|
|
|
|
|
for (int j = 0; j < input_unit_; ++j) { |
|
|
|
|
|
for (int k = 0; k < input_unit_; ++k) { |
|
|
|
|
|
for (int c = 0; c < channel_in; ++c) { |
|
|
|
|
|
*(trans_weight_ + output_oz_offset + c * oc_block) = trans_out_data[in_offset + c]; |
|
|
|
|
|
} |
|
|
|
|
|
in_offset += channel_in; |
|
|
|
|
|
output_oz_offset += block_num_stride; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
free(tmp_weight_data); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#ifndef ENABLE_ARM64 |
|
|
|
|
|
free(tmp_data1); |
|
|
|
|
|
free(trans_out_data1); |
|
|
|
|
|
#endif |
|
|
free(tmp_data); |
|
|
free(tmp_data); |
|
|
free(trans_out_data); |
|
|
free(trans_out_data); |
|
|
free(matrix_g_data_fp16); |
|
|
|
|
|
free(matrix_gt_data_fp16); |
|
|
free(matrix_gt_data_fp16); |
|
|
return RET_OK; |
|
|
return RET_OK; |
|
|
} |
|
|
} |
|
|
|