|
|
|
@@ -27,8 +27,8 @@ constexpr float MUL1_y = 0.5; |
|
|
|
|
|
|
|
// gelu(x) = 1/2 * x * [1 + erf(x / sqrt(2))] |
|
|
|
const BaseRef OnnxGeLUFusion::DefinePattern() const { |
|
|
|
VectorRef div_ref({std::make_shared<CondVar>(IsSpecifiedNode<&kPrimDivFusion>), input_, div_y_}); |
|
|
|
VectorRef erf_ref({std::make_shared<CondVar>(IsSpecifiedNode<&kPrimErf>), div_ref}); |
|
|
|
VectorRef div_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimDivFusion>), input_, div_y_}); |
|
|
|
VectorRef erf_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimErf>), div_ref}); |
|
|
|
VectorRef add_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>), erf_ref, add_y_}); |
|
|
|
VectorRef mul1_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>), input_, mul1_y_}); |
|
|
|
VectorRef mul2_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>), mul1_ref, add_ref}); |
|
|
|
|