Browse Source

remove redundant preprocess in arithmetic operator and fix some fp16 bugs

tags/v1.1.0
zengxianglong 5 years ago
parent
commit
49517fa756
7 changed files with 26 additions and 79 deletions
  1. +17
    -35
      mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc
  2. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.h
  3. +1
    -35
      mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc
  4. +0
    -1
      mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h
  5. +1
    -1
      mindspore/lite/test/models_onnx.cfg
  6. +1
    -1
      mindspore/lite/test/models_onnx_fp16.cfg
  7. +5
    -5
      mindspore/lite/tools/benchmark/benchmark.cc

+ 17
- 35
mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc View File

@@ -21,8 +21,8 @@
#include "schema/model_generated.h" #include "schema/model_generated.h"
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "src/runtime/runtime_api.h" #include "src/runtime/runtime_api.h"
#include "src/ops/populate/arithmetic_populate.h"
#include "include/errorcode.h" #include "include/errorcode.h"
#include "src/ops/arithmetic.h"


using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar; using mindspore::lite::KernelRegistrar;
@@ -100,44 +100,26 @@ int ArithmeticFP16CPUKernel::Init() {
return ReSize(); return ReSize();
} }


int ArithmeticFP16CPUKernel::PreProcess() {
if (!InferShapeDone()) {
(const_cast<mindspore::lite::PrimitiveC *>(primitive_))->set_infer_flag(true);
auto ret = (const_cast<mindspore::lite::PrimitiveC *>(primitive_))->InferShape(in_tensors_, out_tensors_);
if (ret != 0) {
(const_cast<mindspore::lite::PrimitiveC *>(primitive_))->set_infer_flag(false);
MS_LOG(ERROR) << "InferShape fail!";
return ret;
}
if (op_parameter_ != nullptr) {
free(op_parameter_);
op_parameter_ = nullptr;
}
op_parameter_ = PopulateArithmetic(primitive_);
if (op_parameter_ == nullptr) {
MS_LOG(ERROR) << "Malloc parameter failed";
return RET_ERROR;
}
param_ = reinterpret_cast<ArithmeticParameter *>(op_parameter_);
ret = ReSize();
if (ret != 0) {
MS_LOG(ERROR) << "ReSize fail!ret: " << ret;
return ret;
}
}
void ArithmeticFP16CPUKernel::InitParam() {
auto arithmetic_lite_primitive = (lite::Arithmetic *)primitive_;
param_->broadcasting_ = arithmetic_lite_primitive->Broadcasting();
param_->ndim_ = arithmetic_lite_primitive->NDims();


auto outputs = this->out_tensors();
for (auto *output : outputs) {
MS_ASSERT(output != nullptr);
output->MallocData();
}
return RET_OK;
param_->in_elements_num0_ = in_tensors_[0]->ElementsNum();
param_->in_elements_num1_ = in_tensors_[1]->ElementsNum();
param_->out_elements_num_ = out_tensors_[0]->ElementsNum();
memcpy(param_->in_shape0_, reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape0().data(),
reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape0().size() * sizeof(int));
memcpy(param_->in_shape1_, reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape1().data(),
reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape1().size() * sizeof(int));
memcpy(param_->out_shape_, reinterpret_cast<const lite::Arithmetic *>(primitive_)->OutputShape().data(),
reinterpret_cast<const lite::Arithmetic *>(primitive_)->OutputShape().size() * sizeof(int));

return;
} }


int ArithmeticFP16CPUKernel::ReSize() { int ArithmeticFP16CPUKernel::ReSize() {
param_->in_elements_num0_ = in_tensors_.at(0)->ElementsNum();
param_->in_elements_num1_ = in_tensors_.at(1)->ElementsNum();
param_->out_elements_num_ = out_tensors_.at(0)->ElementsNum();
InitParam();


if (param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) { if (param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) {
param_->broadcasting_ = false; param_->broadcasting_ = false;


+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.h View File

@@ -44,7 +44,6 @@ class ArithmeticFP16CPUKernel : public LiteKernel {
~ArithmeticFP16CPUKernel() = default; ~ArithmeticFP16CPUKernel() = default;


int Init() override; int Init() override;
int PreProcess() override;
int ReSize() override; int ReSize() override;
int Run() override; int Run() override;
int DoArithmetic(int task_id); int DoArithmetic(int task_id);
@@ -52,6 +51,7 @@ class ArithmeticFP16CPUKernel : public LiteKernel {
int out_thread_stride); int out_thread_stride);


private: private:
void InitParam();
void FreeTmpBuffer(); void FreeTmpBuffer();
int outside_; int outside_;
int break_pos_; int break_pos_;


+ 1
- 35
mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc View File

@@ -20,7 +20,7 @@
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "src/runtime/kernel/arm/int8/add_int8.h" #include "src/runtime/kernel/arm/int8/add_int8.h"
#include "src/runtime/runtime_api.h" #include "src/runtime/runtime_api.h"
#include "src/ops/populate/arithmetic_populate.h"
#include "src/ops/arithmetic.h"


using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar; using mindspore::lite::KernelRegistrar;
@@ -94,40 +94,6 @@ int ArithmeticCPUKernel::InitBroadCastCase() {
return RET_OK; return RET_OK;
} }


int ArithmeticCPUKernel::PreProcess() {
if (!InferShapeDone()) {
(const_cast<mindspore::lite::PrimitiveC *>(primitive_))->set_infer_flag(true);
auto ret = (const_cast<mindspore::lite::PrimitiveC *>(primitive_))->InferShape(in_tensors_, out_tensors_);
if (ret != 0) {
(const_cast<mindspore::lite::PrimitiveC *>(primitive_))->set_infer_flag(false);
MS_LOG(ERROR) << "InferShape fail!";
return ret;
}
if (op_parameter_ != nullptr) {
free(op_parameter_);
op_parameter_ = nullptr;
}
op_parameter_ = PopulateArithmetic(primitive_);
if (op_parameter_ == nullptr) {
MS_LOG(ERROR) << "Malloc parameter failed";
return RET_ERROR;
}
arithmeticParameter_ = reinterpret_cast<ArithmeticParameter *>(op_parameter_);
ret = ReSize();
if (ret != 0) {
MS_LOG(ERROR) << "ReSize fail!ret: " << ret;
return ret;
}
}

auto outputs = this->out_tensors();
for (auto *output : outputs) {
MS_ASSERT(output != nullptr);
output->MallocData();
}
return RET_OK;
}

void ArithmeticCPUKernel::InitRunFunction() { void ArithmeticCPUKernel::InitRunFunction() {
switch (op_parameter_->type_) { switch (op_parameter_->type_) {
case PrimitiveType_Mul: case PrimitiveType_Mul:


+ 0
- 1
mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h View File

@@ -62,7 +62,6 @@ class ArithmeticCPUKernel : public LiteKernel {
~ArithmeticCPUKernel() override; ~ArithmeticCPUKernel() override;


int Init() override; int Init() override;
int PreProcess() override;
int ReSize() override; int ReSize() override;
int Run() override; int Run() override;
virtual int DoArithmetic(int task_id); virtual int DoArithmetic(int task_id);


+ 1
- 1
mindspore/lite/test/models_onnx.cfg View File

@@ -7,7 +7,7 @@ efficientnet-lite4-11.onnx
mobilenetv2-7.onnx mobilenetv2-7.onnx
shufflenet-v2-10.onnx shufflenet-v2-10.onnx
squeezenet1.1-7.onnx squeezenet1.1-7.onnx
#densenet-9.onnx
densenet-9.onnx
ml_table_detection_fp32.onnx ml_table_detection_fp32.onnx
ml_table_segment.onnx ml_table_segment.onnx
googlenet-9.onnx googlenet-9.onnx


+ 1
- 1
mindspore/lite/test/models_onnx_fp16.cfg View File

@@ -7,7 +7,7 @@ efficientnet-lite4-11.onnx 2
mobilenetv2-7.onnx 8 mobilenetv2-7.onnx 8
shufflenet-v2-10.onnx 5 shufflenet-v2-10.onnx 5
squeezenet1.1-7.onnx 1 squeezenet1.1-7.onnx 1
#densenet-9.onnx 6
densenet-9.onnx 6
ml_table_detection_fp32.onnx 2 ml_table_detection_fp32.onnx 2
ml_table_segment.onnx 2 ml_table_segment.onnx 2
googlenet-9.onnx 3 googlenet-9.onnx 3


+ 5
- 5
mindspore/lite/tools/benchmark/benchmark.cc View File

@@ -38,20 +38,20 @@ int Benchmark::GenerateRandomData(size_t size, void *data, TypeId data_type) {
switch (data_type) { switch (data_type) {
case kNumberTypeFloat32: case kNumberTypeFloat32:
case kNumberTypeFloat: case kNumberTypeFloat:
FillInputData<float>(size, data, std::uniform_real_distribution<float>(-0.5f, 0.5f));
FillInputData<float>(size, data, std::uniform_real_distribution<float>(0.1f, 1.0f));
break; break;
case kNumberTypeFloat64: case kNumberTypeFloat64:
FillInputData<double>(size, data, std::uniform_real_distribution<double>(-0.5, 0.5));
FillInputData<double>(size, data, std::uniform_real_distribution<double>(0.1, 1.0));
break; break;
case kNumberTypeInt64: case kNumberTypeInt64:
FillInputData<int64_t>(size, data, std::uniform_int_distribution<int64_t>(0, 99));
FillInputData<int64_t>(size, data, std::uniform_int_distribution<int64_t>(0, 1));
break; break;
case kNumberTypeInt: case kNumberTypeInt:
case kNumberTypeInt32: case kNumberTypeInt32:
FillInputData<int32_t>(size, data, std::uniform_int_distribution<int32_t>(0, 99));
FillInputData<int32_t>(size, data, std::uniform_int_distribution<int32_t>(0, 1));
break; break;
case kNumberTypeInt16: case kNumberTypeInt16:
FillInputData<int16_t>(size, data, std::uniform_int_distribution<int16_t>(0, 99));
FillInputData<int16_t>(size, data, std::uniform_int_distribution<int16_t>(0, 1));
break; break;
case kNumberTypeInt8: case kNumberTypeInt8:
FillInputData<int8_t>(size, data, std::uniform_int_distribution<int8_t>(-127, 127)); FillInputData<int8_t>(size, data, std::uniform_int_distribution<int8_t>(-127, 127));


Loading…
Cancel
Save