Browse Source

RMSNorm (#5630)

tags/20240820
nihui GitHub 1 year ago
parent
commit
fdf0df3079
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
15 changed files with 669 additions and 7 deletions
  1. +21
    -0
      docs/developer-guide/operators.md
  2. +1
    -0
      src/CMakeLists.txt
  3. +200
    -0
      src/layer/rmsnorm.cpp
  4. +43
    -0
      src/layer/rmsnorm.h
  5. +1
    -0
      tests/CMakeLists.txt
  6. +121
    -0
      tests/test_rmsnorm.cpp
  7. +2
    -0
      tools/pnnx/src/CMakeLists.txt
  8. +1
    -1
      tools/pnnx/src/pass_level1/nn_RMSNorm.cpp
  9. +65
    -0
      tools/pnnx/src/pass_ncnn/F_rms_norm.cpp
  10. +70
    -0
      tools/pnnx/src/pass_ncnn/nn_RMSNorm.cpp
  11. +2
    -0
      tools/pnnx/tests/ncnn/CMakeLists.txt
  12. +3
    -3
      tools/pnnx/tests/ncnn/test_F_layer_norm.py
  13. +68
    -0
      tools/pnnx/tests/ncnn/test_F_rms_norm.py
  14. +3
    -3
      tools/pnnx/tests/ncnn/test_nn_LayerNorm.py
  15. +68
    -0
      tools/pnnx/tests/ncnn/test_nn_RMSNorm.py

+ 21
- 0
docs/developer-guide/operators.md View File

@@ -71,6 +71,7 @@
* [Reorg](#reorg)
* [Requantize](#requantize)
* [Reshape](#reshape)
* [RMSNorm](#rmsnorm)
* [RNN](#rnn)
* [Scale](#scale)
* [SELU](#selu)
@@ -1670,6 +1671,26 @@ Reshape flag:
- -1 = remaining
- -233 = drop this dim(default)

# RMSNorm
```
split x along outmost axis into part x0, x1 ...
root mean square normalize for each part x0, x1 ...
y = x * gamma by elementwise
```

* one_blob_only
* support_inplace

| param id | name | type | default | description |
| --------- | ------------- | ----- | --------- | ----------------- |
| 0 | affine_size | int | 0 | |
| 1 | eps | float | 0.001f | x = x / sqrt(var + eps) |
| 2 | affine | int | 1 | |

| weight | type | shape |
| ------------- | ----- | --------------------- |
| gamma_data | float | [affine_size] |

# RNN
Apply a single-layer RNN to a feature sequence of `T` timesteps. The input blob shape is `[w=input_size, h=T]` and the output blob shape is `[w=num_output, h=T]`.



+ 1
- 0
src/CMakeLists.txt View File

@@ -166,6 +166,7 @@ ncnn_add_layer(Erf)
ncnn_add_layer(Diag)
ncnn_add_layer(CELU)
ncnn_add_layer(Shrink)
ncnn_add_layer(RMSNorm)

if(NCNN_VULKAN)
ncnn_add_shader(${CMAKE_CURRENT_SOURCE_DIR}/convert_ycbcr.comp)


+ 200
- 0
src/layer/rmsnorm.cpp View File

@@ -0,0 +1,200 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "rmsnorm.h"

namespace ncnn {

RMSNorm::RMSNorm()
{
one_blob_only = true;
support_inplace = true;
}

int RMSNorm::load_param(const ParamDict& pd)
{
affine_size = pd.get(0, 0);
eps = pd.get(1, 0.001f);
affine = pd.get(2, 1);

return 0;
}

int RMSNorm::load_model(const ModelBin& mb)
{
if (affine == 0)
return 0;

gamma_data = mb.load(affine_size, 1);
if (gamma_data.empty())
return -100;

return 0;
}

int RMSNorm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
{
// x = x / sqrt(rms + eps) * gamma

int dims = bottom_top_blob.dims;

if (dims == 1)
{
int w = bottom_top_blob.w;
// assert affine_size == w

float* ptr = bottom_top_blob;

float sqsum = 0.f;
for (int i = 0; i < w; i++)
{
sqsum += ptr[i] * ptr[i];
}
float rms = sqrtf(sqsum / w + eps);

float a = 1.f / rms;

if (affine)
{
for (int i = 0; i < w; i++)
{
ptr[i] = (ptr[i] * a) * gamma_data[i];
}
}
else
{
for (int i = 0; i < w; i++)
{
ptr[i] = ptr[i] * a;
}
}
}

if (dims == 2)
{
int w = bottom_top_blob.w;
int h = bottom_top_blob.h;
// assert affine_size == w

#pragma omp parallel for num_threads(opt.num_threads)
for (int i = 0; i < h; i++)
{
float* ptr = bottom_top_blob.row(i);

float sqsum = 0.f;
for (int j = 0; j < w; j++)
{
sqsum += ptr[j] * ptr[j];
}
float rms = sqrtf(sqsum / w + eps);

float a = 1.f / rms;

if (affine)
{
for (int j = 0; j < w; j++)
{
ptr[j] = (ptr[j] * a) * gamma_data[j];
}
}
else
{
for (int j = 0; j < w; j++)
{
ptr[j] = ptr[j] * a;
}
}
}
}

if (dims == 3)
{
int w = bottom_top_blob.w;
int h = bottom_top_blob.h;
int channels = bottom_top_blob.c;
int size = w * h;

if (affine_size == w)
{
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < channels; q++)
{
for (int i = 0; i < h; i++)
{
float* ptr = bottom_top_blob.channel(q).row(i);

float sqsum = 0.f;
for (int j = 0; j < w; j++)
{
sqsum += ptr[j] * ptr[j];
}
float rms = sqrtf(sqsum / w + eps);

float a = 1.f / rms;

if (affine)
{
for (int j = 0; j < w; j++)
{
ptr[j] = (ptr[j] * a) * gamma_data[j];
}
}
else
{
for (int j = 0; j < w; j++)
{
ptr[j] = ptr[j] * a;
}
}
}
}
}
else // if (affine_size == size)
{
#pragma omp parallel for num_threads(opt.num_threads)
for (int q = 0; q < channels; q++)
{
float* ptr = bottom_top_blob.channel(q);

float sqsum = 0.f;
for (int i = 0; i < size; i++)
{
sqsum += ptr[i] * ptr[i];
}
float rms = sqrtf(sqsum / size + eps);

float a = 1.f / rms;

if (affine)
{
for (int i = 0; i < size; i++)
{
ptr[i] = (ptr[i] * a) * gamma_data[i];
}
}
else
{
for (int i = 0; i < size; i++)
{
ptr[i] = ptr[i] * a;
}
}
}
}
}

return 0;
}

} // namespace ncnn

+ 43
- 0
src/layer/rmsnorm.h View File

@@ -0,0 +1,43 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#ifndef LAYER_RMSNORM_H
#define LAYER_RMSNORM_H

#include "layer.h"

namespace ncnn {

class RMSNorm : public Layer
{
public:
RMSNorm();

virtual int load_param(const ParamDict& pd);

virtual int load_model(const ModelBin& mb);

virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const;

public:
int affine_size;
float eps;
int affine;

Mat gamma_data;
};

} // namespace ncnn

#endif // LAYER_RMSNORM_H

+ 1
- 0
tests/CMakeLists.txt View File

@@ -141,6 +141,7 @@ ncnn_add_layer_test(ReLU)
ncnn_add_layer_test(Reorg)
ncnn_add_layer_test(Requantize)
ncnn_add_layer_test(Reshape)
ncnn_add_layer_test(RMSNorm)
ncnn_add_layer_test(RNN)
ncnn_add_layer_test(ROIPooling)
ncnn_add_layer_test(ROIAlign)


+ 121
- 0
tests/test_rmsnorm.cpp View File

@@ -0,0 +1,121 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "testutil.h"

static int test_rmsnorm(const ncnn::Mat& a, int affine_size, float eps, int affine)
{
ncnn::ParamDict pd;
pd.set(0, affine_size);
pd.set(1, eps);
pd.set(2, affine);

std::vector<ncnn::Mat> weights(1);
weights[0] = RandomMat(affine_size);

int ret = test_layer("RMSNorm", pd, weights, a);
if (ret != 0)
{
fprintf(stderr, "test_rmsnorm failed a.dims=%d a=(%d %d %d) affine_size=%d eps=%f affine=%d\n", a.dims, a.w, a.h, a.c, affine_size, eps, affine);
}

return ret;
}

static int test_rmsnorm_0()
{
return 0
|| test_rmsnorm(RandomMat(6, 4, 2), 6, 0.01f, 0)
|| test_rmsnorm(RandomMat(4, 5, 6), 4, 0.01f, 0)
|| test_rmsnorm(RandomMat(3, 3, 8), 3, 0.002f, 0)
|| test_rmsnorm(RandomMat(5, 6, 12), 5, 0.02f, 0)
|| test_rmsnorm(RandomMat(4, 7, 16), 4, 0.02f, 0)
|| test_rmsnorm(RandomMat(6, 7, 24), 6, 0.001f, 0)
|| test_rmsnorm(RandomMat(5, 8, 32), 5, 0.001f, 0)
|| test_rmsnorm(RandomMat(6, 4, 2), 6, 0.01f, 1)
|| test_rmsnorm(RandomMat(4, 5, 6), 4, 0.01f, 1)
|| test_rmsnorm(RandomMat(3, 3, 8), 3, 0.002f, 1)
|| test_rmsnorm(RandomMat(5, 6, 12), 5, 0.02f, 1)
|| test_rmsnorm(RandomMat(4, 7, 16), 4, 0.02f, 1)
|| test_rmsnorm(RandomMat(6, 7, 24), 6, 0.001f, 1)
|| test_rmsnorm(RandomMat(5, 8, 32), 5, 0.001f, 1);
}

static int test_rmsnorm_1()
{
return 0
|| test_rmsnorm(RandomMat(6, 4, 2), 24, 0.01f, 0)
|| test_rmsnorm(RandomMat(4, 5, 6), 20, 0.01f, 0)
|| test_rmsnorm(RandomMat(3, 3, 8), 9, 0.002f, 0)
|| test_rmsnorm(RandomMat(5, 6, 12), 30, 0.02f, 0)
|| test_rmsnorm(RandomMat(4, 7, 16), 28, 0.02f, 0)
|| test_rmsnorm(RandomMat(6, 7, 24), 42, 0.001f, 0)
|| test_rmsnorm(RandomMat(5, 8, 32), 40, 0.001f, 0)
|| test_rmsnorm(RandomMat(6, 4, 2), 24, 0.01f, 1)
|| test_rmsnorm(RandomMat(4, 5, 6), 20, 0.01f, 1)
|| test_rmsnorm(RandomMat(3, 3, 8), 9, 0.002f, 1)
|| test_rmsnorm(RandomMat(5, 6, 12), 30, 0.02f, 1)
|| test_rmsnorm(RandomMat(4, 7, 16), 28, 0.02f, 1)
|| test_rmsnorm(RandomMat(6, 7, 24), 42, 0.001f, 1)
|| test_rmsnorm(RandomMat(5, 8, 32), 40, 0.001f, 1);
}

static int test_rmsnorm_2()
{
return 0
|| test_rmsnorm(RandomMat(4, 2), 4, 0.01f, 0)
|| test_rmsnorm(RandomMat(5, 6), 5, 0.01f, 0)
|| test_rmsnorm(RandomMat(3, 8), 3, 0.002f, 0)
|| test_rmsnorm(RandomMat(6, 12), 6, 0.02f, 0)
|| test_rmsnorm(RandomMat(4, 16), 4, 0.02f, 0)
|| test_rmsnorm(RandomMat(7, 24), 7, 0.001f, 0)
|| test_rmsnorm(RandomMat(8, 32), 8, 0.001f, 0)
|| test_rmsnorm(RandomMat(4, 2), 4, 0.01f, 1)
|| test_rmsnorm(RandomMat(5, 6), 5, 0.01f, 1)
|| test_rmsnorm(RandomMat(3, 8), 3, 0.002f, 1)
|| test_rmsnorm(RandomMat(6, 12), 6, 0.02f, 1)
|| test_rmsnorm(RandomMat(4, 16), 4, 0.02f, 1)
|| test_rmsnorm(RandomMat(7, 24), 7, 0.001f, 1)
|| test_rmsnorm(RandomMat(8, 32), 8, 0.001f, 1);
}

static int test_rmsnorm_3()
{
return 0
|| test_rmsnorm(RandomMat(2), 2, 0.01f, 0)
|| test_rmsnorm(RandomMat(6), 6, 0.01f, 0)
|| test_rmsnorm(RandomMat(8), 8, 0.002f, 0)
|| test_rmsnorm(RandomMat(12), 12, 0.02f, 0)
|| test_rmsnorm(RandomMat(16), 16, 0.02f, 0)
|| test_rmsnorm(RandomMat(24), 24, 0.001f, 0)
|| test_rmsnorm(RandomMat(32), 32, 0.001f, 0)
|| test_rmsnorm(RandomMat(2), 2, 0.01f, 1)
|| test_rmsnorm(RandomMat(6), 6, 0.01f, 1)
|| test_rmsnorm(RandomMat(8), 8, 0.002f, 1)
|| test_rmsnorm(RandomMat(12), 12, 0.02f, 1)
|| test_rmsnorm(RandomMat(16), 16, 0.02f, 1)
|| test_rmsnorm(RandomMat(24), 24, 0.001f, 1)
|| test_rmsnorm(RandomMat(32), 32, 0.001f, 1);
}

int main()
{
SRAND(7767517);

return 0
|| test_rmsnorm_0()
|| test_rmsnorm_1()
|| test_rmsnorm_2()
|| test_rmsnorm_3();
}

+ 2
- 0
tools/pnnx/src/CMakeLists.txt View File

@@ -475,6 +475,7 @@ set(pnnx_pass_ncnn_SRCS
pass_ncnn/F_prelu.cpp
pass_ncnn/F_relu.cpp
pass_ncnn/F_relu6.cpp
pass_ncnn/F_rms_norm.cpp
pass_ncnn/F_scaled_dot_product_attention.cpp
pass_ncnn/F_selu.cpp
pass_ncnn/F_sigmoid.cpp
@@ -541,6 +542,7 @@ set(pnnx_pass_ncnn_SRCS
pass_ncnn/nn_ReplicationPad1d.cpp
pass_ncnn/nn_ReplicationPad2d.cpp
pass_ncnn/nn_ReplicationPad3d.cpp
pass_ncnn/nn_RMSNorm.cpp
pass_ncnn/nn_RNN.cpp
pass_ncnn/nn_SELU.cpp
pass_ncnn/nn_Sigmoid.cpp


+ 1
- 1
tools/pnnx/src/pass_level1/nn_RMSNorm.cpp View File

@@ -37,7 +37,7 @@ public:

op->params["normalized_shape"] = rmsn->namedInput("normalized_shape");
op->params["eps"] = rmsn->namedInput("eps");
op->params["elementwise_affine"] = mod.hasattr("weight") && mod.hasattr("bias");
op->params["elementwise_affine"] = mod.hasattr("weight");

if (mod.hasattr("weight"))
{


+ 65
- 0
tools/pnnx/src/pass_ncnn/F_rms_norm.cpp View File

@@ -0,0 +1,65 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_ncnn.h"

namespace pnnx {

namespace ncnn {

class F_rms_norm : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
F.rms_norm op_0 1 1 input out weight=None normalized_shape=%normalized_shape eps=%eps
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "RMSNorm";
}

const char* name_str() const
{
return "rmsn";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
const std::vector<int>& normalized_shape = captured_params.at("normalized_shape").ai;
int affine_size = normalized_shape[0];
for (size_t i = 1; i < normalized_shape.size(); i++)
{
affine_size *= normalized_shape[i];
}

const float eps = captured_params.at("eps").type == 0 ? 0.f : captured_params.at("eps").f;

op->params["0"] = affine_size;
op->params["1"] = eps;
op->params["2"] = 0;
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_rms_norm, 20)

} // namespace ncnn

} // namespace pnnx

+ 70
- 0
tools/pnnx/src/pass_ncnn/nn_RMSNorm.cpp View File

@@ -0,0 +1,70 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_ncnn.h"

namespace pnnx {

namespace ncnn {

class nn_RMSNorm : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
nn.RMSNorm op_0 1 1 input out normalized_shape=%normalized_shape eps=%eps elementwise_affine=%elementwise_affine @weight
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "RMSNorm";
}

const char* name_str() const
{
return "rmsn";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
const std::vector<int>& normalized_shape = captured_params.at("normalized_shape").ai;
int affine_size = normalized_shape[0];
for (size_t i = 1; i < normalized_shape.size(); i++)
{
affine_size *= normalized_shape[i];
}

const float eps = captured_params.at("eps").type == 0 ? 0.f : captured_params.at("eps").f;

op->params["0"] = affine_size;
op->params["1"] = eps;
op->params["2"] = captured_params.at("elementwise_affine").b ? 1 : 0;

if (captured_params.at("elementwise_affine").b)
{
op->attrs["0"] = captured_attrs.at("op_0.weight");
}
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_RMSNorm, 20)

} // namespace ncnn

} // namespace pnnx

+ 2
- 0
tools/pnnx/tests/ncnn/CMakeLists.txt View File

@@ -53,6 +53,7 @@ pnnx_ncnn_add_test(F_pixel_unshuffle)
pnnx_ncnn_add_test(F_prelu)
pnnx_ncnn_add_test(F_relu)
pnnx_ncnn_add_test(F_relu6)
pnnx_ncnn_add_test(F_rms_norm)
pnnx_ncnn_add_test(F_selu)
pnnx_ncnn_add_test(F_sigmoid)
pnnx_ncnn_add_test(F_silu)
@@ -123,6 +124,7 @@ pnnx_ncnn_add_test(nn_ReLU6)
pnnx_ncnn_add_test(nn_ReplicationPad1d)
pnnx_ncnn_add_test(nn_ReplicationPad2d)
pnnx_ncnn_add_test(nn_ReplicationPad3d)
pnnx_ncnn_add_test(nn_RMSNorm)
pnnx_ncnn_add_test(nn_RNN)
pnnx_ncnn_add_test(nn_SELU)
pnnx_ncnn_add_test(nn_Sigmoid)


+ 3
- 3
tools/pnnx/tests/ncnn/test_F_layer_norm.py View File

@@ -37,8 +37,8 @@ def test():
net.eval()

torch.manual_seed(0)
x = torch.rand(12, 24)
y = torch.rand(3, 12, 16)
x = torch.rand(1, 12, 24)
y = torch.rand(1, 3, 12, 16)

a = net(x, y)

@@ -48,7 +48,7 @@ def test():

# torchscript to pnnx
import os
os.system("../../src/pnnx test_F_layer_norm.pt inputshape=[12,24],[3,12,16]")
os.system("../../src/pnnx test_F_layer_norm.pt inputshape=[1,12,24],[1,3,12,16]")

# ncnn inference
import test_F_layer_norm_ncnn


+ 68
- 0
tools/pnnx/tests/ncnn/test_F_rms_norm.py View File

@@ -0,0 +1,68 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F
from packaging import version

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

self.w3 = nn.Parameter(torch.rand(24))
self.w4 = nn.Parameter(torch.rand(12, 16))

def forward(self, x, y):
x = F.rms_norm(x, (24,), self.w3)

y = F.rms_norm(y, (16,), None)
z = F.rms_norm(y, (12,16), self.w4, eps=1e-3)
return x, y, z

def test():
if version.parse(torch.__version__) < version.parse('2.4'):
return True

net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(1, 12, 24)
y = torch.rand(1, 3, 12, 16)

a = net(x, y)

# export torchscript
mod = torch.jit.trace(net, (x, y))
mod.save("test_F_rms_norm.pt")

# torchscript to pnnx
import os
os.system("../../src/pnnx test_F_rms_norm.pt inputshape=[1,12,24],[1,3,12,16]")

# ncnn inference
import test_F_rms_norm_ncnn
b = test_F_rms_norm_ncnn.test_inference()

for a0, b0 in zip(a, b):
if not torch.allclose(a0, b0, 1e-4, 1e-4):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

+ 3
- 3
tools/pnnx/tests/ncnn/test_nn_LayerNorm.py View File

@@ -36,8 +36,8 @@ def test():
net.eval()

torch.manual_seed(0)
x = torch.rand(24, 64)
y = torch.rand(12, 24, 64)
x = torch.rand(1, 24, 64)
y = torch.rand(1, 12, 24, 64)

a = net(x, y)

@@ -47,7 +47,7 @@ def test():

# torchscript to pnnx
import os
os.system("../../src/pnnx test_nn_LayerNorm.pt inputshape=[24,64],[12,24,64]")
os.system("../../src/pnnx test_nn_LayerNorm.pt inputshape=[1,24,64],[1,12,24,64]")

# ncnn inference
import test_nn_LayerNorm_ncnn


+ 68
- 0
tools/pnnx/tests/ncnn/test_nn_RMSNorm.py View File

@@ -0,0 +1,68 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F
from packaging import version

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

self.rmsn_0 = nn.RMSNorm(64)
self.rmsn_0.weight = nn.Parameter(torch.rand(64))
self.rmsn_1 = nn.RMSNorm(normalized_shape=(24,64), eps=1e-2, elementwise_affine=False)

def forward(self, x, y):
x = self.rmsn_0(x)
y = self.rmsn_0(y)
z = self.rmsn_1(y)
return x, y, z

def test():
if version.parse(torch.__version__) < version.parse('2.4'):
return True

net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(1, 24, 64)
y = torch.rand(1, 12, 24, 64)

a = net(x, y)

# export torchscript
mod = torch.jit.trace(net, (x, y))
mod.save("test_nn_RMSNorm.pt")

# torchscript to pnnx
import os
os.system("../../src/pnnx test_nn_RMSNorm.pt inputshape=[1,24,64],[1,12,24,64]")

# ncnn inference
import test_nn_RMSNorm_ncnn
b = test_nn_RMSNorm_ncnn.test_inference()

for a0, b0 in zip(a, b):
if not torch.allclose(a0, b0, 1e-4, 1e-4):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

Loading…
Cancel
Save