diff --git a/mindspore/lite/nnacl/fp32/arithmetic_self_fp32.c b/mindspore/lite/nnacl/fp32/arithmetic_self_fp32.c index 5adedddd18..caa048bf8b 100644 --- a/mindspore/lite/nnacl/fp32/arithmetic_self_fp32.c +++ b/mindspore/lite/nnacl/fp32/arithmetic_self_fp32.c @@ -91,6 +91,14 @@ int ElementLogicalNot(const float *input, float *output, const int element_size) return NNACL_OK; } +// logical_not: +int ElementLogicalNotBool(const bool *input, bool *output, const int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = !input[i]; + } + return NNACL_OK; +} + // round: int ElementRound(const float *input, float *output, const int element_size) { for (int i = 0; i < element_size; i++) { diff --git a/mindspore/lite/nnacl/fp32/arithmetic_self_fp32.h b/mindspore/lite/nnacl/fp32/arithmetic_self_fp32.h index 6aa92f3581..d29aaa5a16 100644 --- a/mindspore/lite/nnacl/fp32/arithmetic_self_fp32.h +++ b/mindspore/lite/nnacl/fp32/arithmetic_self_fp32.h @@ -42,6 +42,8 @@ int ElementSin(const float *input, float *output, const int element_size); int ElementLogicalNot(const float *input, float *output, const int element_size); +int ElementLogicalNotBool(const bool *input, bool *output, const int element_size); + int ElementRound(const float *input, float *output, const int element_size); int ElementFloor(const float *input, float *output, const int element_size); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc index 62df9fabf1..840d76c5a8 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_compare_fp32.cc @@ -80,6 +80,9 @@ int ArithmeticCompareCPUKernel::DoArithmetic(int task_id) { stride = UP_DIV(outside_, thread_count_); int out_count = MSMIN(stride, outside_ - stride * task_id); int out_thread_stride = stride * task_id; + if (out_count <= 0) { + return RET_OK; + } if (data_type_ == kDataTypeFloat) { error_code = BroadcastRun( reinterpret_cast(in_tensors_[0]->data_c()), reinterpret_cast(in_tensors_[1]->data_c()), diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc index b802488a0e..f0af0c1e98 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc @@ -82,27 +82,12 @@ int ArithmeticCPUKernel::ReSize() { arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum(); arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum(); arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum(); - for (size_t i = 0; i < in_tensors_[0]->shape().size(); i++) { - if (arithmeticParameter_->in_shape0_[i] == -1) { - memcpy(arithmeticParameter_->in_shape0_, static_cast(in_tensors_[0]->shape().data()), - in_tensors_[0]->shape().size() * sizeof(int)); - break; - } - } - for (size_t i = 0; i < in_tensors_[1]->shape().size(); i++) { - if (arithmeticParameter_->in_shape1_[i] == -1) { - memcpy(arithmeticParameter_->in_shape1_, static_cast(in_tensors_[1]->shape().data()), - in_tensors_[1]->shape().size() * sizeof(int)); - break; - } - } - for (size_t i = 0; i < out_tensors_[0]->shape().size(); i++) { - if (arithmeticParameter_->out_shape_[i] == -1) { - memcpy(arithmeticParameter_->out_shape_, static_cast(out_tensors_[0]->shape().data()), - out_tensors_[0]->shape().size() * sizeof(int)); - break; - } - } + memcpy(arithmeticParameter_->in_shape0_, reinterpret_cast(primitive_)->InShape0().data(), + reinterpret_cast(primitive_)->InShape0().size() * sizeof(int)); + memcpy(arithmeticParameter_->in_shape1_, reinterpret_cast(primitive_)->InShape1().data(), + reinterpret_cast(primitive_)->InShape1().size() * sizeof(int)); + memcpy(arithmeticParameter_->out_shape_, reinterpret_cast(primitive_)->OutputShape().data(), + reinterpret_cast(primitive_)->OutputShape().size() * sizeof(int)); if (arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) { switch (arithmeticParameter_->op_parameter_.type_) { @@ -244,6 +229,9 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) { if (arithmeticParameter_->broadcasting_) { // need broadcast stride = UP_DIV(outside_, thread_count_); int out_count = MSMIN(stride, outside_ - stride * task_id); + if (out_count <= 0) { + return RET_OK; + } int out_thread_stride = stride * task_id; if (data_type_ == kDataTypeFloat) { error_code = BroadcastRun(reinterpret_cast(in_tensors_[0]->data_c()), diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self_fp32.cc index 7482874ff0..a9d978a4a8 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self_fp32.cc @@ -50,6 +50,13 @@ ArithmeticSelfFunc ArithmeticSelfCPUKernel::GetArithmeticSelfFun(int primitive_t return nullptr; } +ArithmeticSelfBoolFunc ArithmeticSelfCPUKernel::GetArithmeticSelfBoolFun(int primitive_type) { + if (primitive_type == mindspore::schema::PrimitiveType_LogicalNot) { + return ElementLogicalNotBool; + } + return nullptr; +} + int ArithmeticSelfCPUKernel::Init() { if (!InferShapeDone()) { return RET_OK; @@ -67,13 +74,27 @@ int ArithmeticSelfCPUKernel::DoExecute(int task_id) { if (count <= 0) { return RET_OK; } - if (func_ == nullptr) { - MS_LOG(ERROR) << "Run function is null! "; + int ret = RET_ERROR; + if (in_tensors_[0]->data_type() == kNumberTypeFloat32) { + if (func_ == nullptr) { + MS_LOG(ERROR) << "Run function is null! "; + return RET_ERROR; + } + float *input_ptr = reinterpret_cast(in_tensors_.at(0)->data_c()); + float *output_ptr = reinterpret_cast(out_tensors_.at(0)->data_c()); + ret = func_(input_ptr + offset, output_ptr + offset, count); + } else if (in_tensors_[0]->data_type() == kNumberTypeBool) { + if (func_bool_ == nullptr) { + MS_LOG(ERROR) << "Run function is null! "; + return RET_ERROR; + } + bool *input_ptr = reinterpret_cast(in_tensors_.at(0)->data_c()); + bool *output_ptr = reinterpret_cast(out_tensors_.at(0)->data_c()); + ret = func_bool_(input_ptr + offset, output_ptr + offset, count); + } else { + MS_LOG(ERROR) << "Unsupported type: " << in_tensors_[0]->data_type() << "."; return RET_ERROR; } - float *input_ptr = reinterpret_cast(in_tensors_.at(0)->MutableData()); - float *output_ptr = reinterpret_cast(out_tensors_.at(0)->MutableData()); - auto ret = func_(input_ptr + offset, output_ptr + offset, count); if (ret != RET_OK) { MS_LOG(ERROR) << "Run failed, illegal input! "; } @@ -126,6 +147,7 @@ REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Sqrt, CpuArithmeticSelfFp32Ke REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Rsqrt, CpuArithmeticSelfFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Sin, CpuArithmeticSelfFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LogicalNot, CpuArithmeticSelfFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_LogicalNot, CpuArithmeticSelfFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Floor, CpuArithmeticSelfFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Ceil, CpuArithmeticSelfFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Round, CpuArithmeticSelfFp32KernelCreator) diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self_fp32.h index 0da2a09c91..5d6a775653 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_self_fp32.h @@ -34,6 +34,7 @@ using mindspore::schema::PrimitiveType_Square; namespace mindspore::kernel { typedef int (*ArithmeticSelfFunc)(const float *input, float *output, const int element_size); +typedef int (*ArithmeticSelfBoolFunc)(const bool *input, bool *output, const int element_size); class ArithmeticSelfCPUKernel : public LiteKernel { public: explicit ArithmeticSelfCPUKernel(OpParameter *parameter, const std::vector &inputs, @@ -41,6 +42,7 @@ class ArithmeticSelfCPUKernel : public LiteKernel { const mindspore::lite::PrimitiveC *primitive) : LiteKernel(parameter, inputs, outputs, ctx, primitive) { func_ = GetArithmeticSelfFun(parameter->type_); + func_bool_ = GetArithmeticSelfBoolFun(parameter->type_); } ~ArithmeticSelfCPUKernel() override = default; @@ -51,7 +53,9 @@ class ArithmeticSelfCPUKernel : public LiteKernel { private: ArithmeticSelfFunc GetArithmeticSelfFun(int primitive_type); + ArithmeticSelfBoolFunc GetArithmeticSelfBoolFun(int primitive_type); ArithmeticSelfFunc func_; + ArithmeticSelfBoolFunc func_bool_; }; int ArithmeticSelfRun(void *cdata, int task_id); } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_self_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_self_int8.cc index 94e9f4f9c5..bdfc349ec4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_self_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/arithmetic_self_int8.cc @@ -146,5 +146,4 @@ REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Sqrt, CpuArithmeticSelfInt8Kerne REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Rsqrt, CpuArithmeticSelfInt8KernelCreator) REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Square, CpuArithmeticSelfInt8KernelCreator) REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_LogicalNot, CpuArithmeticSelfInt8KernelCreator) -REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_LogicalNot, CpuArithmeticSelfInt8KernelCreator) } // namespace mindspore::kernel