|
|
|
@@ -37,7 +37,7 @@ int ActivationFp16CPUKernel::Init() { |
|
|
|
if (type_ != schema::ActivationType_RELU && type_ != schema::ActivationType_RELU6 && |
|
|
|
type_ != schema::ActivationType_LEAKY_RELU && type_ != schema::ActivationType_SIGMOID && |
|
|
|
type_ != schema::ActivationType_TANH && type_ != schema::ActivationType_HSWISH && |
|
|
|
type_ != schema::ActivationType_SWISH) { |
|
|
|
type_ != schema::ActivationType_SWISH && type_ != schema::ActivationType_HARD_TANH) { |
|
|
|
MS_LOG(ERROR) << "Activation fp16 not support type: " << type_; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
@@ -67,6 +67,9 @@ int ActivationFp16CPUKernel::DoActivation(int task_id) { |
|
|
|
error_code = HSwishFp16(fp16_input_ + stride * task_id, fp16_output_ + stride * task_id, count); |
|
|
|
} else if (type_ == schema::ActivationType_SWISH) { |
|
|
|
error_code = SwishFp16(fp16_input_ + stride * task_id, fp16_output_ + stride * task_id, count); |
|
|
|
} else if (type_ == schema::ActivationType_HARD_TANH) { |
|
|
|
error_code = |
|
|
|
HardTanhFp16(fp16_input_ + stride * task_id, count, fp16_output_ + stride * task_id, min_val_, max_val_); |
|
|
|
} else { |
|
|
|
MS_LOG(ERROR) << "Activation fp16 not support type: " << type_; |
|
|
|
return RET_ERROR; |
|
|
|
|