|
|
|
@@ -50,6 +50,13 @@ ArithmeticSelfFunc ArithmeticSelfCPUKernel::GetArithmeticSelfFun(int primitive_t |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
ArithmeticSelfBoolFunc ArithmeticSelfCPUKernel::GetArithmeticSelfBoolFun(int primitive_type) { |
|
|
|
if (primitive_type == mindspore::schema::PrimitiveType_LogicalNot) { |
|
|
|
return ElementLogicalNotBool; |
|
|
|
} |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
int ArithmeticSelfCPUKernel::Init() { |
|
|
|
if (!InferShapeDone()) { |
|
|
|
return RET_OK; |
|
|
|
@@ -67,13 +74,27 @@ int ArithmeticSelfCPUKernel::DoExecute(int task_id) { |
|
|
|
if (count <= 0) { |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
if (func_ == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Run function is null! "; |
|
|
|
int ret = RET_ERROR; |
|
|
|
if (in_tensors_[0]->data_type() == kNumberTypeFloat32) { |
|
|
|
if (func_ == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Run function is null! "; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
float *input_ptr = reinterpret_cast<float *>(in_tensors_.at(0)->data_c()); |
|
|
|
float *output_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->data_c()); |
|
|
|
ret = func_(input_ptr + offset, output_ptr + offset, count); |
|
|
|
} else if (in_tensors_[0]->data_type() == kNumberTypeBool) { |
|
|
|
if (func_bool_ == nullptr) { |
|
|
|
MS_LOG(ERROR) << "Run function is null! "; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
bool *input_ptr = reinterpret_cast<bool *>(in_tensors_.at(0)->data_c()); |
|
|
|
bool *output_ptr = reinterpret_cast<bool *>(out_tensors_.at(0)->data_c()); |
|
|
|
ret = func_bool_(input_ptr + offset, output_ptr + offset, count); |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "Unsupported type: " << in_tensors_[0]->data_type() << "."; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
float *input_ptr = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData()); |
|
|
|
float *output_ptr = reinterpret_cast<float *>(out_tensors_.at(0)->MutableData()); |
|
|
|
auto ret = func_(input_ptr + offset, output_ptr + offset, count); |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Run failed, illegal input! "; |
|
|
|
} |
|
|
|
@@ -126,6 +147,7 @@ REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Sqrt, CpuArithmeticSelfFp32Ke |
|
|
|
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Rsqrt, CpuArithmeticSelfFp32KernelCreator) |
|
|
|
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Sin, CpuArithmeticSelfFp32KernelCreator) |
|
|
|
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LogicalNot, CpuArithmeticSelfFp32KernelCreator) |
|
|
|
REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_LogicalNot, CpuArithmeticSelfFp32KernelCreator) |
|
|
|
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Floor, CpuArithmeticSelfFp32KernelCreator) |
|
|
|
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Ceil, CpuArithmeticSelfFp32KernelCreator) |
|
|
|
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Round, CpuArithmeticSelfFp32KernelCreator) |
|
|
|
|