From 6e49fa30dc2cd6b438733520fc9d5b0139c697bd Mon Sep 17 00:00:00 2001 From: nihui Date: Mon, 31 Oct 2022 10:56:42 +0800 Subject: [PATCH] groupnorm 1d/2d/4d (#4312) --- src/layer/groupnorm.cpp | 194 ++++++++++++++++----- tests/test_groupnorm.cpp | 38 +++- tools/pnnx/tests/ncnn/test_F_group_norm.py | 18 +- tools/pnnx/tests/ncnn/test_nn_GroupNorm.py | 29 ++- 4 files changed, 225 insertions(+), 54 deletions(-) diff --git a/src/layer/groupnorm.cpp b/src/layer/groupnorm.cpp index 81847d573..596d39743 100644 --- a/src/layer/groupnorm.cpp +++ b/src/layer/groupnorm.cpp @@ -52,66 +52,180 @@ int GroupNorm::load_model(const ModelBin& mb) int GroupNorm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const { - // x = (x - mean) / sqrt(var + eps) * gamma + beta + const int dims = bottom_top_blob.dims; + const int channels_per_group = channels / group; - int w = bottom_top_blob.w; - int h = bottom_top_blob.h; - int size = w * h; - - int channels_per_group = channels / group; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int g = 0; g < group; g++) + if (dims == 1) { - Mat bottom_top_blob_g = bottom_top_blob.channel_range(g * channels_per_group, channels_per_group); - - // mean and var - float sum = 0.f; - for (int q = 0; q < channels_per_group; q++) + #pragma omp parallel for num_threads(opt.num_threads) + for (int g = 0; g < group; g++) { - const float* ptr = bottom_top_blob_g.channel(q); - for (int i = 0; i < size; i++) + Mat bottom_top_blob_g = bottom_top_blob.range(g * channels_per_group, channels_per_group); + const Mat gamma_data_g = gamma_data.range(g * channels_per_group, channels_per_group); + const Mat beta_data_g = beta_data.range(g * channels_per_group, channels_per_group); + + // mean and var + float sum = 0.f; + for (int q = 0; q < channels_per_group; q++) { - sum += ptr[i]; + sum += bottom_top_blob_g[q]; } - } - float mean = sum / (channels_per_group * size); + float mean = sum / channels_per_group; - float sqsum = 0.f; - for (int q = 0; q < channels_per_group; q++) - { - const float* ptr = bottom_top_blob_g.channel(q); - for (int i = 0; i < size; i++) + float sqsum = 0.f; + for (int q = 0; q < channels_per_group; q++) { - float tmp = ptr[i] - mean; + float tmp = bottom_top_blob_g[q] - mean; sqsum += tmp * tmp; } + float var = sqsum / channels_per_group; + + for (int q = 0; q < channels_per_group; q++) + { + float a; + float b; + if (affine) + { + float gamma = gamma_data_g[q]; + float beta = beta_data_g[q]; + + a = (float)(gamma / sqrt(var + eps)); + b = -mean * a + beta; + } + else + { + a = (float)(1.f / (sqrt(var + eps))); + b = -mean * a; + } + + bottom_top_blob_g[q] = bottom_top_blob_g[q] * a + b; + } } - float var = sqsum / (channels_per_group * size); + } - for (int q = 0; q < channels_per_group; q++) + if (dims == 2) + { + int w = bottom_top_blob.w; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int g = 0; g < group; g++) { - float a; - float b; - if (affine) + Mat bottom_top_blob_g = bottom_top_blob.row_range(g * channels_per_group, channels_per_group); + const Mat gamma_data_g = gamma_data.range(g * channels_per_group, channels_per_group); + const Mat beta_data_g = beta_data.range(g * channels_per_group, channels_per_group); + + // mean and var + float sum = 0.f; + for (int q = 0; q < channels_per_group; q++) { - float gamma = gamma_data[g * channels_per_group + q]; - float beta = beta_data[g * channels_per_group + q]; + const float* ptr = bottom_top_blob_g.row(q); + for (int i = 0; i < w; i++) + { + sum += ptr[i]; + } + } + float mean = sum / (channels_per_group * w); - a = static_cast(gamma / sqrt(var + eps)); - b = -mean * a + beta; + float sqsum = 0.f; + for (int q = 0; q < channels_per_group; q++) + { + const float* ptr = bottom_top_blob_g.row(q); + for (int i = 0; i < w; i++) + { + float tmp = ptr[i] - mean; + sqsum += tmp * tmp; + } } - else + float var = sqsum / (channels_per_group * w); + + for (int q = 0; q < channels_per_group; q++) { - a = static_cast(1.f / (sqrt(var + eps))); - b = -mean * a; + float a; + float b; + if (affine) + { + float gamma = gamma_data_g[q]; + float beta = beta_data_g[q]; + + a = (float)(gamma / sqrt(var + eps)); + b = -mean * a + beta; + } + else + { + a = (float)(1.f / (sqrt(var + eps))); + b = -mean * a; + } + + float* ptr = bottom_top_blob_g.row(q); + for (int i = 0; i < w; i++) + { + ptr[i] = ptr[i] * a + b; + } } + } + } - float* ptr = bottom_top_blob_g.channel(q); + if (dims == 3 || dims == 4) + { + int w = bottom_top_blob.w; + int h = bottom_top_blob.h; + int d = bottom_top_blob.d; + int size = w * h * d; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int g = 0; g < group; g++) + { + Mat bottom_top_blob_g = bottom_top_blob.channel_range(g * channels_per_group, channels_per_group); + const Mat gamma_data_g = gamma_data.range(g * channels_per_group, channels_per_group); + const Mat beta_data_g = beta_data.range(g * channels_per_group, channels_per_group); + + // mean and var + float sum = 0.f; + for (int q = 0; q < channels_per_group; q++) + { + const float* ptr = bottom_top_blob_g.channel(q); + for (int i = 0; i < size; i++) + { + sum += ptr[i]; + } + } + float mean = sum / (channels_per_group * size); + + float sqsum = 0.f; + for (int q = 0; q < channels_per_group; q++) + { + const float* ptr = bottom_top_blob_g.channel(q); + for (int i = 0; i < size; i++) + { + float tmp = ptr[i] - mean; + sqsum += tmp * tmp; + } + } + float var = sqsum / (channels_per_group * size); - for (int i = 0; i < size; i++) + for (int q = 0; q < channels_per_group; q++) { - ptr[i] = ptr[i] * a + b; + float a; + float b; + if (affine) + { + float gamma = gamma_data_g[q]; + float beta = beta_data_g[q]; + + a = (float)(gamma / sqrt(var + eps)); + b = -mean * a + beta; + } + else + { + a = (float)(1.f / (sqrt(var + eps))); + b = -mean * a; + } + + float* ptr = bottom_top_blob_g.channel(q); + for (int i = 0; i < size; i++) + { + ptr[i] = ptr[i] * a + b; + } } } } diff --git a/tests/test_groupnorm.cpp b/tests/test_groupnorm.cpp index e1e831c06..5f3c5c569 100644 --- a/tests/test_groupnorm.cpp +++ b/tests/test_groupnorm.cpp @@ -38,6 +38,17 @@ static int test_groupnorm(const ncnn::Mat& a, int group, float eps) } static int test_groupnorm_0() +{ + return 0 + || test_groupnorm(RandomMat(3, 6, 4, 2), 1, 0.01f) + || test_groupnorm(RandomMat(2, 3, 3, 8), 2, 0.002f) + || test_groupnorm(RandomMat(3, 4, 5, 6), 3, 0.01f) + || test_groupnorm(RandomMat(4, 5, 6, 12), 4, 0.02f) + || test_groupnorm(RandomMat(5, 6, 7, 24), 2, 0.001f) + || test_groupnorm(RandomMat(2, 8, 9, 24), 3, 0.0001f); +} + +static int test_groupnorm_1() { return 0 || test_groupnorm(RandomMat(6, 4, 2), 1, 0.01f) @@ -48,10 +59,35 @@ static int test_groupnorm_0() || test_groupnorm(RandomMat(8, 9, 24), 3, 0.0001f); } +static int test_groupnorm_2() +{ + return 0 + || test_groupnorm(RandomMat(24, 2), 1, 0.01f) + || test_groupnorm(RandomMat(23, 8), 2, 0.002f) + || test_groupnorm(RandomMat(25, 6), 3, 0.01f) + || test_groupnorm(RandomMat(26, 12), 4, 0.02f) + || test_groupnorm(RandomMat(27, 24), 2, 0.001f) + || test_groupnorm(RandomMat(29, 24), 3, 0.0001f); +} + +static int test_groupnorm_3() +{ + return 0 + || test_groupnorm(RandomMat(12), 1, 0.01f) + || test_groupnorm(RandomMat(18), 2, 0.002f) + || test_groupnorm(RandomMat(36), 3, 0.01f) + || test_groupnorm(RandomMat(212), 4, 0.02f) + || test_groupnorm(RandomMat(124), 2, 0.001f) + || test_groupnorm(RandomMat(324), 3, 0.0001f); +} + int main() { SRAND(7767517); return 0 - || test_groupnorm_0(); + || test_groupnorm_0() + || test_groupnorm_1() + || test_groupnorm_2() + || test_groupnorm_3(); } diff --git a/tools/pnnx/tests/ncnn/test_F_group_norm.py b/tools/pnnx/tests/ncnn/test_F_group_norm.py index 0e4710fbb..6b0347950 100644 --- a/tools/pnnx/tests/ncnn/test_F_group_norm.py +++ b/tools/pnnx/tests/ncnn/test_F_group_norm.py @@ -20,29 +20,37 @@ class Model(nn.Module): def __init__(self): super(Model, self).__init__() + self.w3 = nn.Parameter(torch.rand(16)) + self.b3 = nn.Parameter(torch.rand(16)) + self.w4 = nn.Parameter(torch.rand(12)) + self.b4 = nn.Parameter(torch.rand(12)) self.w5 = nn.Parameter(torch.rand(32)) self.b5 = nn.Parameter(torch.rand(32)) - def forward(self, z): + def forward(self, x, y, z): + x = F.group_norm(x, 4, self.w3, self.b3) + y = F.group_norm(y, 6, self.w4, self.b4) z = F.group_norm(z, 8, self.w5, self.b5, eps=1e-2) - return z + return x, y, z def test(): net = Model() net.eval() torch.manual_seed(0) + x = torch.rand(1, 16) + y = torch.rand(1, 12, 16) z = torch.rand(1, 32, 12, 16) - a = net(z) + a = net(x, y, z) # export torchscript - mod = torch.jit.trace(net, z) + mod = torch.jit.trace(net, (x, y, z)) mod.save("test_F_group_norm.pt") # torchscript to pnnx import os - os.system("../../src/pnnx test_F_group_norm.pt inputshape=[1,32,12,16]") + os.system("../../src/pnnx test_F_group_norm.pt inputshape=[1,16],[1,12,16],[1,32,12,16]") # ncnn inference import test_F_group_norm_ncnn diff --git a/tools/pnnx/tests/ncnn/test_nn_GroupNorm.py b/tools/pnnx/tests/ncnn/test_nn_GroupNorm.py index 71f7d684f..c016e7ae2 100644 --- a/tools/pnnx/tests/ncnn/test_nn_GroupNorm.py +++ b/tools/pnnx/tests/ncnn/test_nn_GroupNorm.py @@ -24,34 +24,47 @@ class Model(nn.Module): self.gn_1 = nn.GroupNorm(num_groups=12, num_channels=12, eps=1e-2, affine=True) self.gn_2 = nn.GroupNorm(num_groups=1, num_channels=12, eps=1e-4, affine=True) - def forward(self, x): + def forward(self, x, y, z): x = self.gn_0(x) x = self.gn_1(x) x = self.gn_2(x) - return x + + y = self.gn_0(y) + y = self.gn_1(y) + y = self.gn_2(y) + + z = self.gn_0(z) + z = self.gn_1(z) + z = self.gn_2(z) + return x, y, z def test(): net = Model() net.eval() torch.manual_seed(0) - x = torch.rand(1, 12, 24, 64) + x = torch.rand(1, 12, 64) + y = torch.rand(1, 12, 24, 64) + z = torch.rand(1, 12, 24, 32, 64) - a0 = net(x) + a = net(x, y, z) # export torchscript - mod = torch.jit.trace(net, x) + mod = torch.jit.trace(net, (x, y, z)) mod.save("test_nn_GroupNorm.pt") # torchscript to pnnx import os - os.system("../../src/pnnx test_nn_GroupNorm.pt inputshape=[1,12,24,64]") + os.system("../../src/pnnx test_nn_GroupNorm.pt inputshape=[1,12,64],[1,12,24,64],[1,12,24,32,64]") # ncnn inference import test_nn_GroupNorm_ncnn - b0 = test_nn_GroupNorm_ncnn.test_inference() + b = test_nn_GroupNorm_ncnn.test_inference() - return torch.allclose(a0, b0, 1e-4, 1e-4) + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True if __name__ == "__main__": if test():