Browse Source

pnnx convert torch.addmm (#3634)

tags/20220420
nihui GitHub 4 years ago
parent
commit
d42e048b56
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 404 additions and 1 deletions
  1. +5
    -0
      src/layer/gemm.cpp
  2. +15
    -1
      tests/test_gemm.cpp
  3. +2
    -0
      tools/pnnx/src/CMakeLists.txt
  4. +44
    -0
      tools/pnnx/src/pass_level2/torch_addmm.cpp
  5. +196
    -0
      tools/pnnx/src/pass_ncnn/torch_addmm.cpp
  6. +1
    -0
      tools/pnnx/tests/CMakeLists.txt
  7. +1
    -0
      tools/pnnx/tests/ncnn/CMakeLists.txt
  8. +72
    -0
      tools/pnnx/tests/ncnn/test_torch_addmm.py
  9. +68
    -0
      tools/pnnx/tests/test_torch_addmm.py

+ 5
- 0
src/layer/gemm.cpp View File

@@ -104,6 +104,11 @@ int Gemm::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl
// auto broadcast from h to w is the ncnn-style convention
broadcast_type_C = 1;
}
if (C.dims == 1 && C.w == N)
{
// N
broadcast_type_C = 4;
}
if (C.dims == 2 && C.w == 1 && C.h == M)
{
// Mx1


+ 15
- 1
tests/test_gemm.cpp View File

@@ -146,6 +146,19 @@ static int test_gemm_5()
|| test_gemm_bias(16, 24, 15, RandomMat(1, 14), 1.7f, 1.3f, 1, 1);
}

static int test_gemm_6()
{
return 0
|| test_gemm_bias(13, 14, 15, RandomMat(14), 0.1f, 0.4f, 0, 0)
|| test_gemm_bias(13, 14, 15, RandomMat(14), 0.4f, -1.f, 1, 0)
|| test_gemm_bias(13, 14, 15, RandomMat(14), -0.3f, -0.21f, 0, 1)
|| test_gemm_bias(13, 14, 15, RandomMat(14), 1.7f, 1.3f, 1, 1)
|| test_gemm_bias(16, 24, 15, RandomMat(14), 0.1f, 0.4f, 0, 0)
|| test_gemm_bias(16, 24, 15, RandomMat(14), 0.4f, -1.f, 1, 0)
|| test_gemm_bias(16, 24, 15, RandomMat(14), -0.3f, -0.21f, 0, 1)
|| test_gemm_bias(16, 24, 15, RandomMat(14), 1.7f, 1.3f, 1, 1);
}

int main()
{
SRAND(7767517);
@@ -156,5 +169,6 @@ int main()
|| test_gemm_2()
|| test_gemm_3()
|| test_gemm_4()
|| test_gemm_5();
|| test_gemm_5()
|| test_gemm_6();
}

+ 2
- 0
tools/pnnx/src/CMakeLists.txt View File

@@ -175,6 +175,7 @@ set(pnnx_pass_level2_SRCS
pass_level2/Tensor_select.cpp
pass_level2/Tensor_slice.cpp
pass_level2/Tensor_view.cpp
pass_level2/torch_addmm.cpp
pass_level2/torch_amax.cpp
pass_level2/torch_amin.cpp
pass_level2/torch_arange.cpp
@@ -402,6 +403,7 @@ set(pnnx_pass_ncnn_SRCS
pass_ncnn/Tensor_repeat.cpp
pass_ncnn/Tensor_slice.cpp
pass_ncnn/Tensor_view.cpp
pass_ncnn/torch_addmm.cpp
pass_ncnn/torch_amax.cpp
pass_ncnn/torch_amin.cpp
pass_ncnn/torch_clamp.cpp


+ 44
- 0
tools/pnnx/src/pass_level2/torch_addmm.cpp View File

@@ -0,0 +1,44 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level2.h"

namespace pnnx {

class torch_addmm : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
7 6
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 mat1
pnnx.Input input_2 0 1 mat2
pnnx.Input input_3 0 1 beta
pnnx.Input input_4 0 1 alpha
aten::addmm op_0 5 1 input mat1 mat2 beta alpha out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "torch.addmm";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_addmm, 20)

} // namespace pnnx

+ 196
- 0
tools/pnnx/src/pass_ncnn/torch_addmm.cpp View File

@@ -0,0 +1,196 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_ncnn.h"

namespace pnnx {

namespace ncnn {

class torch_addmm : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input_0 0 1 mat1
pnnx.Attribute op_bias 0 1 bias @qwq
pnnx.Attribute op_weight 0 1 weight @qwq
torch.addmm op_0 3 1 bias mat1 weight out alpha=%alpha beta=%beta
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "InnerProduct";
}

const char* name_str() const
{
return "addmm";
}

bool match(const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
float alpha = 1.f;
float beta = 1.f;

if (captured_params.at("alpha").type == 2)
{
alpha = captured_params.at("alpha").i;
}
if (captured_params.at("alpha").type == 3)
{
alpha = captured_params.at("alpha").f;
}

if (captured_params.at("beta").type == 2)
{
beta = captured_params.at("beta").i;
}
if (captured_params.at("beta").type == 3)
{
beta = captured_params.at("beta").f;
}

if (alpha != 1.f || beta != 1.f)
return false;

Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

if (weight.shape.size() != 2 || bias.shape.size() != 1)
return false;

if (weight.shape[1] != bias.shape[0])
return false;

return true;
}

void write(Operator* op, const std::map<std::string, Parameter>& /*captured_params*/, const std::map<std::string, Attribute>& captured_attrs) const
{
Attribute weight;
Attribute bias;
for (const auto& x : captured_attrs)
{
if (x.first.substr(0, 10) == "op_weight.")
weight = x.second;
if (x.first.substr(0, 8) == "op_bias.")
bias = x.second;
}

// transpose weight inch-outch to outch-inch
const int inch = weight.shape[0];
const int outch = weight.shape[1];
std::vector<float> new_weight;
{
const float* w = (const float*)weight.data.data();

new_weight.resize(outch * inch);
float* w2 = (float*)new_weight.data();

// reorder weight from inch-outch to outch-inch
for (int i = 0; i < outch; i++)
{
for (int j = 0; j < inch; j++)
{
w2[i * inch + j] = w[j * outch + i];
}
}
}

op->params["0"] = outch;
op->params["1"] = 1;
op->params["2"] = (int)(weight.data.size() / sizeof(float));

op->attrs["0"] = Attribute();
op->attrs["0"].data = {0, 0, 0, 0};
op->attrs["1"] = Attribute({outch, inch}, new_weight);
op->attrs["2"] = bias;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_addmm, 20)

class torch_addmm_1 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
5 4
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 mat1
pnnx.Input input_2 0 1 mat2
torch.addmm op_0 3 1 input mat1 mat2 out alpha=%alpha beta=%beta
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Gemm";
}

const char* name_str() const
{
return "addmm";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
std::swap(op->inputs[0], op->inputs[1]);
std::swap(op->inputs[1], op->inputs[2]);

float alpha = 1.f;
float beta = 1.f;

if (captured_params.at("alpha").type == 2)
{
alpha = captured_params.at("alpha").i;
}
if (captured_params.at("alpha").type == 3)
{
alpha = captured_params.at("alpha").f;
}

if (captured_params.at("beta").type == 2)
{
beta = captured_params.at("beta").i;
}
if (captured_params.at("beta").type == 3)
{
beta = captured_params.at("beta").f;
}

op->params["0"] = alpha;
op->params["1"] = beta / alpha;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_addmm_1, 22)

} // namespace ncnn

} // namespace pnnx

+ 1
- 0
tools/pnnx/tests/CMakeLists.txt View File

@@ -162,6 +162,7 @@ pnnx_add_test(Tensor_select)
pnnx_add_test(Tensor_slice)
pnnx_add_test(Tensor_view)

pnnx_add_test(torch_addmm)
pnnx_add_test(torch_amax)
pnnx_add_test(torch_amin)
pnnx_add_test(torch_argmax)


+ 1
- 0
tools/pnnx/tests/ncnn/CMakeLists.txt View File

@@ -125,6 +125,7 @@ pnnx_ncnn_add_test(Tensor_reshape)
pnnx_ncnn_add_test(Tensor_slice)
pnnx_ncnn_add_test(Tensor_view)

pnnx_ncnn_add_test(torch_addmm)
pnnx_ncnn_add_test(torch_amax)
pnnx_ncnn_add_test(torch_amin)
pnnx_ncnn_add_test(torch_cat)


+ 72
- 0
tools/pnnx/tests/ncnn/test_torch_addmm.py View File

@@ -0,0 +1,72 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

self.c0 = nn.Parameter(torch.rand(12))
self.c2 = nn.Parameter(torch.rand(48, 12))

def forward(self, a0, a1, a2, b0, b1, b2, c1):
a = torch.addmm(a0, a1, a2)
b = torch.addmm(b0, b1, b2, beta=1.4, alpha=0.7)
c = torch.addmm(self.c0, c1, self.c2, beta=1, alpha=1)
return a, b, c

def test():
net = Model()
net.eval()

torch.manual_seed(0)
a0 = torch.rand(23)
a1 = torch.rand(13, 16)
a2 = torch.rand(16, 23)
b0 = torch.rand(7, 33)
b1 = torch.rand(7, 26)
b2 = torch.rand(26, 33)
c1 = torch.rand(16, 48)

a = net(a0, a1, a2, b0, b1, b2, c1)

# export torchscript
mod = torch.jit.trace(net, (a0, a1, a2, b0, b1, b2, c1))
mod.save("test_torch_addmm.pt")

# torchscript to pnnx
import os
os.system("../../src/pnnx test_torch_addmm.pt inputshape=[23],[13,16],[16,23],[7,33],[7,26],[26,33],[16,48]")

# ncnn inference
import test_torch_addmm_ncnn
b = test_torch_addmm_ncnn.test_inference()

for a0, b0 in zip(a, b):
if not torch.allclose(a0, b0, 1e-4, 1e-4):
print(a0.shape)
print(b0.shape)
print(a0)
print(b0)
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

+ 68
- 0
tools/pnnx/tests/test_torch_addmm.py View File

@@ -0,0 +1,68 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

self.c0 = nn.Parameter(torch.rand(12))
self.c2 = nn.Parameter(torch.rand(48, 12))

def forward(self, a0, a1, a2, b0, b1, b2, c1):
a = torch.addmm(a0, a1, a2)
b = torch.addmm(b0, b1, b2, beta=1.4, alpha=0.7)
c = torch.addmm(self.c0, c1, self.c2, beta=1, alpha=1)
return a, b, c

def test():
net = Model()
net.eval()

torch.manual_seed(0)
a0 = torch.rand(13, 1)
a1 = torch.rand(13, 16)
a2 = torch.rand(16, 23)
b0 = torch.rand(7, 33)
b1 = torch.rand(7, 26)
b2 = torch.rand(26, 33)
c1 = torch.rand(16, 48)

a = net(a0, a1, a2, b0, b1, b2, c1)

# export torchscript
mod = torch.jit.trace(net, (a0, a1, a2, b0, b1, b2, c1))
mod.save("test_torch_addmm.pt")

# torchscript to pnnx
import os
os.system("../src/pnnx test_torch_addmm.pt inputshape=[13,1],[13,16],[16,23],[7,33],[7,26],[26,33],[16,48]")

# pnnx inference
import test_torch_addmm_pnnx
b = test_torch_addmm_pnnx.test_inference()

for a0, b0 in zip(a, b):
if not torch.equal(a0, b0):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

Loading…
Cancel
Save