|
|
@@ -17,8 +17,8 @@ |
|
|
#include "nnacl/fp16/winograd_utils_fp16.h" |
|
|
#include "nnacl/fp16/winograd_utils_fp16.h" |
|
|
#include "nnacl/fp16/matrix_fp16.h" |
|
|
#include "nnacl/fp16/matrix_fp16.h" |
|
|
|
|
|
|
|
|
#define MIN_UNIT 2 |
|
|
|
|
|
#define MAX_UNIT 8 |
|
|
|
|
|
|
|
|
#define MIN_UNIT_FP16 2 |
|
|
|
|
|
#define MAX_UNIT_FP16 4 |
|
|
|
|
|
|
|
|
void GeneralInputTransformUnitFp16(const float16_t *src_data, float16_t *dst_data, float16_t *matrix_b, |
|
|
void GeneralInputTransformUnitFp16(const float16_t *src_data, float16_t *dst_data, float16_t *matrix_b, |
|
|
float16_t *matrix_bt, int src_step, int dst_step, int in_unit) { |
|
|
float16_t *matrix_bt, int src_step, int dst_step, int in_unit) { |
|
|
@@ -2942,3 +2942,52 @@ void OutputTransform8x7Relu6UnitFp16(const float16_t *src_data, float16_t *dst_d |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
int SelectOutputUnitFp16(ConvParameter *conv_param) { |
|
|
|
|
|
int kernel_h = conv_param->kernel_h_; |
|
|
|
|
|
int kernel_w = conv_param->kernel_w_; |
|
|
|
|
|
int in_c = conv_param->input_channel_; |
|
|
|
|
|
int out_w = conv_param->output_w_; |
|
|
|
|
|
int out_h = conv_param->output_h_; |
|
|
|
|
|
int out_c = conv_param->output_channel_; |
|
|
|
|
|
int unit2 = UP_DIV(out_w * out_h, C16NUM * conv_param->op_parameter_.thread_num_); |
|
|
|
|
|
int max_out_unit = (int)(sqrtf((float)unit2)); |
|
|
|
|
|
max_out_unit = max_out_unit < MAX_UNIT_FP16 ? max_out_unit : MAX_UNIT_FP16; |
|
|
|
|
|
max_out_unit = max_out_unit > MIN_UNIT_FP16 ? max_out_unit : MIN_UNIT_FP16; |
|
|
|
|
|
|
|
|
|
|
|
int unit = 0; |
|
|
|
|
|
float max_rate = 0.0f; |
|
|
|
|
|
float common_cost = (float)out_h * out_w * in_c * out_c * kernel_h * kernel_w; |
|
|
|
|
|
|
|
|
|
|
|
for (int i = MIN_UNIT_FP16; i <= max_out_unit; ++i) { |
|
|
|
|
|
int input_unit = i + kernel_w - 1; |
|
|
|
|
|
if (!GetOutputTransFp16Func(input_unit, i, ActType_No)) { |
|
|
|
|
|
continue; |
|
|
|
|
|
} |
|
|
|
|
|
float penalty = ((float)input_unit * input_unit) / ((float)kernel_h * kernel_w) * 0.12f; |
|
|
|
|
|
float wino_cost = ((2 + out_c) * (float)input_unit * input_unit * in_c + ((float)input_unit + i) * i * out_c) * |
|
|
|
|
|
UP_DIV(out_w, i) * UP_DIV(out_h, i); |
|
|
|
|
|
float reduce_rate = common_cost / wino_cost - penalty; |
|
|
|
|
|
if (reduce_rate > max_rate) { |
|
|
|
|
|
max_rate = reduce_rate; |
|
|
|
|
|
unit = i; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
if (max_rate < 1.0f) { |
|
|
|
|
|
return 1; |
|
|
|
|
|
} |
|
|
|
|
|
// If output_unit is 1, then it is conventional convolution |
|
|
|
|
|
return unit; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void CheckIfUseWinogradFp16(bool *use_winograd, int *output_unit, ConvParameter *conv_param) { |
|
|
|
|
|
if (conv_param->kernel_w_ == conv_param->kernel_h_ && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 && |
|
|
|
|
|
conv_param->stride_h_ == 1 && conv_param->stride_w_ == 1) { |
|
|
|
|
|
*output_unit = SelectOutputUnitFp16(conv_param); |
|
|
|
|
|
if (*output_unit > 1) { |
|
|
|
|
|
*use_winograd = true; |
|
|
|
|
|
} |
|
|
|
|
|
} else { |
|
|
|
|
|
*use_winograd = false; |
|
|
|
|
|
} |
|
|
|
|
|
} |