|
|
|
@@ -57,7 +57,7 @@ int MatmulCPUKernel::MallocMatrixABuffer() { |
|
|
|
params_->batch = batch; |
|
|
|
params_->row_ = params_->a_transpose_ ? a_shape[a_shape.size() - 1] : a_shape[a_shape.size() - 2]; |
|
|
|
#ifdef ENABLE_ARM64 |
|
|
|
if (params_->row_ == 1) { |
|
|
|
if (params_->a_init_shape_ && params_->row_ == 1) { |
|
|
|
is_vector_a_ = true; |
|
|
|
} |
|
|
|
#endif |
|
|
|
@@ -134,7 +134,7 @@ int MatmulCPUKernel::InitBias() { |
|
|
|
} |
|
|
|
|
|
|
|
int MatmulCPUKernel::ReSize() { |
|
|
|
if (params_->a_const_ == false || params_->a_has_shape_ == false) { |
|
|
|
if (params_->a_const_ == false || params_->a_init_shape_ == false) { |
|
|
|
if (a_pack_ptr_ != nullptr) { |
|
|
|
free(a_pack_ptr_); |
|
|
|
a_pack_ptr_ = nullptr; |
|
|
|
@@ -145,7 +145,7 @@ int MatmulCPUKernel::ReSize() { |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
} |
|
|
|
if (params_->b_const_ == false || params_->b_has_shape_ == false) { |
|
|
|
if (params_->b_const_ == false || params_->b_init_shape_ == false) { |
|
|
|
if (b_pack_ptr_ != nullptr) { |
|
|
|
free(b_pack_ptr_); |
|
|
|
b_pack_ptr_ = nullptr; |
|
|
|
@@ -222,16 +222,16 @@ void MatmulCPUKernel::InitMatrixB(float *src_ptr, float *dst_ptr) { |
|
|
|
} |
|
|
|
|
|
|
|
int MatmulCPUKernel::Init() { |
|
|
|
params_->a_has_shape_ = (in_tensors_[0]->shape().size() != 0); |
|
|
|
params_->b_has_shape_ = (in_tensors_[1]->shape().size() != 0); |
|
|
|
if (params_->a_has_shape_ == true) { |
|
|
|
params_->a_init_shape_ = (in_tensors_[0]->shape().size() != 0); |
|
|
|
params_->b_init_shape_ = (in_tensors_[1]->shape().size() != 0); |
|
|
|
if (params_->a_init_shape_ == true) { |
|
|
|
auto ret = MallocMatrixABuffer(); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Matmul fp32 malloc matrix a buffer failed"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
} |
|
|
|
if (params_->b_has_shape_ == true) { |
|
|
|
if (params_->b_init_shape_ == true) { |
|
|
|
auto ret = MallocMatrixBBuffer(); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Matmul fp32 malloc matrix b buffer failed"; |
|
|
|
@@ -300,7 +300,7 @@ int MatmulCPUKernel::Run() { |
|
|
|
} |
|
|
|
} |
|
|
|
if (params_->b_const_ == false || is_train()) { |
|
|
|
if (is_vector_a_) { |
|
|
|
if (is_vector_a_ && params_->b_transpose_) { |
|
|
|
b_ptr_ = b_src; |
|
|
|
} else { |
|
|
|
InitMatrixB(b_src, b_pack_ptr_); |
|
|
|
|