From: @gongdaguo Reviewed-by: @zhang_xue_tong,@zhang_xue_tong,@zhanghaibo5 Signed-off-by: @zhang_xue_tongtags/v1.1.0
| @@ -91,6 +91,14 @@ int ElementLogicalNot(const float *input, float *output, const int element_size) | |||
| return NNACL_OK; | |||
| } | |||
| // logical_not: | |||
| int ElementLogicalNotBool(const bool *input, bool *output, const int element_size) { | |||
| for (int i = 0; i < element_size; i++) { | |||
| output[i] = !input[i]; | |||
| } | |||
| return NNACL_OK; | |||
| } | |||
| // round: | |||
| int ElementRound(const float *input, float *output, const int element_size) { | |||
| for (int i = 0; i < element_size; i++) { | |||
| @@ -42,6 +42,8 @@ int ElementSin(const float *input, float *output, const int element_size); | |||
| int ElementLogicalNot(const float *input, float *output, const int element_size); | |||
| int ElementLogicalNotBool(const bool *input, bool *output, const int element_size); | |||
| int ElementRound(const float *input, float *output, const int element_size); | |||
| int ElementFloor(const float *input, float *output, const int element_size); | |||
| @@ -80,6 +80,9 @@ int ArithmeticCompareCPUKernel::DoArithmetic(int task_id) { | |||
| stride = UP_DIV(outside_, thread_count_); | |||
| int out_count = MSMIN(stride, outside_ - stride * task_id); | |||
| int out_thread_stride = stride * task_id; | |||
| if (out_count <= 0) { | |||
| return RET_OK; | |||
| } | |||
| if (data_type_ == kDataTypeFloat) { | |||
| error_code = BroadcastRun( | |||
| reinterpret_cast<float *>(in_tensors_[0]->data_c()), reinterpret_cast<float *>(in_tensors_[1]->data_c()), | |||
| @@ -82,27 +82,12 @@ int ArithmeticCPUKernel::ReSize() { | |||
| arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum(); | |||
| arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum(); | |||
| arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum(); | |||
| for (size_t i = 0; i < in_tensors_[0]->shape().size(); i++) { | |||
| if (arithmeticParameter_->in_shape0_[i] == -1) { | |||
| memcpy(arithmeticParameter_->in_shape0_, static_cast<void *>(in_tensors_[0]->shape().data()), | |||
| in_tensors_[0]->shape().size() * sizeof(int)); | |||
| break; | |||
| } | |||
| } | |||
| for (size_t i = 0; i < in_tensors_[1]->shape().size(); i++) { | |||
| if (arithmeticParameter_->in_shape1_[i] == -1) { | |||
| memcpy(arithmeticParameter_->in_shape1_, static_cast<void *>(in_tensors_[1]->shape().data()), | |||
| in_tensors_[1]->shape().size() * sizeof(int)); | |||
| break; | |||
| } | |||
| } | |||
| for (size_t i = 0; i < out_tensors_[0]->shape().size(); i++) { | |||
| if (arithmeticParameter_->out_shape_[i] == -1) { | |||
| memcpy(arithmeticParameter_->out_shape_, static_cast<void *>(out_tensors_[0]->shape().data()), | |||
| out_tensors_[0]->shape().size() * sizeof(int)); | |||
| break; | |||
| } | |||
| } | |||
| memcpy(arithmeticParameter_->in_shape0_, reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape0().data(), | |||
| reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape0().size() * sizeof(int)); | |||
| memcpy(arithmeticParameter_->in_shape1_, reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape1().data(), | |||
| reinterpret_cast<const lite::Arithmetic *>(primitive_)->InShape1().size() * sizeof(int)); | |||
| memcpy(arithmeticParameter_->out_shape_, reinterpret_cast<const lite::Arithmetic *>(primitive_)->OutputShape().data(), | |||
| reinterpret_cast<const lite::Arithmetic *>(primitive_)->OutputShape().size() * sizeof(int)); | |||
| if (arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) { | |||
| switch (arithmeticParameter_->op_parameter_.type_) { | |||
| @@ -244,6 +229,9 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) { | |||
| if (arithmeticParameter_->broadcasting_) { // need broadcast | |||
| stride = UP_DIV(outside_, thread_count_); | |||
| int out_count = MSMIN(stride, outside_ - stride * task_id); | |||
| if (out_count <= 0) { | |||
| return RET_OK; | |||
| } | |||
| int out_thread_stride = stride * task_id; | |||
| if (data_type_ == kDataTypeFloat) { | |||
| error_code = BroadcastRun(reinterpret_cast<float *>(in_tensors_[0]->data_c()), | |||
| @@ -50,6 +50,13 @@ ArithmeticSelfFunc ArithmeticSelfCPUKernel::GetArithmeticSelfFun(int primitive_t | |||
| return nullptr; | |||
| } | |||
| ArithmeticSelfBoolFunc ArithmeticSelfCPUKernel::GetArithmeticSelfBoolFun(int primitive_type) { | |||
| if (primitive_type == mindspore::schema::PrimitiveType_LogicalNot) { | |||
| return ElementLogicalNotBool; | |||
| } | |||
| return nullptr; | |||
| } | |||
| int ArithmeticSelfCPUKernel::Init() { | |||
| if (!InferShapeDone()) { | |||
| return RET_OK; | |||
| @@ -67,13 +74,27 @@ int ArithmeticSelfCPUKernel::DoExecute(int task_id) { | |||
| if (count <= 0) { | |||
| return RET_OK; | |||
| } | |||
| if (func_ == nullptr) { | |||
| MS_LOG(ERROR) << "Run function is null! "; | |||
| int ret = RET_ERROR; | |||
| if (in_tensors_[0]->data_type() == kNumberTypeFloat32) { | |||
| if (func_ == nullptr) { | |||
| MS_LOG(ERROR) << "Run function is null! "; | |||
| return RET_ERROR; | |||
| } | |||
| float *input_ptr = reinterpret_cast<float *>(in_tensors_.at(0)->data_c()); | |||
| float *output_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->data_c()); | |||
| ret = func_(input_ptr + offset, output_ptr + offset, count); | |||
| } else if (in_tensors_[0]->data_type() == kNumberTypeBool) { | |||
| if (func_bool_ == nullptr) { | |||
| MS_LOG(ERROR) << "Run function is null! "; | |||
| return RET_ERROR; | |||
| } | |||
| bool *input_ptr = reinterpret_cast<bool *>(in_tensors_.at(0)->data_c()); | |||
| bool *output_ptr = reinterpret_cast<bool *>(out_tensors_.at(0)->data_c()); | |||
| ret = func_bool_(input_ptr + offset, output_ptr + offset, count); | |||
| } else { | |||
| MS_LOG(ERROR) << "Unsupported type: " << in_tensors_[0]->data_type() << "."; | |||
| return RET_ERROR; | |||
| } | |||
| float *input_ptr = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData()); | |||
| float *output_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); | |||
| auto ret = func_(input_ptr + offset, output_ptr + offset, count); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Run failed, illegal input! "; | |||
| } | |||
| @@ -126,6 +147,7 @@ REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Sqrt, CpuArithmeticSelfFp32Ke | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Rsqrt, CpuArithmeticSelfFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Sin, CpuArithmeticSelfFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LogicalNot, CpuArithmeticSelfFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_LogicalNot, CpuArithmeticSelfFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Floor, CpuArithmeticSelfFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Ceil, CpuArithmeticSelfFp32KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Round, CpuArithmeticSelfFp32KernelCreator) | |||
| @@ -34,6 +34,7 @@ using mindspore::schema::PrimitiveType_Square; | |||
| namespace mindspore::kernel { | |||
| typedef int (*ArithmeticSelfFunc)(const float *input, float *output, const int element_size); | |||
| typedef int (*ArithmeticSelfBoolFunc)(const bool *input, bool *output, const int element_size); | |||
| class ArithmeticSelfCPUKernel : public LiteKernel { | |||
| public: | |||
| explicit ArithmeticSelfCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| @@ -41,6 +42,7 @@ class ArithmeticSelfCPUKernel : public LiteKernel { | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : LiteKernel(parameter, inputs, outputs, ctx, primitive) { | |||
| func_ = GetArithmeticSelfFun(parameter->type_); | |||
| func_bool_ = GetArithmeticSelfBoolFun(parameter->type_); | |||
| } | |||
| ~ArithmeticSelfCPUKernel() override = default; | |||
| @@ -51,7 +53,9 @@ class ArithmeticSelfCPUKernel : public LiteKernel { | |||
| private: | |||
| ArithmeticSelfFunc GetArithmeticSelfFun(int primitive_type); | |||
| ArithmeticSelfBoolFunc GetArithmeticSelfBoolFun(int primitive_type); | |||
| ArithmeticSelfFunc func_; | |||
| ArithmeticSelfBoolFunc func_bool_; | |||
| }; | |||
| int ArithmeticSelfRun(void *cdata, int task_id); | |||
| } // namespace mindspore::kernel | |||
| @@ -146,5 +146,4 @@ REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Sqrt, CpuArithmeticSelfInt8Kerne | |||
| REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Rsqrt, CpuArithmeticSelfInt8KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Square, CpuArithmeticSelfInt8KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_LogicalNot, CpuArithmeticSelfInt8KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_LogicalNot, CpuArithmeticSelfInt8KernelCreator) | |||
| } // namespace mindspore::kernel | |||