From 03bf0baef15692ae9d79503c143002090817bd27 Mon Sep 17 00:00:00 2001 From: ling Date: Fri, 19 Feb 2021 16:52:18 +0800 Subject: [PATCH] [MSLITE] fp32 matmal vec_matmul --- .../kernel/arm/fp32/matmul_fp32_base.cc | 20 ++++++++++++++++--- .../kernel/arm/fp32/matmul_fp32_base.h | 1 + 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.cc index 93869e6cb2..fbfcd642d3 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.cc @@ -56,7 +56,7 @@ void MatmulFp32BaseCPUKernel::InitParameter() { } void MatmulFp32BaseCPUKernel::ResizeParameter() { - if (params_->row_ == 1 && params_->b_const_ == false) { + if (params_->row_ == 1) { vec_matmul_ = true; } params_->row_align_ = vec_matmul_ ? 1 : UP_ROUND(params_->row_, row_tile_); @@ -238,10 +238,15 @@ int MatmulFp32BaseCPUKernel::Init() { } if (params_->b_const_ == true) { - if (RET_OK != InitBufferB()) { + /* copy origin b data, pack in resize + * pack after a infershape done */ + auto b_tensor = in_tensors_[1]; + src_b_ = reinterpret_cast(malloc(params_->batch * params_->col_ * params_->deep_ * sizeof(float))); + if (src_b_ == nullptr) { + MS_LOG(ERROR) << "Matmul fp16 malloc src_b_ failed"; return RET_ERROR; } - InitMatrixB(reinterpret_cast(in_tensors_[1]->data_c())); + memcpy(src_b_, b_tensor->data_c(), params_->batch * params_->col_ * params_->deep_ * sizeof(float)); } return RET_OK; } @@ -249,6 +254,15 @@ int MatmulFp32BaseCPUKernel::Init() { int MatmulFp32BaseCPUKernel::ReSize() { ResizeParameter(); + if (params_->b_const_ == true && src_b_ != nullptr) { + if (RET_OK != InitBufferB()) { + return RET_ERROR; + } + InitMatrixB(src_b_); + free(src_b_); + src_b_ = nullptr; + } + thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(params_->col_align_, col_tile_)); thread_stride_ = UP_DIV(UP_DIV(params_->col_align_, col_tile_), thread_count_); return RET_OK; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.h b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.h index 71b44d28cf..5a8a5e558c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/matmul_fp32_base.h @@ -69,6 +69,7 @@ class MatmulFp32BaseCPUKernel : public LiteKernel { int thread_stride_ = 0; int thread_count_ = 0; bool vec_matmul_ = false; + float *src_b_ = nullptr; float *bias_ptr_ = nullptr; float *batch_a_ptr_ = nullptr; float *batch_b_ptr_ = nullptr;