Browse Source

[MSLITE] Optimize gather fp16 op's performance

tags/v1.2.0-rc1
zhanyuan 4 years ago
parent
commit
69426cadf6
2 changed files with 25 additions and 7 deletions
  1. +23
    -6
      mindspore/lite/src/runtime/kernel/arm/fp16/gather_fp16.cc
  2. +2
    -1
      mindspore/lite/src/runtime/kernel/arm/fp16/gather_fp16.h

+ 23
- 6
mindspore/lite/src/runtime/kernel/arm/fp16/gather_fp16.cc View File

@@ -30,7 +30,22 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Gather;

namespace mindspore::kernel {
GatherFp16CPUKernel::~GatherFp16CPUKernel() {
if (input_data_) {
context_->allocator->Free(input_data_);
input_data_ = nullptr;
}
}

int GatherFp16CPUKernel::Init() {
auto input_tensor = in_tensors_.at(0);
if (input_tensor->data_type() == kNumberTypeFloat32 && input_tensor->data_c() != nullptr) {
const_input_ = true;
input_data_ =
reinterpret_cast<float16_t *>(context_->allocator->Malloc(input_tensor->ElementsNum() * sizeof(float16_t)));
Float32ToFloat16(reinterpret_cast<float *>(input_tensor->data_c()), input_data_, input_tensor->ElementsNum());
}

if (!InferShapeDone()) {
return RET_OK;
}
@@ -128,11 +143,13 @@ int GatherFp16CPUKernel::Run() {
MS_LOG(ERROR) << "AssignIndicesData failed, error_code[" << ret << "]";
return ret;
}
auto input_tensor = in_tensors_.at(0);
if (input_tensor->data_type() == kNumberTypeFloat32) {
input_data_ =
reinterpret_cast<float16_t *>(context_->allocator->Malloc(input_tensor->ElementsNum() * sizeof(float16_t)));
Float32ToFloat16(reinterpret_cast<float *>(input_tensor->data_c()), input_data_, input_tensor->ElementsNum());
if (!const_input_) {
auto input_tensor = in_tensors_.at(0);
if (input_tensor->data_type() == kNumberTypeFloat32) {
input_data_ =
reinterpret_cast<float16_t *>(context_->allocator->Malloc(input_tensor->ElementsNum() * sizeof(float16_t)));
Float32ToFloat16(reinterpret_cast<float *>(input_tensor->data_c()), input_data_, input_tensor->ElementsNum());
}
}
ret = ParallelLaunch(this->context_->thread_pool_, GatherRunFp16, this, op_parameter_->thread_num_);
if (ret != RET_OK) {
@@ -142,7 +159,7 @@ int GatherFp16CPUKernel::Run() {
context_->allocator->Free(indices_data_);
indices_data_ = nullptr;
}
if (input_data_) {
if (!const_input_ && input_data_) {
context_->allocator->Free(input_data_);
input_data_ = nullptr;
}


+ 2
- 1
mindspore/lite/src/runtime/kernel/arm/fp16/gather_fp16.h View File

@@ -30,7 +30,7 @@ class GatherFp16CPUKernel : public LiteKernel {
GatherFp16CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
: LiteKernel(parameter, inputs, outputs, ctx) {}
~GatherFp16CPUKernel() = default;
~GatherFp16CPUKernel() override;

int Init() override;
int ReSize() override;
@@ -42,6 +42,7 @@ class GatherFp16CPUKernel : public LiteKernel {
int *indices_data_ = nullptr;
int AssignIndicesData(bool isIndicesInt32, int indices_num, lite::Tensor *indices_tensor);
float16_t *input_data_ = nullptr;
bool const_input_ = false;
};
} // namespace mindspore::kernel



Loading…
Cancel
Save