Browse Source

[MS][LITE] fix bug of arm cpu fp16 infer: set subgraph output tensor data_type float32

tags/v1.0.0
yangruoqi713 5 years ago
parent
commit
4e4ad85cb1
3 changed files with 27 additions and 47 deletions
  1. +20
    -35
      mindspore/lite/src/runtime/kernel/arm/fp16/pooling_fp16.cc
  2. +1
    -9
      mindspore/lite/src/runtime/kernel/arm/fp16/pooling_fp16.h
  3. +6
    -3
      mindspore/lite/src/scheduler.cc

+ 20
- 35
mindspore/lite/src/runtime/kernel/arm/fp16/pooling_fp16.cc View File

@@ -21,6 +21,7 @@
#include "include/errorcode.h" #include "include/errorcode.h"
#include "nnacl/op_base.h" #include "nnacl/op_base.h"
#include "nnacl/fp16/cast_fp16.h" #include "nnacl/fp16/cast_fp16.h"
#include "src/runtime/kernel/arm/fp16/common_fp16.h"


using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar; using mindspore::lite::KernelRegistrar;
@@ -29,29 +30,6 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Pooling; using mindspore::schema::PrimitiveType_Pooling;


namespace mindspore::kernel { namespace mindspore::kernel {
int PoolingFp16CPUKernel::InitBuffer() {
int in_batch = pooling_param_->input_batch_;
int in_h = pooling_param_->input_h_;
int in_w = pooling_param_->input_w_;
int in_channel = pooling_param_->input_channel_;
fp16_input_ = reinterpret_cast<float16_t *>(malloc(in_batch * in_h * in_w * in_channel * sizeof(float16_t)));
if (fp16_input_ == nullptr) {
MS_LOG(ERROR) << "malloc fp16_input_ failed.";
return RET_ERROR;
}

int out_batch = pooling_param_->output_batch_;
int out_h = pooling_param_->output_h_;
int out_w = pooling_param_->output_w_;
int out_channel = pooling_param_->output_channel_;
fp16_output_ = reinterpret_cast<float16_t *>(malloc(out_batch * out_h * out_w * out_channel * sizeof(float16_t)));
if (fp16_output_ == nullptr) {
MS_LOG(ERROR) << "fp16_out malloc failed.";
return RET_ERROR;
}
return RET_OK;
}

int PoolingFp16CPUKernel::Init() { int PoolingFp16CPUKernel::Init() {
auto ret = PoolingBaseCPUKernel::Init(); auto ret = PoolingBaseCPUKernel::Init();
if (ret != RET_OK) { if (ret != RET_OK) {
@@ -71,12 +49,6 @@ int PoolingFp16CPUKernel::ReSize() {
MS_LOG(ERROR) << "PoolingBase ReSize fai1!ret: " << ret; MS_LOG(ERROR) << "PoolingBase ReSize fai1!ret: " << ret;
return ret; return ret;
} }

ret = InitBuffer();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init Buffer fail!ret: " << ret;
return ret;
}
return RET_OK; return RET_OK;
} }


@@ -105,9 +77,16 @@ int PoolingFp16CPUKernel::Run() {
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
return prepare_ret; return prepare_ret;
} }
auto ele_num = in_tensors_.front()->ElementsNum();
auto input_ptr = reinterpret_cast<float *>(in_tensors_.at(kInputIndex)->Data());
Float32ToFloat16(input_ptr, fp16_input_, ele_num);

auto input_tensor = in_tensors_.at(kInputIndex);
auto in_data_type_ = input_tensor->data_type();
MS_ASSERT(in_data_type_ == kNumberTypeFloat32 || in_data_type_ == kNumberTypeFloat16);
fp16_input_ = ConvertInputFp32toFp16(input_tensor, context_);

auto out_tensor = out_tensors_.at(kOutputIndex);
auto out_data_type_ = out_tensor->data_type();
MS_ASSERT(out_data_type_ == kNumberTypeFloat32 || out_data_type_ == kNumberTypeFloat16);
fp16_output_ = MallocOutputFp16(out_tensor, context_);


int error_code = ParallelLaunch(THREAD_POOL_DEFAULT, PoolingFp16Impl, this, thread_count_); int error_code = ParallelLaunch(THREAD_POOL_DEFAULT, PoolingFp16Impl, this, thread_count_);
if (error_code != RET_OK) { if (error_code != RET_OK) {
@@ -115,9 +94,15 @@ int PoolingFp16CPUKernel::Run() {
return RET_ERROR; return RET_ERROR;
} }


auto out_ele_num = out_tensors_.front()->ElementsNum();
auto output_ptr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->Data());
Float16ToFloat32(fp16_output_, output_ptr, out_ele_num);
if (in_data_type_ == kNumberTypeFloat32) {
context_->allocator->Free(fp16_input_);
}
if (out_data_type_ == kNumberTypeFloat32) {
auto out_ele_num = out_tensor->ElementsNum();
auto output_addr = reinterpret_cast<float *>(out_tensor->Data());
Float16ToFloat32(fp16_output_, output_addr, out_ele_num);
context_->allocator->Free(fp16_output_);
}
return RET_OK; return RET_OK;
} }




+ 1
- 9
mindspore/lite/src/runtime/kernel/arm/fp16/pooling_fp16.h View File

@@ -28,17 +28,9 @@ class PoolingFp16CPUKernel : public PoolingBaseCPUKernel {
const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx, const std::vector<lite::tensor::Tensor *> &outputs, const Context *ctx,
const mindspore::lite::PrimitiveC *primitive) const mindspore::lite::PrimitiveC *primitive)
: PoolingBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} : PoolingBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
~PoolingFp16CPUKernel() override {
if (fp16_input_ != nullptr) {
free(fp16_input_);
}
if (fp16_output_ != nullptr) {
free(fp16_output_);
}
};
~PoolingFp16CPUKernel() override = default;


int Init() override; int Init() override;
int InitBuffer();
int ReSize() override; int ReSize() override;
int Run() override; int Run() override;
int RunImpl(int task_id); int RunImpl(int task_id);


+ 6
- 3
mindspore/lite/src/scheduler.cc View File

@@ -182,9 +182,12 @@ void Scheduler::ConstructSubgraphs(std::vector<kernel::LiteKernel *> *kernels) {
for (auto kernel : temp_kernels) { for (auto kernel : temp_kernels) {
for (auto tensor : kernel->out_tensors()) { for (auto tensor : kernel->out_tensors()) {
tensor->set_allocator(context_->allocator.get()); tensor->set_allocator(context_->allocator.get());
if (context_->float16_priority && tensor->data_type() == kNumberTypeFloat16) {
tensor->set_data_type(kNumberTypeFloat32);
}
}
}
std::vector<tensor::Tensor *> output_tensor = kernel::LiteKernelUtil::SubgraphOutputTensors(temp_kernels);
for (auto tensor : output_tensor) {
if (context_->float16_priority && tensor->data_type() == kNumberTypeFloat16) {
tensor->set_data_type(kNumberTypeFloat32);
} }
} }
std::copy(temp_kernels.begin(), temp_kernels.end(), std::back_inserter(subgraph_kernels)); std::copy(temp_kernels.begin(), temp_kernels.end(), std::back_inserter(subgraph_kernels));


Loading…
Cancel
Save