|
|
@@ -56,7 +56,7 @@ void MatmulFp32BaseCPUKernel::InitParameter() { |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void MatmulFp32BaseCPUKernel::ResizeParameter() { |
|
|
void MatmulFp32BaseCPUKernel::ResizeParameter() { |
|
|
if (params_->row_ == 1 && params_->b_const_ == false) { |
|
|
|
|
|
|
|
|
if (params_->row_ == 1) { |
|
|
vec_matmul_ = true; |
|
|
vec_matmul_ = true; |
|
|
} |
|
|
} |
|
|
params_->row_align_ = vec_matmul_ ? 1 : UP_ROUND(params_->row_, row_tile_); |
|
|
params_->row_align_ = vec_matmul_ ? 1 : UP_ROUND(params_->row_, row_tile_); |
|
|
@@ -238,10 +238,15 @@ int MatmulFp32BaseCPUKernel::Init() { |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if (params_->b_const_ == true) { |
|
|
if (params_->b_const_ == true) { |
|
|
if (RET_OK != InitBufferB()) { |
|
|
|
|
|
|
|
|
/* copy origin b data, pack in resize |
|
|
|
|
|
* pack after a infershape done */ |
|
|
|
|
|
auto b_tensor = in_tensors_[1]; |
|
|
|
|
|
src_b_ = reinterpret_cast<float *>(malloc(params_->batch * params_->col_ * params_->deep_ * sizeof(float))); |
|
|
|
|
|
if (src_b_ == nullptr) { |
|
|
|
|
|
MS_LOG(ERROR) << "Matmul fp16 malloc src_b_ failed"; |
|
|
return RET_ERROR; |
|
|
return RET_ERROR; |
|
|
} |
|
|
} |
|
|
InitMatrixB(reinterpret_cast<float *>(in_tensors_[1]->data_c())); |
|
|
|
|
|
|
|
|
memcpy(src_b_, b_tensor->data_c(), params_->batch * params_->col_ * params_->deep_ * sizeof(float)); |
|
|
} |
|
|
} |
|
|
return RET_OK; |
|
|
return RET_OK; |
|
|
} |
|
|
} |
|
|
@@ -249,6 +254,15 @@ int MatmulFp32BaseCPUKernel::Init() { |
|
|
int MatmulFp32BaseCPUKernel::ReSize() { |
|
|
int MatmulFp32BaseCPUKernel::ReSize() { |
|
|
ResizeParameter(); |
|
|
ResizeParameter(); |
|
|
|
|
|
|
|
|
|
|
|
if (params_->b_const_ == true && src_b_ != nullptr) { |
|
|
|
|
|
if (RET_OK != InitBufferB()) { |
|
|
|
|
|
return RET_ERROR; |
|
|
|
|
|
} |
|
|
|
|
|
InitMatrixB(src_b_); |
|
|
|
|
|
free(src_b_); |
|
|
|
|
|
src_b_ = nullptr; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(params_->col_align_, col_tile_)); |
|
|
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(params_->col_align_, col_tile_)); |
|
|
thread_stride_ = UP_DIV(UP_DIV(params_->col_align_, col_tile_), thread_count_); |
|
|
thread_stride_ = UP_DIV(UP_DIV(params_->col_align_, col_tile_), thread_count_); |
|
|
return RET_OK; |
|
|
return RET_OK; |
|
|
|