|
|
@@ -16,6 +16,7 @@ |
|
|
|
|
|
|
|
|
#include "src/runtime/kernel/arm/fp32/convolution_winograd.h" |
|
|
#include "src/runtime/kernel/arm/fp32/convolution_winograd.h" |
|
|
#include "nnacl/fp32/conv.h" |
|
|
#include "nnacl/fp32/conv.h" |
|
|
|
|
|
#include "nnacl/pack.h" |
|
|
#include "schema/model_generated.h" |
|
|
#include "schema/model_generated.h" |
|
|
#include "src/kernel_registry.h" |
|
|
#include "src/kernel_registry.h" |
|
|
#include "include/errorcode.h" |
|
|
#include "include/errorcode.h" |
|
|
@@ -31,78 +32,93 @@ using mindspore::schema::PrimitiveType_Conv2D; |
|
|
namespace mindspore::kernel { |
|
|
namespace mindspore::kernel { |
|
|
int ConvolutionWinogradCPUKernel::WinogradFilterTransform(const float *weight_data, float *matrix_g, float *matrix_gt, |
|
|
int ConvolutionWinogradCPUKernel::WinogradFilterTransform(const float *weight_data, float *matrix_g, float *matrix_gt, |
|
|
int oc_block) { |
|
|
int oc_block) { |
|
|
|
|
|
if (oc_block == 0) { |
|
|
|
|
|
MS_LOG(ERROR) << "Divide by zero"; |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
// original weight format : ohwi |
|
|
// original weight format : ohwi |
|
|
auto channel_in = conv_param_->input_channel_; |
|
|
auto channel_in = conv_param_->input_channel_; |
|
|
auto channel_out = conv_param_->output_channel_; |
|
|
auto channel_out = conv_param_->output_channel_; |
|
|
int input_unit_square = input_unit_ * input_unit_; |
|
|
|
|
|
int ic4 = UP_DIV(channel_in, C4NUM); |
|
|
int ic4 = UP_DIV(channel_in, C4NUM); |
|
|
int oc_block_num = UP_DIV(channel_out, oc_block); |
|
|
int oc_block_num = UP_DIV(channel_out, oc_block); |
|
|
|
|
|
int c4_channel = ic4 * C4NUM; |
|
|
|
|
|
int block_stride = c4_channel * oc_block; |
|
|
|
|
|
int block_num_stride = block_stride * oc_block_num; |
|
|
|
|
|
|
|
|
// trans_filter = G*g*GT (g represents weight_data) |
|
|
// trans_filter = G*g*GT (g represents weight_data) |
|
|
// separate into two steps ===> tmp = G*g ===> out = tmp * GT |
|
|
|
|
|
auto tmp_weight_data = reinterpret_cast<float *>(malloc(kernel_unit_ * kernel_unit_ * sizeof(float))); |
|
|
|
|
|
if (tmp_weight_data == nullptr) { |
|
|
|
|
|
MS_LOG(ERROR) << "malloc tmp_weight_data failed."; |
|
|
|
|
|
return RET_MEMORY_FAILED; |
|
|
|
|
|
} |
|
|
|
|
|
auto tmp_data = reinterpret_cast<float *>(malloc(input_unit_ * kernel_unit_ * sizeof(float))); |
|
|
|
|
|
|
|
|
// separate into two steps ===> tmp = (g * GT)T ===> trans = (tmp * GT)T use same function:MatrixMultiplyWinograd |
|
|
|
|
|
auto tmp_data = reinterpret_cast<float *>(malloc(c4_channel * input_unit_ * kernel_unit_ * sizeof(float))); |
|
|
if (tmp_data == nullptr) { |
|
|
if (tmp_data == nullptr) { |
|
|
free(tmp_weight_data); |
|
|
|
|
|
MS_LOG(ERROR) << "malloc tmp_data failed."; |
|
|
MS_LOG(ERROR) << "malloc tmp_data failed."; |
|
|
return RET_MEMORY_FAILED; |
|
|
return RET_MEMORY_FAILED; |
|
|
} |
|
|
} |
|
|
auto trans_out_data = reinterpret_cast<float *>(malloc(input_unit_ * input_unit_ * sizeof(float))); |
|
|
|
|
|
|
|
|
memset(tmp_data, 0, c4_channel * input_unit_ * kernel_unit_ * sizeof(float)); |
|
|
|
|
|
auto trans_out_data = reinterpret_cast<float *>(malloc(c4_channel * input_unit_ * input_unit_ * sizeof(float))); |
|
|
if (trans_out_data == nullptr) { |
|
|
if (trans_out_data == nullptr) { |
|
|
free(tmp_data); |
|
|
free(tmp_data); |
|
|
free(tmp_weight_data); |
|
|
|
|
|
MS_LOG(ERROR) << "malloc trans_out_data failed."; |
|
|
MS_LOG(ERROR) << "malloc trans_out_data failed."; |
|
|
return RET_MEMORY_FAILED; |
|
|
return RET_MEMORY_FAILED; |
|
|
} |
|
|
} |
|
|
std::vector<int> shape{input_unit_ * input_unit_, oc_block_num, ic4, C4NUM, oc_block}; |
|
|
|
|
|
std::vector<int> strides; |
|
|
|
|
|
for (int i = 0; i < 4; i++) { |
|
|
|
|
|
int stride = 1; |
|
|
|
|
|
for (int j = i + 1; j < 5; j++) { |
|
|
|
|
|
stride *= shape[j]; |
|
|
|
|
|
} |
|
|
|
|
|
strides.push_back(stride); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
int kernel_plane_stride = channel_in; |
|
|
|
|
|
if (oc_block == 0) { |
|
|
|
|
|
MS_LOG(ERROR) << "Divide by zero"; |
|
|
|
|
|
free(tmp_weight_data); |
|
|
|
|
|
|
|
|
#ifndef ENABLE_ARM64 |
|
|
|
|
|
auto tmp_data1 = reinterpret_cast<float *>(malloc(c4_channel * input_unit_ * kernel_unit_ * sizeof(float))); |
|
|
|
|
|
if (tmp_data1 == nullptr) { |
|
|
free(tmp_data); |
|
|
free(tmp_data); |
|
|
free(trans_out_data); |
|
|
free(trans_out_data); |
|
|
return RET_ERROR; |
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "malloc tmp_data1 failed."; |
|
|
|
|
|
return RET_MEMORY_FAILED; |
|
|
} |
|
|
} |
|
|
|
|
|
auto trans_out_data1 = reinterpret_cast<float *>(malloc(c4_channel * input_unit_ * input_unit_ * sizeof(float))); |
|
|
|
|
|
if (trans_out_data1 == nullptr) { |
|
|
|
|
|
free(tmp_data); |
|
|
|
|
|
free(tmp_data1); |
|
|
|
|
|
free(trans_out_data); |
|
|
|
|
|
MS_LOG(ERROR) << "malloc trans_out_data1 failed."; |
|
|
|
|
|
return RET_MEMORY_FAILED; |
|
|
|
|
|
} |
|
|
|
|
|
#endif |
|
|
|
|
|
|
|
|
|
|
|
int input_oz_offset = kernel_unit_ * kernel_unit_ * channel_in; |
|
|
for (int i = 0; i < channel_out; i++) { |
|
|
for (int i = 0; i < channel_out; i++) { |
|
|
int out_c_block = i / oc_block; |
|
|
int out_c_block = i / oc_block; |
|
|
int out_c_res = i % oc_block; |
|
|
int out_c_res = i % oc_block; |
|
|
int input_oz_offset = i * kernel_unit_ * kernel_unit_ * channel_in; |
|
|
|
|
|
int output_oz_offset = out_c_block * strides[1] + out_c_res; |
|
|
|
|
|
for (int j = 0; j < channel_in; j++) { |
|
|
|
|
|
int ic4_block = j / C4NUM; |
|
|
|
|
|
int ic4_res = j % C4NUM; |
|
|
|
|
|
int input_iz_offset = input_oz_offset + j; |
|
|
|
|
|
int output_iz_offset = output_oz_offset + ic4_block * strides[2] + ic4_res * strides[3]; |
|
|
|
|
|
for (int k = 0; k < kernel_unit_ * kernel_unit_; k++) { |
|
|
|
|
|
int input_xy_offset = input_iz_offset + k * kernel_plane_stride; |
|
|
|
|
|
tmp_weight_data[k] = *(weight_data + input_xy_offset); |
|
|
|
|
|
} |
|
|
|
|
|
// now we only support row-major matrix-multiply |
|
|
|
|
|
// tmp = G * g |
|
|
|
|
|
MatrixMultiply(matrix_g, tmp_weight_data, tmp_data, input_unit_, kernel_unit_, kernel_unit_); |
|
|
|
|
|
// out = tmp * GT |
|
|
|
|
|
MatrixMultiply(tmp_data, matrix_gt, trans_out_data, input_unit_, kernel_unit_, input_unit_); |
|
|
|
|
|
|
|
|
int output_oz_offset = out_c_block * block_stride + out_c_res; |
|
|
|
|
|
|
|
|
for (int z = 0; z < input_unit_square; z++) { |
|
|
|
|
|
int output_xy_offset = output_iz_offset + z * strides[0]; |
|
|
|
|
|
*(trans_weight_ + output_xy_offset) = trans_out_data[z]; |
|
|
|
|
|
|
|
|
#ifndef ENABLE_ARM64 |
|
|
|
|
|
// tmp_data = g * GT |
|
|
|
|
|
MatrixMultiplyWinograd(weight_data + i * input_oz_offset, matrix_gt, tmp_data, kernel_unit_, kernel_unit_, |
|
|
|
|
|
input_unit_, channel_in, c4_channel * 4); |
|
|
|
|
|
// tmp_data1 = (tmp_data)T |
|
|
|
|
|
PackHWCToWHC(tmp_data, tmp_data1, kernel_unit_, input_unit_, c4_channel); |
|
|
|
|
|
// trans_out_data1 = tmp * GT |
|
|
|
|
|
MatrixMultiplyWinograd(tmp_data1, matrix_gt, trans_out_data1, input_unit_, kernel_unit_, input_unit_, c4_channel, |
|
|
|
|
|
c4_channel * 4); |
|
|
|
|
|
// trans_out_data = (trans_out_data1)T |
|
|
|
|
|
PackHWCToWHC(trans_out_data1, trans_out_data, input_unit_, input_unit_, c4_channel); |
|
|
|
|
|
#else |
|
|
|
|
|
// tmp = (g * GT)T |
|
|
|
|
|
MatrixMultiplyWinograd(weight_data + i * input_oz_offset, matrix_gt, tmp_data, kernel_unit_, kernel_unit_, |
|
|
|
|
|
input_unit_, channel_in, c4_channel * 4); |
|
|
|
|
|
// trans = (tmp * GT)T |
|
|
|
|
|
MatrixMultiplyWinograd(tmp_data, matrix_gt, trans_out_data, input_unit_, kernel_unit_, input_unit_, c4_channel, |
|
|
|
|
|
c4_channel * 4); |
|
|
|
|
|
#endif |
|
|
|
|
|
|
|
|
|
|
|
int in_offset = 0; |
|
|
|
|
|
for (int j = 0; j < input_unit_; ++j) { |
|
|
|
|
|
for (int k = 0; k < input_unit_; ++k) { |
|
|
|
|
|
for (int c = 0; c < c4_channel; ++c) { |
|
|
|
|
|
*(trans_weight_ + output_oz_offset + c * oc_block) = trans_out_data[in_offset + c]; |
|
|
|
|
|
} |
|
|
|
|
|
in_offset += c4_channel; |
|
|
|
|
|
output_oz_offset += block_num_stride; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
free(tmp_weight_data); |
|
|
|
|
|
|
|
|
#ifndef ENABLE_ARM64 |
|
|
|
|
|
free(tmp_data1); |
|
|
|
|
|
free(trans_out_data1); |
|
|
|
|
|
#endif |
|
|
free(tmp_data); |
|
|
free(tmp_data); |
|
|
free(trans_out_data); |
|
|
free(trans_out_data); |
|
|
return RET_OK; |
|
|
return RET_OK; |
|
|
|