|
|
|
@@ -78,7 +78,15 @@ int FullconnectionFP16CPUKernel::ReSize() { |
|
|
|
} |
|
|
|
memset(b_pack_ptr_, 0, fc_param_->col_8_ * fc_param_->deep_ * sizeof(float16_t)); |
|
|
|
|
|
|
|
InitMatrixB(reinterpret_cast<float *>(in_tensors_[1]->data_c()), b_pack_ptr_); |
|
|
|
fc_param_->b_const_ = (in_tensors_[1]->data_c() != nullptr); |
|
|
|
if (fc_param_->b_const_) { |
|
|
|
if (in_tensors_[1]->data_type() == kNumberTypeFloat32) { |
|
|
|
InitMatrixB(reinterpret_cast<float *>(in_tensors_[1]->data_c()), b_pack_ptr_); |
|
|
|
} else { |
|
|
|
InitMatrixB(reinterpret_cast<float16_t *>(in_tensors_[1]->data_c()), b_pack_ptr_); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (in_tensors_.size() == 3) { |
|
|
|
bias_ptr_ = reinterpret_cast<float16_t *>(ctx_->allocator->Malloc(fc_param_->col_8_ * sizeof(float16_t))); |
|
|
|
if (bias_ptr_ == nullptr) { |
|
|
|
@@ -108,6 +116,10 @@ void FullconnectionFP16CPUKernel::InitMatrixB(float *b_ptr, float16_t *b_pack_pt |
|
|
|
RowMajor2Col8MajorFp16(reinterpret_cast<void *>(b_ptr), b_pack_ptr, fc_param_->col_, fc_param_->deep_, true); |
|
|
|
} |
|
|
|
|
|
|
|
void FullconnectionFP16CPUKernel::InitMatrixB(float16_t *b_ptr, float16_t *b_pack_ptr) { |
|
|
|
RowMajor2Col8MajorFp16(reinterpret_cast<void *>(b_ptr), b_pack_ptr, fc_param_->col_, fc_param_->deep_, false); |
|
|
|
} |
|
|
|
|
|
|
|
int FullconnectionFP16CPUKernel::Init() { |
|
|
|
if (!InferShapeDone()) { |
|
|
|
return RET_OK; |
|
|
|
@@ -156,6 +168,13 @@ int FullconnectionFP16CPUKernel::Run() { |
|
|
|
} else { |
|
|
|
InitMatrixA(reinterpret_cast<float16_t *>(in_tensors_[0]->data_c()), a_pack_ptr_); |
|
|
|
} |
|
|
|
if (!fc_param_->b_const_) { |
|
|
|
if (in_tensors_[1]->data_type() == kNumberTypeFloat32) { |
|
|
|
InitMatrixB(reinterpret_cast<float *>(in_tensors_[1]->data_c()), b_pack_ptr_); |
|
|
|
} else { |
|
|
|
InitMatrixB(reinterpret_cast<float16_t *>(in_tensors_[1]->data_c()), b_pack_ptr_); |
|
|
|
} |
|
|
|
} |
|
|
|
ParallelLaunch(this->context_->thread_pool_, FcFP16Run, this, thread_count_); |
|
|
|
if (out_tensor->data_type() == kNumberTypeFloat32) { |
|
|
|
auto size = out_tensor->ElementsNum(); |
|
|
|
|