| @@ -104,6 +104,11 @@ int Gemm::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl | |||
| // auto broadcast from h to w is the ncnn-style convention | |||
| broadcast_type_C = 1; | |||
| } | |||
| if (C.dims == 1 && C.w == N) | |||
| { | |||
| // N | |||
| broadcast_type_C = 4; | |||
| } | |||
| if (C.dims == 2 && C.w == 1 && C.h == M) | |||
| { | |||
| // Mx1 | |||
| @@ -146,6 +146,19 @@ static int test_gemm_5() | |||
| || test_gemm_bias(16, 24, 15, RandomMat(1, 14), 1.7f, 1.3f, 1, 1); | |||
| } | |||
| static int test_gemm_6() | |||
| { | |||
| return 0 | |||
| || test_gemm_bias(13, 14, 15, RandomMat(14), 0.1f, 0.4f, 0, 0) | |||
| || test_gemm_bias(13, 14, 15, RandomMat(14), 0.4f, -1.f, 1, 0) | |||
| || test_gemm_bias(13, 14, 15, RandomMat(14), -0.3f, -0.21f, 0, 1) | |||
| || test_gemm_bias(13, 14, 15, RandomMat(14), 1.7f, 1.3f, 1, 1) | |||
| || test_gemm_bias(16, 24, 15, RandomMat(14), 0.1f, 0.4f, 0, 0) | |||
| || test_gemm_bias(16, 24, 15, RandomMat(14), 0.4f, -1.f, 1, 0) | |||
| || test_gemm_bias(16, 24, 15, RandomMat(14), -0.3f, -0.21f, 0, 1) | |||
| || test_gemm_bias(16, 24, 15, RandomMat(14), 1.7f, 1.3f, 1, 1); | |||
| } | |||
| int main() | |||
| { | |||
| SRAND(7767517); | |||
| @@ -156,5 +169,6 @@ int main() | |||
| || test_gemm_2() | |||
| || test_gemm_3() | |||
| || test_gemm_4() | |||
| || test_gemm_5(); | |||
| || test_gemm_5() | |||
| || test_gemm_6(); | |||
| } | |||
| @@ -175,6 +175,7 @@ set(pnnx_pass_level2_SRCS | |||
| pass_level2/Tensor_select.cpp | |||
| pass_level2/Tensor_slice.cpp | |||
| pass_level2/Tensor_view.cpp | |||
| pass_level2/torch_addmm.cpp | |||
| pass_level2/torch_amax.cpp | |||
| pass_level2/torch_amin.cpp | |||
| pass_level2/torch_arange.cpp | |||
| @@ -402,6 +403,7 @@ set(pnnx_pass_ncnn_SRCS | |||
| pass_ncnn/Tensor_repeat.cpp | |||
| pass_ncnn/Tensor_slice.cpp | |||
| pass_ncnn/Tensor_view.cpp | |||
| pass_ncnn/torch_addmm.cpp | |||
| pass_ncnn/torch_amax.cpp | |||
| pass_ncnn/torch_amin.cpp | |||
| pass_ncnn/torch_clamp.cpp | |||
| @@ -0,0 +1,44 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level2.h" | |||
| namespace pnnx { | |||
| class torch_addmm : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 7 6 | |||
| pnnx.Input input_0 0 1 input | |||
| pnnx.Input input_1 0 1 mat1 | |||
| pnnx.Input input_2 0 1 mat2 | |||
| pnnx.Input input_3 0 1 beta | |||
| pnnx.Input input_4 0 1 alpha | |||
| aten::addmm op_0 5 1 input mat1 mat2 beta alpha out | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "torch.addmm"; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_addmm, 20) | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,196 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_ncnn.h" | |||
| namespace pnnx { | |||
| namespace ncnn { | |||
| class torch_addmm : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 5 4 | |||
| pnnx.Input input_0 0 1 mat1 | |||
| pnnx.Attribute op_bias 0 1 bias @qwq | |||
| pnnx.Attribute op_weight 0 1 weight @qwq | |||
| torch.addmm op_0 3 1 bias mat1 weight out alpha=%alpha beta=%beta | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "InnerProduct"; | |||
| } | |||
| const char* name_str() const | |||
| { | |||
| return "addmm"; | |||
| } | |||
| bool match(const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| float alpha = 1.f; | |||
| float beta = 1.f; | |||
| if (captured_params.at("alpha").type == 2) | |||
| { | |||
| alpha = captured_params.at("alpha").i; | |||
| } | |||
| if (captured_params.at("alpha").type == 3) | |||
| { | |||
| alpha = captured_params.at("alpha").f; | |||
| } | |||
| if (captured_params.at("beta").type == 2) | |||
| { | |||
| beta = captured_params.at("beta").i; | |||
| } | |||
| if (captured_params.at("beta").type == 3) | |||
| { | |||
| beta = captured_params.at("beta").f; | |||
| } | |||
| if (alpha != 1.f || beta != 1.f) | |||
| return false; | |||
| Attribute weight; | |||
| Attribute bias; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 10) == "op_weight.") | |||
| weight = x.second; | |||
| if (x.first.substr(0, 8) == "op_bias.") | |||
| bias = x.second; | |||
| } | |||
| if (weight.shape.size() != 2 || bias.shape.size() != 1) | |||
| return false; | |||
| if (weight.shape[1] != bias.shape[0]) | |||
| return false; | |||
| return true; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& /*captured_params*/, const std::map<std::string, Attribute>& captured_attrs) const | |||
| { | |||
| Attribute weight; | |||
| Attribute bias; | |||
| for (const auto& x : captured_attrs) | |||
| { | |||
| if (x.first.substr(0, 10) == "op_weight.") | |||
| weight = x.second; | |||
| if (x.first.substr(0, 8) == "op_bias.") | |||
| bias = x.second; | |||
| } | |||
| // transpose weight inch-outch to outch-inch | |||
| const int inch = weight.shape[0]; | |||
| const int outch = weight.shape[1]; | |||
| std::vector<float> new_weight; | |||
| { | |||
| const float* w = (const float*)weight.data.data(); | |||
| new_weight.resize(outch * inch); | |||
| float* w2 = (float*)new_weight.data(); | |||
| // reorder weight from inch-outch to outch-inch | |||
| for (int i = 0; i < outch; i++) | |||
| { | |||
| for (int j = 0; j < inch; j++) | |||
| { | |||
| w2[i * inch + j] = w[j * outch + i]; | |||
| } | |||
| } | |||
| } | |||
| op->params["0"] = outch; | |||
| op->params["1"] = 1; | |||
| op->params["2"] = (int)(weight.data.size() / sizeof(float)); | |||
| op->attrs["0"] = Attribute(); | |||
| op->attrs["0"].data = {0, 0, 0, 0}; | |||
| op->attrs["1"] = Attribute({outch, inch}, new_weight); | |||
| op->attrs["2"] = bias; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_addmm, 20) | |||
| class torch_addmm_1 : public GraphRewriterPass | |||
| { | |||
| public: | |||
| const char* match_pattern_graph() const | |||
| { | |||
| return R"PNNXIR(7767517 | |||
| 5 4 | |||
| pnnx.Input input_0 0 1 input | |||
| pnnx.Input input_1 0 1 mat1 | |||
| pnnx.Input input_2 0 1 mat2 | |||
| torch.addmm op_0 3 1 input mat1 mat2 out alpha=%alpha beta=%beta | |||
| pnnx.Output output 1 0 out | |||
| )PNNXIR"; | |||
| } | |||
| const char* type_str() const | |||
| { | |||
| return "Gemm"; | |||
| } | |||
| const char* name_str() const | |||
| { | |||
| return "addmm"; | |||
| } | |||
| void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const | |||
| { | |||
| std::swap(op->inputs[0], op->inputs[1]); | |||
| std::swap(op->inputs[1], op->inputs[2]); | |||
| float alpha = 1.f; | |||
| float beta = 1.f; | |||
| if (captured_params.at("alpha").type == 2) | |||
| { | |||
| alpha = captured_params.at("alpha").i; | |||
| } | |||
| if (captured_params.at("alpha").type == 3) | |||
| { | |||
| alpha = captured_params.at("alpha").f; | |||
| } | |||
| if (captured_params.at("beta").type == 2) | |||
| { | |||
| beta = captured_params.at("beta").i; | |||
| } | |||
| if (captured_params.at("beta").type == 3) | |||
| { | |||
| beta = captured_params.at("beta").f; | |||
| } | |||
| op->params["0"] = alpha; | |||
| op->params["1"] = beta / alpha; | |||
| } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_addmm_1, 22) | |||
| } // namespace ncnn | |||
| } // namespace pnnx | |||
| @@ -162,6 +162,7 @@ pnnx_add_test(Tensor_select) | |||
| pnnx_add_test(Tensor_slice) | |||
| pnnx_add_test(Tensor_view) | |||
| pnnx_add_test(torch_addmm) | |||
| pnnx_add_test(torch_amax) | |||
| pnnx_add_test(torch_amin) | |||
| pnnx_add_test(torch_argmax) | |||
| @@ -125,6 +125,7 @@ pnnx_ncnn_add_test(Tensor_reshape) | |||
| pnnx_ncnn_add_test(Tensor_slice) | |||
| pnnx_ncnn_add_test(Tensor_view) | |||
| pnnx_ncnn_add_test(torch_addmm) | |||
| pnnx_ncnn_add_test(torch_amax) | |||
| pnnx_ncnn_add_test(torch_amin) | |||
| pnnx_ncnn_add_test(torch_cat) | |||
| @@ -0,0 +1,72 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| self.c0 = nn.Parameter(torch.rand(12)) | |||
| self.c2 = nn.Parameter(torch.rand(48, 12)) | |||
| def forward(self, a0, a1, a2, b0, b1, b2, c1): | |||
| a = torch.addmm(a0, a1, a2) | |||
| b = torch.addmm(b0, b1, b2, beta=1.4, alpha=0.7) | |||
| c = torch.addmm(self.c0, c1, self.c2, beta=1, alpha=1) | |||
| return a, b, c | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| a0 = torch.rand(23) | |||
| a1 = torch.rand(13, 16) | |||
| a2 = torch.rand(16, 23) | |||
| b0 = torch.rand(7, 33) | |||
| b1 = torch.rand(7, 26) | |||
| b2 = torch.rand(26, 33) | |||
| c1 = torch.rand(16, 48) | |||
| a = net(a0, a1, a2, b0, b1, b2, c1) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (a0, a1, a2, b0, b1, b2, c1)) | |||
| mod.save("test_torch_addmm.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../../src/pnnx test_torch_addmm.pt inputshape=[23],[13,16],[16,23],[7,33],[7,26],[26,33],[16,48]") | |||
| # ncnn inference | |||
| import test_torch_addmm_ncnn | |||
| b = test_torch_addmm_ncnn.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.allclose(a0, b0, 1e-4, 1e-4): | |||
| print(a0.shape) | |||
| print(b0.shape) | |||
| print(a0) | |||
| print(b0) | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,68 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| self.c0 = nn.Parameter(torch.rand(12)) | |||
| self.c2 = nn.Parameter(torch.rand(48, 12)) | |||
| def forward(self, a0, a1, a2, b0, b1, b2, c1): | |||
| a = torch.addmm(a0, a1, a2) | |||
| b = torch.addmm(b0, b1, b2, beta=1.4, alpha=0.7) | |||
| c = torch.addmm(self.c0, c1, self.c2, beta=1, alpha=1) | |||
| return a, b, c | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| a0 = torch.rand(13, 1) | |||
| a1 = torch.rand(13, 16) | |||
| a2 = torch.rand(16, 23) | |||
| b0 = torch.rand(7, 33) | |||
| b1 = torch.rand(7, 26) | |||
| b2 = torch.rand(26, 33) | |||
| c1 = torch.rand(16, 48) | |||
| a = net(a0, a1, a2, b0, b1, b2, c1) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (a0, a1, a2, b0, b1, b2, c1)) | |||
| mod.save("test_torch_addmm.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../src/pnnx test_torch_addmm.pt inputshape=[13,1],[13,16],[16,23],[7,33],[7,26],[26,33],[16,48]") | |||
| # pnnx inference | |||
| import test_torch_addmm_pnnx | |||
| b = test_torch_addmm_pnnx.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.equal(a0, b0): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||