From 95dc6ea8ee0db02e1aedf4c1cbb351c4bde041c2 Mon Sep 17 00:00:00 2001 From: nihuini Date: Fri, 1 Aug 2025 14:54:29 +0800 Subject: [PATCH] u --- tools/pnnx/src/pass_level1/nn_Conv1d.cpp | 5 ++--- tools/pnnx/src/pass_level1/nn_Conv2d.cpp | 5 ++--- tools/pnnx/src/pass_level1/nn_Conv3d.cpp | 5 ++--- .../src/pass_level1/nn_ConvTranspose1d.cpp | 18 +++++++++++++++++- .../src/pass_level1/nn_ConvTranspose2d.cpp | 18 +++++++++++++++++- .../src/pass_level1/nn_ConvTranspose3d.cpp | 18 +++++++++++++++++- tools/pnnx/src/utils.cpp | 8 ++++---- tools/pnnx/src/utils.h | 2 +- tools/pnnx/tests/test_nn_ConvTranspose1d.py | 4 +++- tools/pnnx/tests/test_nn_ConvTranspose2d.py | 4 +++- tools/pnnx/tests/test_nn_ConvTranspose3d.py | 4 +++- 11 files changed, 71 insertions(+), 20 deletions(-) diff --git a/tools/pnnx/src/pass_level1/nn_Conv1d.cpp b/tools/pnnx/src/pass_level1/nn_Conv1d.cpp index 9acfc9128..8602c5220 100644 --- a/tools/pnnx/src/pass_level1/nn_Conv1d.cpp +++ b/tools/pnnx/src/pass_level1/nn_Conv1d.cpp @@ -122,9 +122,8 @@ public: std::vector weight_data = op->attrs["weight"].get_float32_data(); std::vector weight_g_data = weight_g.get_float32_data(); int outch = op->params.at("out_channels").i; - int inch = op->params.at("in_channels").i; - int maxk = op->params.at("kernel_size").ai[0]; - apply_weight_norm(weight_data, weight_g_data, outch, inch, maxk); + int inch = op->params.at("in_channels").i * op->params.at("kernel_size").ai[0]; + apply_weight_norm(weight_data, weight_g_data, outch, inch); op->attrs["weight"].set_float32_data(weight_data); // drop the additional weight input diff --git a/tools/pnnx/src/pass_level1/nn_Conv2d.cpp b/tools/pnnx/src/pass_level1/nn_Conv2d.cpp index c9f031a5e..d10af5412 100644 --- a/tools/pnnx/src/pass_level1/nn_Conv2d.cpp +++ b/tools/pnnx/src/pass_level1/nn_Conv2d.cpp @@ -122,9 +122,8 @@ public: std::vector weight_data = op->attrs["weight"].get_float32_data(); std::vector weight_g_data = weight_g.get_float32_data(); int outch = op->params.at("out_channels").i; - int inch = op->params.at("in_channels").i; - int maxk = op->params.at("kernel_size").ai[0] * op->params.at("kernel_size").ai[1]; - apply_weight_norm(weight_data, weight_g_data, outch, inch, maxk); + int inch = op->params.at("in_channels").i * op->params.at("kernel_size").ai[0] * op->params.at("kernel_size").ai[1]; + apply_weight_norm(weight_data, weight_g_data, outch, inch); op->attrs["weight"].set_float32_data(weight_data); // drop the additional weight input diff --git a/tools/pnnx/src/pass_level1/nn_Conv3d.cpp b/tools/pnnx/src/pass_level1/nn_Conv3d.cpp index 895f59b21..302994ea5 100644 --- a/tools/pnnx/src/pass_level1/nn_Conv3d.cpp +++ b/tools/pnnx/src/pass_level1/nn_Conv3d.cpp @@ -122,9 +122,8 @@ public: std::vector weight_data = op->attrs["weight"].get_float32_data(); std::vector weight_g_data = weight_g.get_float32_data(); int outch = op->params.at("out_channels").i; - int inch = op->params.at("in_channels").i; - int maxk = op->params.at("kernel_size").ai[0] * op->params.at("kernel_size").ai[1] * op->params.at("kernel_size").ai[2]; - apply_weight_norm(weight_data, weight_g_data, outch, inch, maxk); + int inch = op->params.at("in_channels").i * op->params.at("kernel_size").ai[0] * op->params.at("kernel_size").ai[1] * op->params.at("kernel_size").ai[2]; + apply_weight_norm(weight_data, weight_g_data, outch, inch); op->attrs["weight"].set_float32_data(weight_data); // drop the additional weight input diff --git a/tools/pnnx/src/pass_level1/nn_ConvTranspose1d.cpp b/tools/pnnx/src/pass_level1/nn_ConvTranspose1d.cpp index c124fa9cd..783ff269f 100644 --- a/tools/pnnx/src/pass_level1/nn_ConvTranspose1d.cpp +++ b/tools/pnnx/src/pass_level1/nn_ConvTranspose1d.cpp @@ -2,6 +2,7 @@ // SPDX-License-Identifier: BSD-3-Clause #include "fuse_module_pass.h" +#include "utils.h" namespace pnnx { @@ -22,7 +23,7 @@ public: { const TorchNodeProxy* convolution = graph.find_node_by_kind("aten::_convolution"); - const TorchTensorProxy& weight = mod.attr("weight"); + const TorchTensorProxy& weight = mod.hasattr("weight") ? mod.attr("weight") : mod.attr("weight_v"); op->params["groups"] = convolution->namedInput("groups"); op->params["in_channels"] = weight.size(0); @@ -35,6 +36,21 @@ public: op->params["bias"] = mod.hasattr("bias"); op->attrs["weight"] = weight; + if (!mod.hasattr("weight")) + { + // weight norm + Attribute weight_g = mod.attr("weight_g"); + std::vector weight_data = op->attrs["weight"].get_float32_data(); + std::vector weight_g_data = weight_g.get_float32_data(); + int inch = op->params.at("in_channels").i; + int outch = op->params.at("out_channels").i * op->params.at("kernel_size").ai[0]; + apply_weight_norm(weight_data, weight_g_data, inch, outch); + op->attrs["weight"].set_float32_data(weight_data); + + // drop the additional weight input + op->inputs[1]->remove_consumer(op); + op->inputs.resize(1); + } if (mod.hasattr("bias")) { op->attrs["bias"] = mod.attr("bias"); diff --git a/tools/pnnx/src/pass_level1/nn_ConvTranspose2d.cpp b/tools/pnnx/src/pass_level1/nn_ConvTranspose2d.cpp index 50104b8f8..fa6d02b7b 100644 --- a/tools/pnnx/src/pass_level1/nn_ConvTranspose2d.cpp +++ b/tools/pnnx/src/pass_level1/nn_ConvTranspose2d.cpp @@ -2,6 +2,7 @@ // SPDX-License-Identifier: BSD-3-Clause #include "fuse_module_pass.h" +#include "utils.h" namespace pnnx { @@ -22,7 +23,7 @@ public: { const TorchNodeProxy* convolution = graph.find_node_by_kind("aten::_convolution"); - const TorchTensorProxy& weight = mod.attr("weight"); + const TorchTensorProxy& weight = mod.hasattr("weight") ? mod.attr("weight") : mod.attr("weight_v"); op->params["groups"] = convolution->namedInput("groups"); op->params["in_channels"] = weight.size(0); @@ -35,6 +36,21 @@ public: op->params["bias"] = mod.hasattr("bias"); op->attrs["weight"] = weight; + if (!mod.hasattr("weight")) + { + // weight norm + Attribute weight_g = mod.attr("weight_g"); + std::vector weight_data = op->attrs["weight"].get_float32_data(); + std::vector weight_g_data = weight_g.get_float32_data(); + int inch = op->params.at("in_channels").i; + int outch = op->params.at("out_channels").i * op->params.at("kernel_size").ai[0] * op->params.at("kernel_size").ai[1]; + apply_weight_norm(weight_data, weight_g_data, inch, outch); + op->attrs["weight"].set_float32_data(weight_data); + + // drop the additional weight input + op->inputs[1]->remove_consumer(op); + op->inputs.resize(1); + } if (mod.hasattr("bias")) { op->attrs["bias"] = mod.attr("bias"); diff --git a/tools/pnnx/src/pass_level1/nn_ConvTranspose3d.cpp b/tools/pnnx/src/pass_level1/nn_ConvTranspose3d.cpp index e4382a626..e41e43de6 100644 --- a/tools/pnnx/src/pass_level1/nn_ConvTranspose3d.cpp +++ b/tools/pnnx/src/pass_level1/nn_ConvTranspose3d.cpp @@ -2,6 +2,7 @@ // SPDX-License-Identifier: BSD-3-Clause #include "fuse_module_pass.h" +#include "utils.h" namespace pnnx { @@ -22,7 +23,7 @@ public: { const TorchNodeProxy* convolution = graph.find_node_by_kind("aten::_convolution"); - const TorchTensorProxy& weight = mod.attr("weight"); + const TorchTensorProxy& weight = mod.hasattr("weight") ? mod.attr("weight") : mod.attr("weight_v"); op->params["groups"] = convolution->namedInput("groups"); op->params["in_channels"] = weight.size(0); @@ -35,6 +36,21 @@ public: op->params["bias"] = mod.hasattr("bias"); op->attrs["weight"] = weight; + if (!mod.hasattr("weight")) + { + // weight norm + Attribute weight_g = mod.attr("weight_g"); + std::vector weight_data = op->attrs["weight"].get_float32_data(); + std::vector weight_g_data = weight_g.get_float32_data(); + int inch = op->params.at("in_channels").i; + int outch = op->params.at("out_channels").i * op->params.at("kernel_size").ai[0] * op->params.at("kernel_size").ai[1] * op->params.at("kernel_size").ai[2]; + apply_weight_norm(weight_data, weight_g_data, inch, outch); + op->attrs["weight"].set_float32_data(weight_data); + + // drop the additional weight input + op->inputs[1]->remove_consumer(op); + op->inputs.resize(1); + } if (mod.hasattr("bias")) { op->attrs["bias"] = mod.attr("bias"); diff --git a/tools/pnnx/src/utils.cpp b/tools/pnnx/src/utils.cpp index c99364025..6499e5160 100644 --- a/tools/pnnx/src/utils.cpp +++ b/tools/pnnx/src/utils.cpp @@ -112,23 +112,23 @@ float float16_to_float32(unsigned short value) return tmp.f; } -void apply_weight_norm(std::vector& weight, const std::vector& weight_g, int outch, int inch, int maxk) +void apply_weight_norm(std::vector& weight, const std::vector& weight_g, int outch, int inch) { const float eps = 1e-12f; for (int i = 0; i < outch; i++) { - float* pw = weight.data() + i * inch * maxk; + float* pw = weight.data() + i * inch; float norm = 0.f; - for (int j = 0; j < inch * maxk; j++) + for (int j = 0; j < inch; j++) { float w = pw[j]; norm += w * w; } norm = sqrt(norm) + eps; - for (int j = 0; j < inch * maxk; j++) + for (int j = 0; j < inch; j++) { pw[j] = weight_g[i] * pw[j] / norm; } diff --git a/tools/pnnx/src/utils.h b/tools/pnnx/src/utils.h index 31c5e4605..22f7a2208 100644 --- a/tools/pnnx/src/utils.h +++ b/tools/pnnx/src/utils.h @@ -12,7 +12,7 @@ unsigned short float32_to_float16(float value); float float16_to_float32(unsigned short value); -void apply_weight_norm(std::vector& weight, const std::vector& weight_g, int outch, int inch, int maxk); +void apply_weight_norm(std::vector& weight, const std::vector& weight_g, int outch, int inch); } // namespace pnnx diff --git a/tools/pnnx/tests/test_nn_ConvTranspose1d.py b/tools/pnnx/tests/test_nn_ConvTranspose1d.py index 413207fbb..963b0c384 100644 --- a/tools/pnnx/tests/test_nn_ConvTranspose1d.py +++ b/tools/pnnx/tests/test_nn_ConvTranspose1d.py @@ -18,6 +18,8 @@ class Model(nn.Module): self.deconv_6 = nn.ConvTranspose1d(in_channels=32, out_channels=28, kernel_size=2, stride=1, padding=2, output_padding=0, dilation=1, groups=1, bias=False) self.deconv_7 = nn.ConvTranspose1d(in_channels=28, out_channels=24, kernel_size=3, stride=2, padding=(6), output_padding=(1), dilation=2, groups=1, bias=True) + self.deconv_7 = torch.nn.utils.weight_norm(self.deconv_7) + self.downsample = nn.Conv1d(24, 16, 3, stride=2, padding=1) self.upsample = nn.ConvTranspose1d(16, 24, 3, stride=2, padding=1) @@ -57,7 +59,7 @@ def test(): import test_nn_ConvTranspose1d_pnnx b = test_nn_ConvTranspose1d_pnnx.test_inference() - return torch.equal(a, b) + return torch.allclose(a, b, 1e-4, 1e-4) if __name__ == "__main__": if test(): diff --git a/tools/pnnx/tests/test_nn_ConvTranspose2d.py b/tools/pnnx/tests/test_nn_ConvTranspose2d.py index 6e1e544e4..21d41615d 100644 --- a/tools/pnnx/tests/test_nn_ConvTranspose2d.py +++ b/tools/pnnx/tests/test_nn_ConvTranspose2d.py @@ -18,6 +18,8 @@ class Model(nn.Module): self.deconv_6 = nn.ConvTranspose2d(in_channels=32, out_channels=28, kernel_size=2, stride=1, padding=2, output_padding=0, dilation=1, groups=1, bias=False) self.deconv_7 = nn.ConvTranspose2d(in_channels=28, out_channels=24, kernel_size=3, stride=2, padding=(5,6), output_padding=(1,0), dilation=2, groups=1, bias=True) + self.deconv_7 = torch.nn.utils.weight_norm(self.deconv_7) + self.downsample = nn.Conv2d(24, 16, 3, stride=2, padding=1) self.upsample = nn.ConvTranspose2d(16, 24, 3, stride=2, padding=1) @@ -57,7 +59,7 @@ def test(): import test_nn_ConvTranspose2d_pnnx b = test_nn_ConvTranspose2d_pnnx.test_inference() - return torch.equal(a, b) + return torch.allclose(a, b, 1e-4, 1e-4) if __name__ == "__main__": if test(): diff --git a/tools/pnnx/tests/test_nn_ConvTranspose3d.py b/tools/pnnx/tests/test_nn_ConvTranspose3d.py index c555ab451..1699e7288 100644 --- a/tools/pnnx/tests/test_nn_ConvTranspose3d.py +++ b/tools/pnnx/tests/test_nn_ConvTranspose3d.py @@ -18,6 +18,8 @@ class Model(nn.Module): self.deconv_6 = nn.ConvTranspose3d(in_channels=32, out_channels=28, kernel_size=2, stride=1, padding=2, output_padding=0, dilation=1, groups=1, bias=False) self.deconv_7 = nn.ConvTranspose3d(in_channels=28, out_channels=24, kernel_size=3, stride=2, padding=(5,6,7), output_padding=(1,0,1), dilation=2, groups=1, bias=True) + self.deconv_7 = torch.nn.utils.weight_norm(self.deconv_7) + self.downsample = nn.Conv3d(24, 16, 3, stride=2, padding=1) self.upsample = nn.ConvTranspose3d(16, 24, 3, stride=2, padding=1) @@ -57,7 +59,7 @@ def test(): import test_nn_ConvTranspose3d_pnnx b = test_nn_ConvTranspose3d_pnnx.test_inference() - return torch.equal(a, b) + return torch.allclose(a, b, 1e-4, 1e-4) if __name__ == "__main__": if test():