Browse Source

!9268 [MS][LITE] Support logticnot bool operation,fix arithmetic

From: @gongdaguo
Reviewed-by: @zhang_xue_tong,@zhang_xue_tong,@zhanghaibo5
Signed-off-by: @zhang_xue_tong
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
8453b0d243
7 changed files with 53 additions and 27 deletions
  1. +8
    -0
      mindspore/lite/nnacl/fp32/arithmetic_self_fp32.c
  2. +2
    -0
      mindspore/lite/nnacl/fp32/arithmetic_self_fp32.h
  3. +3
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc
  4. +9
    -21
      mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc
  5. +27
    -5
      mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self_fp32.cc
  6. +4
    -0
      mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self_fp32.h
  7. +0
    -1
      mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_self_int8.cc

+ 8
- 0
mindspore/lite/nnacl/fp32/arithmetic_self_fp32.c View File

@@ -91,6 +91,14 @@ int ElementLogicalNot(const float *input, float *output, const int element_size)
return NNACL_OK; return NNACL_OK;
} }


// logical_not:
int ElementLogicalNotBool(const bool *input, bool *output, const int element_size) {
for (int i = 0; i < element_size; i++) {
output[i] = !input[i];
}
return NNACL_OK;
}

// round: // round:
int ElementRound(const float *input, float *output, const int element_size) { int ElementRound(const float *input, float *output, const int element_size) {
for (int i = 0; i < element_size; i++) { for (int i = 0; i < element_size; i++) {


+ 2
- 0
mindspore/lite/nnacl/fp32/arithmetic_self_fp32.h View File

@@ -42,6 +42,8 @@ int ElementSin(const float *input, float *output, const int element_size);


int ElementLogicalNot(const float *input, float *output, const int element_size); int ElementLogicalNot(const float *input, float *output, const int element_size);


int ElementLogicalNotBool(const bool *input, bool *output, const int element_size);

int ElementRound(const float *input, float *output, const int element_size); int ElementRound(const float *input, float *output, const int element_size);


int ElementFloor(const float *input, float *output, const int element_size); int ElementFloor(const float *input, float *output, const int element_size);


+ 3
- 0
mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc View File

@@ -80,6 +80,9 @@ int ArithmeticCompareCPUKernel::DoArithmetic(int task_id) {
stride = UP_DIV(outside_, thread_count_); stride = UP_DIV(outside_, thread_count_);
int out_count = MSMIN(stride, outside_ - stride * task_id); int out_count = MSMIN(stride, outside_ - stride * task_id);
int out_thread_stride = stride * task_id; int out_thread_stride = stride * task_id;
if (out_count <= 0) {
return RET_OK;
}
if (data_type_ == kDataTypeFloat) { if (data_type_ == kDataTypeFloat) {
error_code = BroadcastRun( error_code = BroadcastRun(
reinterpret_cast<float *>(in_tensors_[0]->data_c()), reinterpret_cast<float *>(in_tensors_[1]->data_c()), reinterpret_cast<float *>(in_tensors_[0]->data_c()), reinterpret_cast<float *>(in_tensors_[1]->data_c()),


+ 9
- 21
mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc View File

@@ -82,27 +82,12 @@ int ArithmeticCPUKernel::ReSize() {
arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum(); arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum();
arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum(); arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum();
arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum(); arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum();
for (size_t i = 0; i < in_tensors_[0]->shape().size(); i++) {
if (arithmeticParameter_->in_shape0_[i] == -1) {
memcpy(arithmeticParameter_->in_shape0_, static_cast<void *>(in_tensors_[0]->shape().data()),
in_tensors_[0]->shape().size() * sizeof(int));
break;
}
}
for (size_t i = 0; i < in_tensors_[1]->shape().size(); i++) {
if (arithmeticParameter_->in_shape1_[i] == -1) {
memcpy(arithmeticParameter_->in_shape1_, static_cast<void *>(in_tensors_[1]->shape().data()),
in_tensors_[1]->shape().size() * sizeof(int));
break;
}
}
for (size_t i = 0; i < out_tensors_[0]->shape().size(); i++) {
if (arithmeticParameter_->out_shape_[i] == -1) {
memcpy(arithmeticParameter_->out_shape_, static_cast<void *>(out_tensors_[0]->shape().data()),
out_tensors_[0]->shape().size() * sizeof(int));
break;
}
}
memcpy(arithmeticParameter_->in_shape0_, reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape0().data(),
reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape0().size() * sizeof(int));
memcpy(arithmeticParameter_->in_shape1_, reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape1().data(),
reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape1().size() * sizeof(int));
memcpy(arithmeticParameter_->out_shape_, reinterpret_cast<const lite::Arithmetic *>(primitive_)->OutputShape().data(),
reinterpret_cast<const lite::Arithmetic *>(primitive_)->OutputShape().size() * sizeof(int));


if (arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) { if (arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) {
switch (arithmeticParameter_->op_parameter_.type_) { switch (arithmeticParameter_->op_parameter_.type_) {
@@ -244,6 +229,9 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) {
if (arithmeticParameter_->broadcasting_) { // need broadcast if (arithmeticParameter_->broadcasting_) { // need broadcast
stride = UP_DIV(outside_, thread_count_); stride = UP_DIV(outside_, thread_count_);
int out_count = MSMIN(stride, outside_ - stride * task_id); int out_count = MSMIN(stride, outside_ - stride * task_id);
if (out_count <= 0) {
return RET_OK;
}
int out_thread_stride = stride * task_id; int out_thread_stride = stride * task_id;
if (data_type_ == kDataTypeFloat) { if (data_type_ == kDataTypeFloat) {
error_code = BroadcastRun(reinterpret_cast<float *>(in_tensors_[0]->data_c()), error_code = BroadcastRun(reinterpret_cast<float *>(in_tensors_[0]->data_c()),


+ 27
- 5
mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self_fp32.cc View File

@@ -50,6 +50,13 @@ ArithmeticSelfFunc ArithmeticSelfCPUKernel::GetArithmeticSelfFun(int primitive_t
return nullptr; return nullptr;
} }


ArithmeticSelfBoolFunc ArithmeticSelfCPUKernel::GetArithmeticSelfBoolFun(int primitive_type) {
if (primitive_type == mindspore::schema::PrimitiveType_LogicalNot) {
return ElementLogicalNotBool;
}
return nullptr;
}

int ArithmeticSelfCPUKernel::Init() { int ArithmeticSelfCPUKernel::Init() {
if (!InferShapeDone()) { if (!InferShapeDone()) {
return RET_OK; return RET_OK;
@@ -67,13 +74,27 @@ int ArithmeticSelfCPUKernel::DoExecute(int task_id) {
if (count <= 0) { if (count <= 0) {
return RET_OK; return RET_OK;
} }
if (func_ == nullptr) {
MS_LOG(ERROR) << "Run function is null! ";
int ret = RET_ERROR;
if (in_tensors_[0]->data_type() == kNumberTypeFloat32) {
if (func_ == nullptr) {
MS_LOG(ERROR) << "Run function is null! ";
return RET_ERROR;
}
float *input_ptr = reinterpret_cast<float *>(in_tensors_.at(0)->data_c());
float *output_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->data_c());
ret = func_(input_ptr + offset, output_ptr + offset, count);
} else if (in_tensors_[0]->data_type() == kNumberTypeBool) {
if (func_bool_ == nullptr) {
MS_LOG(ERROR) << "Run function is null! ";
return RET_ERROR;
}
bool *input_ptr = reinterpret_cast<bool *>(in_tensors_.at(0)->data_c());
bool *output_ptr = reinterpret_cast<bool *>(out_tensors_.at(0)->data_c());
ret = func_bool_(input_ptr + offset, output_ptr + offset, count);
} else {
MS_LOG(ERROR) << "Unsupported type: " << in_tensors_[0]->data_type() << ".";
return RET_ERROR; return RET_ERROR;
} }
float *input_ptr = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData());
float *output_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData());
auto ret = func_(input_ptr + offset, output_ptr + offset, count);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Run failed, illegal input! "; MS_LOG(ERROR) << "Run failed, illegal input! ";
} }
@@ -126,6 +147,7 @@ REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Sqrt, CpuArithmeticSelfFp32Ke
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Rsqrt, CpuArithmeticSelfFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Rsqrt, CpuArithmeticSelfFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Sin, CpuArithmeticSelfFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Sin, CpuArithmeticSelfFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LogicalNot, CpuArithmeticSelfFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LogicalNot, CpuArithmeticSelfFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_LogicalNot, CpuArithmeticSelfFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Floor, CpuArithmeticSelfFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Floor, CpuArithmeticSelfFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Ceil, CpuArithmeticSelfFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Ceil, CpuArithmeticSelfFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Round, CpuArithmeticSelfFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Round, CpuArithmeticSelfFp32KernelCreator)


+ 4
- 0
mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self_fp32.h View File

@@ -34,6 +34,7 @@ using mindspore::schema::PrimitiveType_Square;


namespace mindspore::kernel { namespace mindspore::kernel {
typedef int (*ArithmeticSelfFunc)(const float *input, float *output, const int element_size); typedef int (*ArithmeticSelfFunc)(const float *input, float *output, const int element_size);
typedef int (*ArithmeticSelfBoolFunc)(const bool *input, bool *output, const int element_size);
class ArithmeticSelfCPUKernel : public LiteKernel { class ArithmeticSelfCPUKernel : public LiteKernel {
public: public:
explicit ArithmeticSelfCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, explicit ArithmeticSelfCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
@@ -41,6 +42,7 @@ class ArithmeticSelfCPUKernel : public LiteKernel {
const mindspore::lite::PrimitiveC *primitive) const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) { : LiteKernel(parameter, inputs, outputs, ctx, primitive) {
func_ = GetArithmeticSelfFun(parameter->type_); func_ = GetArithmeticSelfFun(parameter->type_);
func_bool_ = GetArithmeticSelfBoolFun(parameter->type_);
} }
~ArithmeticSelfCPUKernel() override = default; ~ArithmeticSelfCPUKernel() override = default;


@@ -51,7 +53,9 @@ class ArithmeticSelfCPUKernel : public LiteKernel {


private: private:
ArithmeticSelfFunc GetArithmeticSelfFun(int primitive_type); ArithmeticSelfFunc GetArithmeticSelfFun(int primitive_type);
ArithmeticSelfBoolFunc GetArithmeticSelfBoolFun(int primitive_type);
ArithmeticSelfFunc func_; ArithmeticSelfFunc func_;
ArithmeticSelfBoolFunc func_bool_;
}; };
int ArithmeticSelfRun(void *cdata, int task_id); int ArithmeticSelfRun(void *cdata, int task_id);
} // namespace mindspore::kernel } // namespace mindspore::kernel


+ 0
- 1
mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_self_int8.cc View File

@@ -146,5 +146,4 @@ REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Sqrt, CpuArithmeticSelfInt8Kerne
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Rsqrt, CpuArithmeticSelfInt8KernelCreator) REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Rsqrt, CpuArithmeticSelfInt8KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Square, CpuArithmeticSelfInt8KernelCreator) REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Square, CpuArithmeticSelfInt8KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_LogicalNot, CpuArithmeticSelfInt8KernelCreator) REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_LogicalNot, CpuArithmeticSelfInt8KernelCreator)
REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_LogicalNot, CpuArithmeticSelfInt8KernelCreator)
} // namespace mindspore::kernel } // namespace mindspore::kernel

Loading…
Cancel
Save