| @@ -59,4 +59,115 @@ pnnx.Output output 1 0 out | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_gelu_1, 10) | |||
| class F_gelu_2 : public GraphRewriterPass | |||
| { | |||
| public: | |||
| // x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 11 10 | |||
| pnnx.Input input_0 0 1 input | |||
| prim::Constant op_0 0 1 12 value=%0p5 | |||
| aten::mul op_1 2 1 input 12 13 | |||
| prim::Constant op_2 0 1 15 value=%sqrt2 | |||
| aten::div op_3 2 1 input 15 16 | |||
| aten::erf op_4 1 1 16 17 | |||
| prim::Constant op_5 0 1 20 value=%1 | |||
| prim::Constant op_6 0 1 21 value=1 | |||
| aten::add op_7 3 1 17 20 21 22 | |||
| aten::mul op_8 2 1 13 22 out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| bool match(const std::map<std::string, Parameter>& captured_params) const | |||
| { | |||
| if (captured_params.at("0p5").f != 0.5f) | |||
| return false; | |||
| if (fabs(captured_params.at("sqrt2").f - sqrt(2.f)) > 0.0001f) | |||
| return false; | |||
| if ((captured_params.at("1").type == 2 && captured_params.at("1").i != 1) || (captured_params.at("1").type == 3 && captured_params.at("1").f != 1.f)) | |||
| return false; | |||
| return true; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "F.gelu"; | |||
| } | |||
| void write(Operator* /*op*/, const std::map<std::string, Parameter>& /*captured_params*/) const | |||
| { | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_gelu_2, 9) | |||
| class F_gelu_3 : public GraphRewriterPass | |||
| { | |||
| public: | |||
| // 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 17 16 | |||
| pnnx.Input input_0 0 1 input | |||
| prim::Constant op_0 0 1 60 value=%0p5 | |||
| aten::mul op_1 2 1 input 60 26 | |||
| prim::Constant op_2 0 1 28 value=%3 | |||
| aten::pow op_3 2 1 input 28 29 | |||
| prim::Constant op_4 0 1 30 value=%0p044715 | |||
| aten::mul op_5 2 1 29 30 31 | |||
| prim::Constant op_6 0 1 61 value=1 | |||
| aten::add op_7 3 1 input 31 61 35 | |||
| prim::Constant op_8 0 1 36 value=%sqrt2dpi | |||
| aten::mul op_9 2 1 35 36 37 | |||
| aten::tanh op_10 1 1 37 39 | |||
| prim::Constant op_11 0 1 62 value=%1 | |||
| prim::Constant op_12 0 1 63 value=%1_1 | |||
| aten::add op_13 3 1 39 62 63 42 | |||
| aten::mul op_14 2 1 26 42 out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| bool match(const std::map<std::string, Parameter>& captured_params) const | |||
| { | |||
| if (captured_params.at("0p5").f != 0.5f) | |||
| return false; | |||
| if (fabs(captured_params.at("0p044715").f - 0.044715f) > 0.0001f) | |||
| return false; | |||
| if (fabs(captured_params.at("sqrt2dpi").f - sqrt(2.f / M_PI)) > 0.0001f) | |||
| return false; | |||
| if ((captured_params.at("1").type == 2 && captured_params.at("1").i != 1) || (captured_params.at("1").type == 3 && captured_params.at("1").f != 1.f)) | |||
| return false; | |||
| if ((captured_params.at("3").type == 2 && captured_params.at("3").i != 3) || (captured_params.at("3").type == 3 && captured_params.at("3").f != 3.f)) | |||
| return false; | |||
| if ((captured_params.at("1_1").type == 2 && captured_params.at("1_1").i != 1) || (captured_params.at("1_1").type == 3 && captured_params.at("1_1").f != 1.f)) | |||
| return false; | |||
| return true; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "F.gelu"; | |||
| } | |||
| void write(Operator* /*op*/, const std::map<std::string, Parameter>& /*captured_params*/) const | |||
| { | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_gelu_3, 9) | |||
| } // namespace pnnx | |||
| @@ -66,7 +66,7 @@ pnnx.Output output 1 0 out | |||
| return "F.local_response_norm"; | |||
| } | |||
| bool match_captured_params(const std::map<std::string, Parameter>& captured_params) const | |||
| bool match(const std::map<std::string, Parameter>& captured_params) const | |||
| { | |||
| if (captured_params.at("padzero").type == 2) | |||
| return captured_params.at("padzero").i == 0; | |||
| @@ -168,7 +168,7 @@ pnnx.Output output 1 0 out | |||
| return "F.local_response_norm"; | |||
| } | |||
| bool match_captured_params(const std::map<std::string, Parameter>& captured_params) const | |||
| bool match(const std::map<std::string, Parameter>& captured_params) const | |||
| { | |||
| if (captured_params.at("padzero").type == 2) | |||
| return captured_params.at("padzero").i == 0; | |||
| @@ -274,7 +274,7 @@ pnnx.Output output 1 0 out | |||
| return "F.local_response_norm"; | |||
| } | |||
| bool match_captured_params(const std::map<std::string, Parameter>& captured_params) const | |||
| bool match(const std::map<std::string, Parameter>& captured_params) const | |||
| { | |||
| if (captured_params.at("padzero").type == 2) | |||
| return captured_params.at("padzero").i == 0; | |||
| @@ -347,7 +347,7 @@ pnnx.Output output 1 0 out | |||
| return "F.local_response_norm"; | |||
| } | |||
| bool match_captured_params(const std::map<std::string, Parameter>& captured_params) const | |||
| bool match(const std::map<std::string, Parameter>& captured_params) const | |||
| { | |||
| if (captured_params.at("padzero").type == 2) | |||
| return captured_params.at("padzero").i == 0; | |||
| @@ -450,7 +450,7 @@ pnnx.Output output 1 0 out | |||
| return "F.local_response_norm"; | |||
| } | |||
| bool match_captured_params(const std::map<std::string, Parameter>& captured_params) const | |||
| bool match(const std::map<std::string, Parameter>& captured_params) const | |||
| { | |||
| if (captured_params.at("padzero").type == 2) | |||
| return captured_params.at("padzero").i == 0; | |||
| @@ -557,7 +557,7 @@ pnnx.Output output 1 0 out | |||
| return "F.local_response_norm"; | |||
| } | |||
| bool match_captured_params(const std::map<std::string, Parameter>& captured_params) const | |||
| bool match(const std::map<std::string, Parameter>& captured_params) const | |||
| { | |||
| if (captured_params.at("padzero").type == 2) | |||
| return captured_params.at("padzero").i == 0; | |||
| @@ -15,6 +15,13 @@ | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| import math | |||
| def gelu_forward_0(x): | |||
| return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) | |||
| def gelu_forward_1(x): | |||
| return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| @@ -23,8 +30,8 @@ class Model(nn.Module): | |||
| def forward(self, x, y, z, w): | |||
| x = F.gelu(x) | |||
| y = F.gelu(y) | |||
| z = F.gelu(z) | |||
| w = F.gelu(w) | |||
| z = gelu_forward_0(z) | |||
| w = gelu_forward_1(w) | |||
| return x, y, z, w | |||
| def test(): | |||
| @@ -37,7 +44,7 @@ def test(): | |||
| z = torch.rand(1, 3, 12, 16) | |||
| w = torch.rand(1, 5, 7, 9, 11) | |||
| a0, a1, a2, a3 = net(x, y, z, w) | |||
| a = net(x, y, z, w) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y, z, w)) | |||
| @@ -49,9 +56,12 @@ def test(): | |||
| # pnnx inference | |||
| import test_F_gelu_pnnx | |||
| b0, b1, b2, b3 = test_F_gelu_pnnx.test_inference() | |||
| b = test_F_gelu_pnnx.test_inference() | |||
| return torch.equal(a0, b0) and torch.equal(a1, b1) and torch.equal(a2, b2) and torch.equal(a3, b3) | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.allclose(a0, b0, 1e-4, 1e-4): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||