|
|
|
@@ -257,6 +257,9 @@ void MatrixMultiplyVec(const float32x4_t *matrix_a, const float32x4_t *matrix_b, |
|
|
|
|
|
|
|
int WinogradWeightTransform(const float *weight_data, float *winograd_data, float *matrix_g, const float *matrix_gt, |
|
|
|
int oc_block, int input_unit, int kernel_unit, int channel, int batch, bool pack) { |
|
|
|
if (oc_block == 0) { |
|
|
|
return NNACL_PARAM_INVALID; |
|
|
|
} |
|
|
|
// original weight format : ohwi |
|
|
|
int oc_block_num = UP_DIV(batch, oc_block); |
|
|
|
int block_stride = channel * oc_block; |
|
|
|
|