From 242e775d2139ef3a53e8f6ec56d0d04f25009901 Mon Sep 17 00:00:00 2001 From: nihui Date: Thu, 9 Feb 2023 19:28:18 +0800 Subject: [PATCH] pnnx convert torch log10, pow 2 as square (#4518) --- docs/developer-guide/operators.md | 1 + src/layer/arm/unaryop_arm.cpp | 20 +++++ src/layer/arm/unaryop_arm_asimdhp.cpp | 19 +++++ src/layer/loongarch/unaryop_loongarch.cpp | 17 ++++ src/layer/mips/unaryop_mips.cpp | 17 ++++ src/layer/riscv/unaryop_riscv.cpp | 22 +++++ src/layer/unaryop.cpp | 11 +++ src/layer/unaryop.h | 3 +- src/layer/vulkan/shader/unaryop.comp | 1 + src/layer/vulkan/shader/unaryop_pack4.comp | 1 + src/layer/vulkan/shader/unaryop_pack8.comp | 5 ++ src/layer/x86/unaryop_x86.cpp | 29 +++++++ tests/test_unaryop.cpp | 4 +- tools/pnnx/src/CMakeLists.txt | 1 + tools/pnnx/src/ir.cpp | 2 + .../pnnx/src/pass_level3/fuse_expression.cpp | 3 + tools/pnnx/src/pass_level5.cpp | 3 + .../src/pass_level5/attribute_unpooling.cpp | 80 +++++++++++++++++++ .../src/pass_level5/attribute_unpooling.h | 21 +++++ .../pnnx/src/pass_level5/eval_expression.cpp | 6 ++ tools/pnnx/src/pass_ncnn/Tensor_reshape.cpp | 2 +- tools/pnnx/src/pass_ncnn/Tensor_view.cpp | 2 +- .../pnnx/src/pass_ncnn/expand_expression.cpp | 10 +++ .../pnnx/src/pass_ncnn/solve_batch_index.cpp | 10 +++ tools/pnnx/tests/CMakeLists.txt | 1 + tools/pnnx/tests/ncnn/CMakeLists.txt | 1 + tools/pnnx/tests/ncnn/test_torch_log10.py | 61 ++++++++++++++ tools/pnnx/tests/test_torch_log10.py | 61 ++++++++++++++ 28 files changed, 409 insertions(+), 5 deletions(-) create mode 100644 tools/pnnx/src/pass_level5/attribute_unpooling.cpp create mode 100644 tools/pnnx/src/pass_level5/attribute_unpooling.h create mode 100644 tools/pnnx/tests/ncnn/test_torch_log10.py create mode 100644 tools/pnnx/tests/test_torch_log10.py diff --git a/docs/developer-guide/operators.md b/docs/developer-guide/operators.md index ddafa7caa..4647fb08f 100644 --- a/docs/developer-guide/operators.md +++ b/docs/developer-guide/operators.md @@ -1759,3 +1759,4 @@ Operation type: - 14 = ATAN - 15 = RECIPROCAL - 16 = TANH +- 17 = LOG10 diff --git a/src/layer/arm/unaryop_arm.cpp b/src/layer/arm/unaryop_arm.cpp index c2c7de31c..eb449a7e9 100644 --- a/src/layer/arm/unaryop_arm.cpp +++ b/src/layer/arm/unaryop_arm.cpp @@ -360,6 +360,20 @@ struct unary_op_tanh #endif // __ARM_NEON }; +struct unary_op_log10 +{ + float func(const float& x) const + { + return (float)log10(x); + } +#if __ARM_NEON + float32x4_t func_pack4(const float32x4_t& x) const + { + return vmulq_f32(log_ps(x), vdupq_n_f32(0.434294481903)); + } +#endif // __ARM_NEON +}; + } // namespace UnaryOp_arm_functor int UnaryOp_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const @@ -429,6 +443,9 @@ int UnaryOp_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const if (op_type == Operation_TANH) return unary_op_inplace(bottom_top_blob, opt); + if (op_type == Operation_LOG10) + return unary_op_inplace(bottom_top_blob, opt); + return 0; } @@ -556,6 +573,9 @@ int UnaryOp_arm::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) if (op_type == Operation_TANH) return unary_op_inplace_bf16s(bottom_top_blob, opt); + if (op_type == Operation_LOG10) + return unary_op_inplace_bf16s(bottom_top_blob, opt); + return 0; } #endif // NCNN_BF16 diff --git a/src/layer/arm/unaryop_arm_asimdhp.cpp b/src/layer/arm/unaryop_arm_asimdhp.cpp index 2cd816290..b6bd0c601 100644 --- a/src/layer/arm/unaryop_arm_asimdhp.cpp +++ b/src/layer/arm/unaryop_arm_asimdhp.cpp @@ -436,6 +436,22 @@ struct unary_op_tanh_fp16s } }; +struct unary_op_log10_fp16s +{ + __fp16 func(const __fp16& x) const + { + return (__fp16)log10(x); + } + float16x4_t func_pack4(const float16x4_t& x) const + { + return vmul_f16(log_ps(x), vdup_n_f16(0.434294481903)); + } + float16x8_t func_pack8(const float16x8_t& x) const + { + return vmulq_f16(log_ps(x), vdupq_n_f16(0.434294481903)); + } +}; + } // namespace UnaryOp_arm_functor int UnaryOp_arm::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt) const @@ -493,6 +509,9 @@ int UnaryOp_arm::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt) if (op_type == Operation_TANH) return unary_op_inplace_fp16s(bottom_top_blob, opt); + if (op_type == Operation_LOG10) + return unary_op_inplace_fp16s(bottom_top_blob, opt); + return 0; } #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC diff --git a/src/layer/loongarch/unaryop_loongarch.cpp b/src/layer/loongarch/unaryop_loongarch.cpp index 892c4dc42..364d4a15b 100644 --- a/src/layer/loongarch/unaryop_loongarch.cpp +++ b/src/layer/loongarch/unaryop_loongarch.cpp @@ -364,6 +364,20 @@ struct unary_op_tanh #endif // __loongarch_sx }; +struct unary_op_log10 +{ + float func(const float& x) const + { + return (float)log10(x); + } +#if __loongarch_sx + __m128 func_pack4(const __m128& x) const + { + return __lsx_vfmul_s(log_ps(x), __lsx_vreplfr2vr_s(0.434294481903)); + } +#endif // __loongarch_sx +}; + } // namespace UnaryOp_loongarch_functor int UnaryOp_loongarch::forward_inplace(Mat& bottom_top_blob, const Option& opt) const @@ -421,6 +435,9 @@ int UnaryOp_loongarch::forward_inplace(Mat& bottom_top_blob, const Option& opt) if (op_type == Operation_TANH) return unary_op_inplace(bottom_top_blob, opt); + if (op_type == Operation_LOG10) + return unary_op_inplace(bottom_top_blob, opt); + return 0; } diff --git a/src/layer/mips/unaryop_mips.cpp b/src/layer/mips/unaryop_mips.cpp index c3d06adee..6bc6a7933 100644 --- a/src/layer/mips/unaryop_mips.cpp +++ b/src/layer/mips/unaryop_mips.cpp @@ -374,6 +374,20 @@ struct unary_op_tanh #endif // __mips_msa }; +struct unary_op_log10 +{ + float func(const float& x) const + { + return (float)log10(x); + } +#if __mips_msa + v4f32 func_pack4(const v4f32& x) const + { + return __msa_fmul_w(log_ps(x), __msa_fill_w_f32(0.434294481903)); + } +#endif // __mips_msa +}; + } // namespace UnaryOp_mips_functor int UnaryOp_mips::forward_inplace(Mat& bottom_top_blob, const Option& opt) const @@ -431,6 +445,9 @@ int UnaryOp_mips::forward_inplace(Mat& bottom_top_blob, const Option& opt) const if (op_type == Operation_TANH) return unary_op_inplace(bottom_top_blob, opt); + if (op_type == Operation_LOG10) + return unary_op_inplace(bottom_top_blob, opt); + return 0; } diff --git a/src/layer/riscv/unaryop_riscv.cpp b/src/layer/riscv/unaryop_riscv.cpp index e5eb80151..40d583e0d 100644 --- a/src/layer/riscv/unaryop_riscv.cpp +++ b/src/layer/riscv/unaryop_riscv.cpp @@ -245,6 +245,14 @@ struct unary_op_tanh } }; +struct unary_op_log10 +{ + vfloat32m8_t operator()(const vfloat32m8_t& x, const size_t& vl) const + { + return vfmul_vf_f32m8(log_ps(x, vl), 0.434294481903, vl); + } +}; + } // namespace UnaryOp_riscv_functor #endif // __riscv_vector @@ -311,6 +319,9 @@ int UnaryOp_riscv::forward_inplace(Mat& bottom_top_blob, const Option& opt) cons if (op_type == Operation_TANH) return unary_op_inplace(bottom_top_blob, opt); + if (op_type == Operation_LOG10) + return unary_op_inplace(bottom_top_blob, opt); + return 0; #else // __riscv_vector return UnaryOp::forward_inplace(bottom_top_blob, opt); @@ -528,6 +539,14 @@ struct unary_op_tanh_fp16s } }; +struct unary_op_log10_fp16s +{ + vfloat16m8_t operator()(const vfloat16m8_t& x, const size_t& vl) const + { + return vfmul_vf_f16m8(log_ps(x, vl), 0.434294481903, vl); + } +}; + } // namespace UnaryOp_riscv_functor int UnaryOp_riscv::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt) const @@ -585,6 +604,9 @@ int UnaryOp_riscv::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt if (op_type == Operation_TANH) return unary_op_inplace_fp16s(bottom_top_blob, opt); + if (op_type == Operation_LOG10) + return unary_op_inplace_fp16s(bottom_top_blob, opt); + return 0; } #endif // __riscv_vector && __riscv_zfh diff --git a/src/layer/unaryop.cpp b/src/layer/unaryop.cpp index 412c3e18b..d3838e345 100644 --- a/src/layer/unaryop.cpp +++ b/src/layer/unaryop.cpp @@ -183,6 +183,14 @@ struct unary_op_tanh } }; +struct unary_op_log10 +{ + float operator()(const float& x) const + { + return (float)log10(x); + } +}; + int UnaryOp::forward_inplace(Mat& bottom_top_blob, const Option& opt) const { if (op_type == Operation_ABS) @@ -236,6 +244,9 @@ int UnaryOp::forward_inplace(Mat& bottom_top_blob, const Option& opt) const if (op_type == Operation_TANH) return unary_op_inplace(bottom_top_blob, opt); + if (op_type == Operation_LOG10) + return unary_op_inplace(bottom_top_blob, opt); + return 0; } diff --git a/src/layer/unaryop.h b/src/layer/unaryop.h index bdbc80eab..9d38cdb86 100644 --- a/src/layer/unaryop.h +++ b/src/layer/unaryop.h @@ -46,7 +46,8 @@ public: Operation_ACOS = 13, Operation_ATAN = 14, Operation_RECIPROCAL = 15, - Operation_TANH = 16 + Operation_TANH = 16, + Operation_LOG10 = 17 }; public: diff --git a/src/layer/vulkan/shader/unaryop.comp b/src/layer/vulkan/shader/unaryop.comp index 9bee389c6..66e78a30a 100644 --- a/src/layer/vulkan/shader/unaryop.comp +++ b/src/layer/vulkan/shader/unaryop.comp @@ -86,6 +86,7 @@ void main() #else if (op_type == 16) res = tanh(v); #endif + if (op_type == 17) res = log(v) * afp(0.434294481903); #if NCNN_image_shader image3d_st1(top_blob_3d, ivec3(gx, gy, gz), res); diff --git a/src/layer/vulkan/shader/unaryop_pack4.comp b/src/layer/vulkan/shader/unaryop_pack4.comp index f04c81618..4525ba08c 100644 --- a/src/layer/vulkan/shader/unaryop_pack4.comp +++ b/src/layer/vulkan/shader/unaryop_pack4.comp @@ -86,6 +86,7 @@ void main() #else if (op_type == 16) res = tanh(v); #endif + if (op_type == 17) res = log(v) * afp(0.434294481903); #if NCNN_image_shader image3d_st4(top_blob_3d, ivec3(gx, gy, gz), res); diff --git a/src/layer/vulkan/shader/unaryop_pack8.comp b/src/layer/vulkan/shader/unaryop_pack8.comp index ebca872f7..47bdfbde7 100644 --- a/src/layer/vulkan/shader/unaryop_pack8.comp +++ b/src/layer/vulkan/shader/unaryop_pack8.comp @@ -156,6 +156,11 @@ void main() res[1] = tanh(v[1]); #endif } + if (op_type == 17) + { + res[0] = log(v[0]) * afp(0.434294481903); + res[1] = log(v[1]) * afp(0.434294481903); + } #if NCNN_image_shader image3d_st8(top_blob_3d, ivec3(gx, gy, gz), res); diff --git a/src/layer/x86/unaryop_x86.cpp b/src/layer/x86/unaryop_x86.cpp index d7e052c29..faba25dbd 100644 --- a/src/layer/x86/unaryop_x86.cpp +++ b/src/layer/x86/unaryop_x86.cpp @@ -686,6 +686,32 @@ struct unary_op_tanh #endif // __SSE2__ }; +struct unary_op_log10 +{ + float func(const float& x) const + { + return (float)log10(x); + } +#if __SSE2__ + __m128 func_pack4(const __m128& x) const + { + return _mm_mul_ps(log_ps(x), _mm_set1_ps(0.434294481903)); + } +#if __AVX__ + __m256 func_pack8(const __m256& x) const + { + return _mm256_mul_ps(log256_ps(x), _mm256_set1_ps(0.434294481903)); + } +#if __AVX512F__ + __m512 func_pack16(const __m512& x) const + { + return _mm512_mul_ps(log512_ps(x), _mm512_set1_ps(0.434294481903)); + } +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +}; + } // namespace UnaryOp_x86_functor int UnaryOp_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const @@ -742,6 +768,9 @@ int UnaryOp_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const if (op_type == Operation_TANH) return unary_op_inplace(bottom_top_blob, opt); + if (op_type == Operation_LOG10) + return unary_op_inplace(bottom_top_blob, opt); + return 0; } diff --git a/tests/test_unaryop.cpp b/tests/test_unaryop.cpp index 473ab2cab..bb25b06b7 100644 --- a/tests/test_unaryop.cpp +++ b/tests/test_unaryop.cpp @@ -15,7 +15,7 @@ #include "layer/unaryop.h" #include "testutil.h" -#define OP_TYPE_MAX 17 +#define OP_TYPE_MAX 18 static int op_type = 0; @@ -30,7 +30,7 @@ static int test_unaryop(const ncnn::Mat& _a) a[i] *= 1000; } } - if (op_type == 5 || op_type == 6 || op_type == 8) + if (op_type == 5 || op_type == 6 || op_type == 8 || op_type == 17) { // value must be positive for sqrt rsqrt log Randomize(a, 0.001f, 2.f); diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index 8e0c72ee6..5a33967b9 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -298,6 +298,7 @@ set(pnnx_pass_level4_SRCS ) set(pnnx_pass_level5_SRCS + pass_level5/attribute_unpooling.cpp pass_level5/eliminate_dropout.cpp pass_level5/eliminate_identity_operator.cpp pass_level5/eliminate_maxpool_indices.cpp diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index 65c5ed018..2a8c5d54f 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -1074,6 +1074,7 @@ static std::string expand_expression(const Operator* op) || t == "exp" || t == "floor" || t == "log" + || t == "log10" || t == "neg" || t == "reciprocal" || t == "rsqrt" @@ -1101,6 +1102,7 @@ static std::string expand_expression(const Operator* op) if (t == "exp") unaryop = "torch.exp"; if (t == "floor") unaryop = "torch.floor"; if (t == "log") unaryop = "torch.log"; + if (t == "log10") unaryop = "torch.log10"; if (t == "neg") unaryop = "torch.neg"; if (t == "reciprocal") unaryop = "torch.reciprocal"; if (t == "rsqrt") unaryop = "torch.rsqrt"; diff --git a/tools/pnnx/src/pass_level3/fuse_expression.cpp b/tools/pnnx/src/pass_level3/fuse_expression.cpp index 0866e1301..698bec333 100644 --- a/tools/pnnx/src/pass_level3/fuse_expression.cpp +++ b/tools/pnnx/src/pass_level3/fuse_expression.cpp @@ -78,6 +78,7 @@ static bool operand_maybe_tensor(const Operand* operand) || op->type == "aten::exp" || op->type == "aten::floor" || op->type == "aten::log" + || op->type == "aten::log10" || op->type == "aten::neg" || op->type == "aten::reciprocal" || op->type == "aten::rsqrt" @@ -286,6 +287,7 @@ static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, s || op->type == "aten::exp" || op->type == "aten::floor" || op->type == "aten::log" + || op->type == "aten::log10" || op->type == "aten::neg" || op->type == "aten::reciprocal" || op->type == "aten::rsqrt" @@ -464,6 +466,7 @@ void fuse_expression(Graph& graph, const std::set& foldable_constan || op->type == "aten::floor" || op->type == "aten::floor_divide" || op->type == "aten::log" + || op->type == "aten::log10" || op->type == "aten::mul" || op->type == "aten::neg" || op->type == "aten::pow" diff --git a/tools/pnnx/src/pass_level5.cpp b/tools/pnnx/src/pass_level5.cpp index de033b1ba..72030dd47 100644 --- a/tools/pnnx/src/pass_level5.cpp +++ b/tools/pnnx/src/pass_level5.cpp @@ -14,6 +14,7 @@ #include "pass_level5.h" +#include "pass_level5/attribute_unpooling.h" #include "pass_level5/fold_constants.h" #include "pass_level5/eliminate_dropout.h" #include "pass_level5/eliminate_identity_operator.h" @@ -81,6 +82,8 @@ void pass_level5(Graph& g, const std::set& foldable_constants, cons fuse_slice_copy(g); + attribute_unpooling(g); + fuse_static_batchnorm(g); fuse_static_groupnorm(g); fuse_static_instancenorm(g); diff --git a/tools/pnnx/src/pass_level5/attribute_unpooling.cpp b/tools/pnnx/src/pass_level5/attribute_unpooling.cpp new file mode 100644 index 000000000..76e7f7622 --- /dev/null +++ b/tools/pnnx/src/pass_level5/attribute_unpooling.cpp @@ -0,0 +1,80 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "attribute_unpooling.h" + +#include + +namespace pnnx { + +void attribute_unpooling(Graph& graph) +{ + while (1) + { + bool matched = false; + + for (size_t i = 0; i < graph.ops.size(); i++) + { + Operator* op = graph.ops[i]; + + if (op->type != "pnnx.Attribute") + continue; + + Operand* attr = op->outputs[0]; + + if (attr->consumers.size() < 2) + continue; + + // multiple modules share same weight + matched = true; + + for (int i = 1; i < (int)attr->consumers.size(); i++) + { + Operator* x = attr->consumers[i]; + + Operator* op2 = graph.new_operator_after("pnnx.Attribute", op->name + "_" + std::to_string(i), op); + + op2->inputnames = op->inputnames; + op2->params = op->params; + op2->attrs = op->attrs; + + Operand* attr2 = graph.new_operand(attr->name + "_" + std::to_string(i)); + + attr2->type = attr->type; + attr2->shape = attr->shape; + attr2->params = attr->params; + + op2->outputs.push_back(attr2); + + attr2->producer = op2; + attr2->consumers.push_back(x); + + for (size_t j = 0; j < x->inputs.size(); j++) + { + if (x->inputs[j] == attr) + x->inputs[j] = attr2; + } + } + + attr->consumers.resize(1); + + break; + } + + if (!matched) + break; + } +} + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/attribute_unpooling.h b/tools/pnnx/src/pass_level5/attribute_unpooling.h new file mode 100644 index 000000000..333c709be --- /dev/null +++ b/tools/pnnx/src/pass_level5/attribute_unpooling.h @@ -0,0 +1,21 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +void attribute_unpooling(Graph& g); + +} // namespace pnnx diff --git a/tools/pnnx/src/pass_level5/eval_expression.cpp b/tools/pnnx/src/pass_level5/eval_expression.cpp index 7326f1b32..47868cdef 100644 --- a/tools/pnnx/src/pass_level5/eval_expression.cpp +++ b/tools/pnnx/src/pass_level5/eval_expression.cpp @@ -153,6 +153,7 @@ static std::string eval_expression(const Operator* op) || t == "exp" || t == "floor" || t == "log" + || t == "log10" || t == "neg" || t == "reciprocal" || t == "rsqrt" @@ -242,6 +243,11 @@ static std::string eval_expression(const Operator* op) float r = log(af); exprstack.push(std::to_string(r)); } + if (t == "log10") + { + float r = log10(af); + exprstack.push(std::to_string(r)); + } if (t == "neg") { float r = -af; diff --git a/tools/pnnx/src/pass_ncnn/Tensor_reshape.cpp b/tools/pnnx/src/pass_ncnn/Tensor_reshape.cpp index ee3b74e6a..da99e7cee 100644 --- a/tools/pnnx/src/pass_ncnn/Tensor_reshape.cpp +++ b/tools/pnnx/src/pass_ncnn/Tensor_reshape.cpp @@ -45,7 +45,7 @@ pnnx.Output output 1 0 out { const std::vector& shape = captured_params.at("shape").ai; - const int batch_index = op->inputs[0]->params["__batch_index"].i; + const int batch_index = op->outputs[0]->params["__batch_index"].i; if (batch_index != 0 && batch_index != 233) { diff --git a/tools/pnnx/src/pass_ncnn/Tensor_view.cpp b/tools/pnnx/src/pass_ncnn/Tensor_view.cpp index 19c704b91..7c41b0ae5 100644 --- a/tools/pnnx/src/pass_ncnn/Tensor_view.cpp +++ b/tools/pnnx/src/pass_ncnn/Tensor_view.cpp @@ -45,7 +45,7 @@ pnnx.Output output 1 0 out { const std::vector& shape = captured_params.at("shape").ai; - const int batch_index = op->inputs[0]->params["__batch_index"].i; + const int batch_index = op->outputs[0]->params["__batch_index"].i; if (batch_index != 0 && batch_index != 233) { diff --git a/tools/pnnx/src/pass_ncnn/expand_expression.cpp b/tools/pnnx/src/pass_ncnn/expand_expression.cpp index cb716b914..b9492e59f 100644 --- a/tools/pnnx/src/pass_ncnn/expand_expression.cpp +++ b/tools/pnnx/src/pass_ncnn/expand_expression.cpp @@ -128,6 +128,7 @@ static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx || t == "exp" || t == "floor" || t == "log" + || t == "log10" || t == "neg" || t == "reciprocal" || t == "rsqrt" @@ -154,6 +155,7 @@ static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx if (t == "exp") op_unary->params["0"] = 7; if (t == "floor") op_unary->params["0"] = 2; if (t == "log") op_unary->params["0"] = 8; + if (t == "log10") op_unary->params["0"] = 17; if (t == "neg") op_unary->params["0"] = 1; if (t == "reciprocal") op_unary->params["0"] = 15; if (t == "rsqrt") op_unary->params["0"] = 6; @@ -222,6 +224,14 @@ static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx op_binary->params["1"] = 1; // with_scalar op_binary->params["2"] = std::stof(b); + if (t == "pow" && std::stof(b) == 2) + { + // replace pow 2 with square + op_binary->type = "UnaryOp"; + op_binary->params.clear(); + op_binary->params["0"] = 4; + } + Operand* op_binary_out = graph.new_operand(op->name + "_" + r); op_binary_out->producer = op_binary; diff --git a/tools/pnnx/src/pass_ncnn/solve_batch_index.cpp b/tools/pnnx/src/pass_ncnn/solve_batch_index.cpp index 73e8e08eb..8f2fd2529 100644 --- a/tools/pnnx/src/pass_ncnn/solve_batch_index.cpp +++ b/tools/pnnx/src/pass_ncnn/solve_batch_index.cpp @@ -334,6 +334,16 @@ void solve_batch_index(Graph& graph) } } + // always treat 1-dim tensor as no batch axis + for (Operand* r : graph.operands) + { + if (r->shape.size() == 1) + { + fprintf(stderr, "force batch axis 233 for operand %s\n", r->name.c_str()); + r->params["__batch_index"] = 233; + } + } + // fallback axis 233 for unknown for (Operand* r : graph.operands) { diff --git a/tools/pnnx/tests/CMakeLists.txt b/tools/pnnx/tests/CMakeLists.txt index 9d2177790..5c7450b2e 100644 --- a/tools/pnnx/tests/CMakeLists.txt +++ b/tools/pnnx/tests/CMakeLists.txt @@ -256,6 +256,7 @@ pnnx_add_test(torch_exp) pnnx_add_test(torch_floor) pnnx_add_test(torch_imag) pnnx_add_test(torch_log) +pnnx_add_test(torch_log10) pnnx_add_test(torch_neg) pnnx_add_test(torch_pow) pnnx_add_test(torch_real) diff --git a/tools/pnnx/tests/ncnn/CMakeLists.txt b/tools/pnnx/tests/ncnn/CMakeLists.txt index d2f918ccb..be5c28c71 100644 --- a/tools/pnnx/tests/ncnn/CMakeLists.txt +++ b/tools/pnnx/tests/ncnn/CMakeLists.txt @@ -168,6 +168,7 @@ pnnx_ncnn_add_test(torch_cos) pnnx_ncnn_add_test(torch_exp) pnnx_ncnn_add_test(torch_floor) pnnx_ncnn_add_test(torch_log) +pnnx_ncnn_add_test(torch_log10) pnnx_ncnn_add_test(torch_neg) pnnx_ncnn_add_test(torch_pow) pnnx_ncnn_add_test(torch_reciprocal) diff --git a/tools/pnnx/tests/ncnn/test_torch_log10.py b/tools/pnnx/tests/ncnn/test_torch_log10.py new file mode 100644 index 000000000..20b8f0939 --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_torch_log10.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = torch.log10(x) + y = torch.log10(y) + z = torch.log10(z) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(3, 16) + y = torch.rand(5, 9, 11) + z = torch.rand(8, 5, 9, 10) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_torch_log10.pt") + + # torchscript to pnnx + import os + os.system("../../src/pnnx test_torch_log10.pt inputshape=[3,16],[5,9,11],[8,5,9,10]") + + # ncnn inference + import test_torch_log10_ncnn + b = test_torch_log10_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if not torch.allclose(a0, b0, 1e-4, 1e-4): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1) diff --git a/tools/pnnx/tests/test_torch_log10.py b/tools/pnnx/tests/test_torch_log10.py new file mode 100644 index 000000000..5fc2984b6 --- /dev/null +++ b/tools/pnnx/tests/test_torch_log10.py @@ -0,0 +1,61 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z): + x = torch.log10(x) + y = torch.log10(y) + z = torch.log10(z) + return x, y, z + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(1, 3, 16) + y = torch.rand(1, 5, 9, 11) + z = torch.rand(14, 8, 5, 9, 10) + + a = net(x, y, z) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_torch_log10.pt") + + # torchscript to pnnx + import os + os.system("../src/pnnx test_torch_log10.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]") + + # pnnx inference + import test_torch_log10_pnnx + b = test_torch_log10_pnnx.test_inference() + + for a0, b0 in zip(a, b): + if not torch.equal(a0, b0): + return False + return True + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1)