diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index c15ba0973..58264dfd9 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -270,6 +270,8 @@ set(pnnx_pass_level2_SRCS pass_level2/torch_unbind.cpp pass_level2/torch_unsqueeze.cpp pass_level2/torch_var.cpp + pass_level2/torch_view_as_complex.cpp + pass_level2/torch_view_as_real.cpp pass_level2/torch_zeros.cpp pass_level2/torch_zeros_like.cpp pass_level2/torch_stft.cpp diff --git a/tools/pnnx/src/pass_level2/torch_view_as_complex.cpp b/tools/pnnx/src/pass_level2/torch_view_as_complex.cpp new file mode 100644 index 000000000..e00ff1371 --- /dev/null +++ b/tools/pnnx/src/pass_level2/torch_view_as_complex.cpp @@ -0,0 +1,40 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class torch_view_as_complex : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +aten::view_as_complex op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.view_as_complex"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_view_as_complex, 20) + +} // namespace pnnx \ No newline at end of file diff --git a/tools/pnnx/src/pass_level2/torch_view_as_real.cpp b/tools/pnnx/src/pass_level2/torch_view_as_real.cpp new file mode 100644 index 000000000..83327e01e --- /dev/null +++ b/tools/pnnx/src/pass_level2/torch_view_as_real.cpp @@ -0,0 +1,40 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "pass_level2.h" + +namespace pnnx { + +class torch_view_as_real : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +aten::view_as_real op_0 1 1 input out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "torch.view_as_real"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_view_as_real, 20) + +} // namespace pnnx \ No newline at end of file diff --git a/tools/pnnx/tests/CMakeLists.txt b/tools/pnnx/tests/CMakeLists.txt index fab12342b..346ee0a95 100644 --- a/tools/pnnx/tests/CMakeLists.txt +++ b/tools/pnnx/tests/CMakeLists.txt @@ -237,6 +237,8 @@ pnnx_add_test(torch_topk) pnnx_add_test(torch_transpose) pnnx_add_test(torch_unbind) pnnx_add_test(torch_unsqueeze) +pnnx_add_test(torch_view_as_complex) +pnnx_add_test(torch_view_as_real) pnnx_add_test(torch_zeros) pnnx_add_test(torch_zeros_like) diff --git a/tools/pnnx/tests/test_torch_view_as_complex.py b/tools/pnnx/tests/test_torch_view_as_complex.py new file mode 100644 index 000000000..c2cedc537 --- /dev/null +++ b/tools/pnnx/tests/test_torch_view_as_complex.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = torch.view_as_complex(x) + y = torch.view_as_complex(y) + z = torch.view_as_complex(z) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 2) + y = torch.rand(1, 5, 9, 2) + z = torch.rand(14, 8, 5, 9, 2) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_torch_view_as_complex.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torch_view_as_complex.pt inputshape=[1,3,2],[1,5,9,2],[14,8,5,9,2]") + + # pnnx inference + import test_torch_view_as_complex_pnnx + b = test_torch_view_as_complex_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) \ No newline at end of file diff --git a/tools/pnnx/tests/test_torch_view_as_real.py b/tools/pnnx/tests/test_torch_view_as_real.py new file mode 100644 index 000000000..06bbe7de9 --- /dev/null +++ b/tools/pnnx/tests/test_torch_view_as_real.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = torch.view_as_real(x) + y = torch.view_as_real(y) + z = torch.view_as_real(z) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16,dtype=torch.complex64) + y = torch.rand(1, 5, 9, 11,dtype=torch.complex64) + z = torch.rand(14, 8, 5, 9, 10,dtype=torch.complex64) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_torch_view_as_real.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torch_view_as_real.pt inputshape=[1,3,16]c64,[1,5,9,11]c64,[14,8,5,9,10]c64") + + # pnnx inference + import test_torch_view_as_real_pnnx + b = test_torch_view_as_real_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) \ No newline at end of file