|
|
|
@@ -29,6 +29,7 @@ using mindspore::schema::ActivationType_HSWISH; |
|
|
|
using mindspore::schema::ActivationType_LEAKY_RELU; |
|
|
|
using mindspore::schema::ActivationType_RELU; |
|
|
|
using mindspore::schema::ActivationType_RELU6; |
|
|
|
using mindspore::schema::ActivationType_SWISH; |
|
|
|
using mindspore::schema::PrimitiveType_Activation; |
|
|
|
|
|
|
|
namespace mindspore::kernel { |
|
|
|
@@ -44,32 +45,34 @@ int ActivationCPUKernel::DoActivation(int task_id) { |
|
|
|
int stride = UP_DIV(length, thread_count_); |
|
|
|
int count = MSMIN(stride, length - stride * task_id); |
|
|
|
|
|
|
|
auto error_code = RET_OK; |
|
|
|
auto ret = RET_OK; |
|
|
|
|
|
|
|
if (type_ == schema::ActivationType_RELU) { |
|
|
|
error_code = Fp32Relu(input_addr + stride * task_id, count, output_addr + stride * task_id); |
|
|
|
ret = Fp32Relu(input_addr + stride * task_id, count, output_addr + stride * task_id); |
|
|
|
} else if (type_ == schema::ActivationType_RELU6) { |
|
|
|
error_code = Fp32Relu6(input_addr + stride * task_id, count, output_addr + stride * task_id); |
|
|
|
ret = Fp32Relu6(input_addr + stride * task_id, count, output_addr + stride * task_id); |
|
|
|
} else if (type_ == schema::ActivationType_LEAKY_RELU) { |
|
|
|
error_code = LRelu(input_addr + stride * task_id, count, output_addr + stride * task_id, alpha_); |
|
|
|
ret = LRelu(input_addr + stride * task_id, count, output_addr + stride * task_id, alpha_); |
|
|
|
} else if (type_ == schema::ActivationType_SIGMOID) { |
|
|
|
error_code = Sigmoid(input_addr + stride * task_id, count, output_addr + stride * task_id); |
|
|
|
ret = Sigmoid(input_addr + stride * task_id, count, output_addr + stride * task_id); |
|
|
|
} else if (type_ == schema::ActivationType_TANH) { |
|
|
|
error_code = Tanh(input_addr + stride * task_id, count, output_addr + stride * task_id); |
|
|
|
ret = Tanh(input_addr + stride * task_id, count, output_addr + stride * task_id); |
|
|
|
} else if (type_ == schema::ActivationType_SWISH) { |
|
|
|
ret = Swish(input_addr + stride * task_id, count, output_addr + stride * task_id); |
|
|
|
} else if (type_ == schema::ActivationType_HSWISH) { |
|
|
|
error_code = HSwish(input_addr + stride * task_id, count, output_addr + stride * task_id); |
|
|
|
ret = HSwish(input_addr + stride * task_id, count, output_addr + stride * task_id); |
|
|
|
} else if (type_ == schema::ActivationType_HSIGMOID) { |
|
|
|
error_code = HSigmoid(input_addr + stride * task_id, count, output_addr + stride * task_id); |
|
|
|
ret = HSigmoid(input_addr + stride * task_id, count, output_addr + stride * task_id); |
|
|
|
} else if (type_ == schema::ActivationType_HARD_TANH) { |
|
|
|
error_code = HardTanh(input_addr + stride * task_id, count, output_addr + stride * task_id, min_val_, max_val_); |
|
|
|
ret = HardTanh(input_addr + stride * task_id, count, output_addr + stride * task_id, min_val_, max_val_); |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "Activation type error"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
if (error_code != RET_OK) { |
|
|
|
return RET_ERROR; |
|
|
|
if (ret != RET_OK) { |
|
|
|
MS_LOG(ERROR) << "Activation error, ret: " << ret; |
|
|
|
} |
|
|
|
return RET_OK; |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
int ActivationRun(void *cdata, int task_id) { |
|
|
|
|