Browse Source

pnnx convert torch log10, pow 2 as square (#4518)

tags/20230223
nihui GitHub 3 years ago
parent
commit
242e775d21
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 409 additions and 5 deletions
  1. +1
    -0
      docs/developer-guide/operators.md
  2. +20
    -0
      src/layer/arm/unaryop_arm.cpp
  3. +19
    -0
      src/layer/arm/unaryop_arm_asimdhp.cpp
  4. +17
    -0
      src/layer/loongarch/unaryop_loongarch.cpp
  5. +17
    -0
      src/layer/mips/unaryop_mips.cpp
  6. +22
    -0
      src/layer/riscv/unaryop_riscv.cpp
  7. +11
    -0
      src/layer/unaryop.cpp
  8. +2
    -1
      src/layer/unaryop.h
  9. +1
    -0
      src/layer/vulkan/shader/unaryop.comp
  10. +1
    -0
      src/layer/vulkan/shader/unaryop_pack4.comp
  11. +5
    -0
      src/layer/vulkan/shader/unaryop_pack8.comp
  12. +29
    -0
      src/layer/x86/unaryop_x86.cpp
  13. +2
    -2
      tests/test_unaryop.cpp
  14. +1
    -0
      tools/pnnx/src/CMakeLists.txt
  15. +2
    -0
      tools/pnnx/src/ir.cpp
  16. +3
    -0
      tools/pnnx/src/pass_level3/fuse_expression.cpp
  17. +3
    -0
      tools/pnnx/src/pass_level5.cpp
  18. +80
    -0
      tools/pnnx/src/pass_level5/attribute_unpooling.cpp
  19. +21
    -0
      tools/pnnx/src/pass_level5/attribute_unpooling.h
  20. +6
    -0
      tools/pnnx/src/pass_level5/eval_expression.cpp
  21. +1
    -1
      tools/pnnx/src/pass_ncnn/Tensor_reshape.cpp
  22. +1
    -1
      tools/pnnx/src/pass_ncnn/Tensor_view.cpp
  23. +10
    -0
      tools/pnnx/src/pass_ncnn/expand_expression.cpp
  24. +10
    -0
      tools/pnnx/src/pass_ncnn/solve_batch_index.cpp
  25. +1
    -0
      tools/pnnx/tests/CMakeLists.txt
  26. +1
    -0
      tools/pnnx/tests/ncnn/CMakeLists.txt
  27. +61
    -0
      tools/pnnx/tests/ncnn/test_torch_log10.py
  28. +61
    -0
      tools/pnnx/tests/test_torch_log10.py

+ 1
- 0
docs/developer-guide/operators.md View File

@@ -1759,3 +1759,4 @@ Operation type:
- 14 = ATAN
- 15 = RECIPROCAL
- 16 = TANH
- 17 = LOG10

+ 20
- 0
src/layer/arm/unaryop_arm.cpp View File

@@ -360,6 +360,20 @@ struct unary_op_tanh
#endif // __ARM_NEON
};

struct unary_op_log10
{
float func(const float& x) const
{
return (float)log10(x);
}
#if __ARM_NEON
float32x4_t func_pack4(const float32x4_t& x) const
{
return vmulq_f32(log_ps(x), vdupq_n_f32(0.434294481903));
}
#endif // __ARM_NEON
};

} // namespace UnaryOp_arm_functor

int UnaryOp_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
@@ -429,6 +443,9 @@ int UnaryOp_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
if (op_type == Operation_TANH)
return unary_op_inplace<unary_op_tanh>(bottom_top_blob, opt);

if (op_type == Operation_LOG10)
return unary_op_inplace<unary_op_log10>(bottom_top_blob, opt);

return 0;
}

@@ -556,6 +573,9 @@ int UnaryOp_arm::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt)
if (op_type == Operation_TANH)
return unary_op_inplace_bf16s<unary_op_tanh>(bottom_top_blob, opt);

if (op_type == Operation_LOG10)
return unary_op_inplace_bf16s<unary_op_log10>(bottom_top_blob, opt);

return 0;
}
#endif // NCNN_BF16


+ 19
- 0
src/layer/arm/unaryop_arm_asimdhp.cpp View File

@@ -436,6 +436,22 @@ struct unary_op_tanh_fp16s
}
};

struct unary_op_log10_fp16s
{
__fp16 func(const __fp16& x) const
{
return (__fp16)log10(x);
}
float16x4_t func_pack4(const float16x4_t& x) const
{
return vmul_f16(log_ps(x), vdup_n_f16(0.434294481903));
}
float16x8_t func_pack8(const float16x8_t& x) const
{
return vmulq_f16(log_ps(x), vdupq_n_f16(0.434294481903));
}
};

} // namespace UnaryOp_arm_functor

int UnaryOp_arm::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt) const
@@ -493,6 +509,9 @@ int UnaryOp_arm::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt)
if (op_type == Operation_TANH)
return unary_op_inplace_fp16s<unary_op_tanh_fp16s>(bottom_top_blob, opt);

if (op_type == Operation_LOG10)
return unary_op_inplace_fp16s<unary_op_log10_fp16s>(bottom_top_blob, opt);

return 0;
}
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC


+ 17
- 0
src/layer/loongarch/unaryop_loongarch.cpp View File

@@ -364,6 +364,20 @@ struct unary_op_tanh
#endif // __loongarch_sx
};

struct unary_op_log10
{
float func(const float& x) const
{
return (float)log10(x);
}
#if __loongarch_sx
__m128 func_pack4(const __m128& x) const
{
return __lsx_vfmul_s(log_ps(x), __lsx_vreplfr2vr_s(0.434294481903));
}
#endif // __loongarch_sx
};

} // namespace UnaryOp_loongarch_functor

int UnaryOp_loongarch::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
@@ -421,6 +435,9 @@ int UnaryOp_loongarch::forward_inplace(Mat& bottom_top_blob, const Option& opt)
if (op_type == Operation_TANH)
return unary_op_inplace<unary_op_tanh>(bottom_top_blob, opt);

if (op_type == Operation_LOG10)
return unary_op_inplace<unary_op_log10>(bottom_top_blob, opt);

return 0;
}



+ 17
- 0
src/layer/mips/unaryop_mips.cpp View File

@@ -374,6 +374,20 @@ struct unary_op_tanh
#endif // __mips_msa
};

struct unary_op_log10
{
float func(const float& x) const
{
return (float)log10(x);
}
#if __mips_msa
v4f32 func_pack4(const v4f32& x) const
{
return __msa_fmul_w(log_ps(x), __msa_fill_w_f32(0.434294481903));
}
#endif // __mips_msa
};

} // namespace UnaryOp_mips_functor

int UnaryOp_mips::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
@@ -431,6 +445,9 @@ int UnaryOp_mips::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
if (op_type == Operation_TANH)
return unary_op_inplace<unary_op_tanh>(bottom_top_blob, opt);

if (op_type == Operation_LOG10)
return unary_op_inplace<unary_op_log10>(bottom_top_blob, opt);

return 0;
}



+ 22
- 0
src/layer/riscv/unaryop_riscv.cpp View File

@@ -245,6 +245,14 @@ struct unary_op_tanh
}
};

struct unary_op_log10
{
vfloat32m8_t operator()(const vfloat32m8_t& x, const size_t& vl) const
{
return vfmul_vf_f32m8(log_ps(x, vl), 0.434294481903, vl);
}
};

} // namespace UnaryOp_riscv_functor
#endif // __riscv_vector

@@ -311,6 +319,9 @@ int UnaryOp_riscv::forward_inplace(Mat& bottom_top_blob, const Option& opt) cons
if (op_type == Operation_TANH)
return unary_op_inplace<unary_op_tanh>(bottom_top_blob, opt);

if (op_type == Operation_LOG10)
return unary_op_inplace<unary_op_log10>(bottom_top_blob, opt);

return 0;
#else // __riscv_vector
return UnaryOp::forward_inplace(bottom_top_blob, opt);
@@ -528,6 +539,14 @@ struct unary_op_tanh_fp16s
}
};

struct unary_op_log10_fp16s
{
vfloat16m8_t operator()(const vfloat16m8_t& x, const size_t& vl) const
{
return vfmul_vf_f16m8(log_ps(x, vl), 0.434294481903, vl);
}
};

} // namespace UnaryOp_riscv_functor

int UnaryOp_riscv::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt) const
@@ -585,6 +604,9 @@ int UnaryOp_riscv::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt
if (op_type == Operation_TANH)
return unary_op_inplace_fp16s<unary_op_tanh_fp16s>(bottom_top_blob, opt);

if (op_type == Operation_LOG10)
return unary_op_inplace_fp16s<unary_op_log10_fp16s>(bottom_top_blob, opt);

return 0;
}
#endif // __riscv_vector && __riscv_zfh


+ 11
- 0
src/layer/unaryop.cpp View File

@@ -183,6 +183,14 @@ struct unary_op_tanh
}
};

struct unary_op_log10
{
float operator()(const float& x) const
{
return (float)log10(x);
}
};

int UnaryOp::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
{
if (op_type == Operation_ABS)
@@ -236,6 +244,9 @@ int UnaryOp::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
if (op_type == Operation_TANH)
return unary_op_inplace<unary_op_tanh>(bottom_top_blob, opt);

if (op_type == Operation_LOG10)
return unary_op_inplace<unary_op_log10>(bottom_top_blob, opt);

return 0;
}



+ 2
- 1
src/layer/unaryop.h View File

@@ -46,7 +46,8 @@ public:
Operation_ACOS = 13,
Operation_ATAN = 14,
Operation_RECIPROCAL = 15,
Operation_TANH = 16
Operation_TANH = 16,
Operation_LOG10 = 17
};

public:


+ 1
- 0
src/layer/vulkan/shader/unaryop.comp View File

@@ -86,6 +86,7 @@ void main()
#else
if (op_type == 16) res = tanh(v);
#endif
if (op_type == 17) res = log(v) * afp(0.434294481903);

#if NCNN_image_shader
image3d_st1(top_blob_3d, ivec3(gx, gy, gz), res);


+ 1
- 0
src/layer/vulkan/shader/unaryop_pack4.comp View File

@@ -86,6 +86,7 @@ void main()
#else
if (op_type == 16) res = tanh(v);
#endif
if (op_type == 17) res = log(v) * afp(0.434294481903);

#if NCNN_image_shader
image3d_st4(top_blob_3d, ivec3(gx, gy, gz), res);


+ 5
- 0
src/layer/vulkan/shader/unaryop_pack8.comp View File

@@ -156,6 +156,11 @@ void main()
res[1] = tanh(v[1]);
#endif
}
if (op_type == 17)
{
res[0] = log(v[0]) * afp(0.434294481903);
res[1] = log(v[1]) * afp(0.434294481903);
}

#if NCNN_image_shader
image3d_st8(top_blob_3d, ivec3(gx, gy, gz), res);


+ 29
- 0
src/layer/x86/unaryop_x86.cpp View File

@@ -686,6 +686,32 @@ struct unary_op_tanh
#endif // __SSE2__
};

struct unary_op_log10
{
float func(const float& x) const
{
return (float)log10(x);
}
#if __SSE2__
__m128 func_pack4(const __m128& x) const
{
return _mm_mul_ps(log_ps(x), _mm_set1_ps(0.434294481903));
}
#if __AVX__
__m256 func_pack8(const __m256& x) const
{
return _mm256_mul_ps(log256_ps(x), _mm256_set1_ps(0.434294481903));
}
#if __AVX512F__
__m512 func_pack16(const __m512& x) const
{
return _mm512_mul_ps(log512_ps(x), _mm512_set1_ps(0.434294481903));
}
#endif // __AVX512F__
#endif // __AVX__
#endif // __SSE2__
};

} // namespace UnaryOp_x86_functor

int UnaryOp_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
@@ -742,6 +768,9 @@ int UnaryOp_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
if (op_type == Operation_TANH)
return unary_op_inplace<unary_op_tanh>(bottom_top_blob, opt);

if (op_type == Operation_LOG10)
return unary_op_inplace<unary_op_log10>(bottom_top_blob, opt);

return 0;
}



+ 2
- 2
tests/test_unaryop.cpp View File

@@ -15,7 +15,7 @@
#include "layer/unaryop.h"
#include "testutil.h"

#define OP_TYPE_MAX 17
#define OP_TYPE_MAX 18

static int op_type = 0;

@@ -30,7 +30,7 @@ static int test_unaryop(const ncnn::Mat& _a)
a[i] *= 1000;
}
}
if (op_type == 5 || op_type == 6 || op_type == 8)
if (op_type == 5 || op_type == 6 || op_type == 8 || op_type == 17)
{
// value must be positive for sqrt rsqrt log
Randomize(a, 0.001f, 2.f);


+ 1
- 0
tools/pnnx/src/CMakeLists.txt View File

@@ -298,6 +298,7 @@ set(pnnx_pass_level4_SRCS
)

set(pnnx_pass_level5_SRCS
pass_level5/attribute_unpooling.cpp
pass_level5/eliminate_dropout.cpp
pass_level5/eliminate_identity_operator.cpp
pass_level5/eliminate_maxpool_indices.cpp


+ 2
- 0
tools/pnnx/src/ir.cpp View File

@@ -1074,6 +1074,7 @@ static std::string expand_expression(const Operator* op)
|| t == "exp"
|| t == "floor"
|| t == "log"
|| t == "log10"
|| t == "neg"
|| t == "reciprocal"
|| t == "rsqrt"
@@ -1101,6 +1102,7 @@ static std::string expand_expression(const Operator* op)
if (t == "exp") unaryop = "torch.exp";
if (t == "floor") unaryop = "torch.floor";
if (t == "log") unaryop = "torch.log";
if (t == "log10") unaryop = "torch.log10";
if (t == "neg") unaryop = "torch.neg";
if (t == "reciprocal") unaryop = "torch.reciprocal";
if (t == "rsqrt") unaryop = "torch.rsqrt";


+ 3
- 0
tools/pnnx/src/pass_level3/fuse_expression.cpp View File

@@ -78,6 +78,7 @@ static bool operand_maybe_tensor(const Operand* operand)
|| op->type == "aten::exp"
|| op->type == "aten::floor"
|| op->type == "aten::log"
|| op->type == "aten::log10"
|| op->type == "aten::neg"
|| op->type == "aten::reciprocal"
|| op->type == "aten::rsqrt"
@@ -286,6 +287,7 @@ static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, s
|| op->type == "aten::exp"
|| op->type == "aten::floor"
|| op->type == "aten::log"
|| op->type == "aten::log10"
|| op->type == "aten::neg"
|| op->type == "aten::reciprocal"
|| op->type == "aten::rsqrt"
@@ -464,6 +466,7 @@ void fuse_expression(Graph& graph, const std::set<std::string>& foldable_constan
|| op->type == "aten::floor"
|| op->type == "aten::floor_divide"
|| op->type == "aten::log"
|| op->type == "aten::log10"
|| op->type == "aten::mul"
|| op->type == "aten::neg"
|| op->type == "aten::pow"


+ 3
- 0
tools/pnnx/src/pass_level5.cpp View File

@@ -14,6 +14,7 @@

#include "pass_level5.h"

#include "pass_level5/attribute_unpooling.h"
#include "pass_level5/fold_constants.h"
#include "pass_level5/eliminate_dropout.h"
#include "pass_level5/eliminate_identity_operator.h"
@@ -81,6 +82,8 @@ void pass_level5(Graph& g, const std::set<std::string>& foldable_constants, cons

fuse_slice_copy(g);

attribute_unpooling(g);

fuse_static_batchnorm(g);
fuse_static_groupnorm(g);
fuse_static_instancenorm(g);


+ 80
- 0
tools/pnnx/src/pass_level5/attribute_unpooling.cpp View File

@@ -0,0 +1,80 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "attribute_unpooling.h"

#include <algorithm>

namespace pnnx {

void attribute_unpooling(Graph& graph)
{
while (1)
{
bool matched = false;

for (size_t i = 0; i < graph.ops.size(); i++)
{
Operator* op = graph.ops[i];

if (op->type != "pnnx.Attribute")
continue;

Operand* attr = op->outputs[0];

if (attr->consumers.size() < 2)
continue;

// multiple modules share same weight
matched = true;

for (int i = 1; i < (int)attr->consumers.size(); i++)
{
Operator* x = attr->consumers[i];

Operator* op2 = graph.new_operator_after("pnnx.Attribute", op->name + "_" + std::to_string(i), op);

op2->inputnames = op->inputnames;
op2->params = op->params;
op2->attrs = op->attrs;

Operand* attr2 = graph.new_operand(attr->name + "_" + std::to_string(i));

attr2->type = attr->type;
attr2->shape = attr->shape;
attr2->params = attr->params;

op2->outputs.push_back(attr2);

attr2->producer = op2;
attr2->consumers.push_back(x);

for (size_t j = 0; j < x->inputs.size(); j++)
{
if (x->inputs[j] == attr)
x->inputs[j] = attr2;
}
}

attr->consumers.resize(1);

break;
}

if (!matched)
break;
}
}

} // namespace pnnx

+ 21
- 0
tools/pnnx/src/pass_level5/attribute_unpooling.h View File

@@ -0,0 +1,21 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "ir.h"

namespace pnnx {

void attribute_unpooling(Graph& g);

} // namespace pnnx

+ 6
- 0
tools/pnnx/src/pass_level5/eval_expression.cpp View File

@@ -153,6 +153,7 @@ static std::string eval_expression(const Operator* op)
|| t == "exp"
|| t == "floor"
|| t == "log"
|| t == "log10"
|| t == "neg"
|| t == "reciprocal"
|| t == "rsqrt"
@@ -242,6 +243,11 @@ static std::string eval_expression(const Operator* op)
float r = log(af);
exprstack.push(std::to_string(r));
}
if (t == "log10")
{
float r = log10(af);
exprstack.push(std::to_string(r));
}
if (t == "neg")
{
float r = -af;


+ 1
- 1
tools/pnnx/src/pass_ncnn/Tensor_reshape.cpp View File

@@ -45,7 +45,7 @@ pnnx.Output output 1 0 out
{
const std::vector<int>& shape = captured_params.at("shape").ai;

const int batch_index = op->inputs[0]->params["__batch_index"].i;
const int batch_index = op->outputs[0]->params["__batch_index"].i;

if (batch_index != 0 && batch_index != 233)
{


+ 1
- 1
tools/pnnx/src/pass_ncnn/Tensor_view.cpp View File

@@ -45,7 +45,7 @@ pnnx.Output output 1 0 out
{
const std::vector<int>& shape = captured_params.at("shape").ai;

const int batch_index = op->inputs[0]->params["__batch_index"].i;
const int batch_index = op->outputs[0]->params["__batch_index"].i;

if (batch_index != 0 && batch_index != 233)
{


+ 10
- 0
tools/pnnx/src/pass_ncnn/expand_expression.cpp View File

@@ -128,6 +128,7 @@ static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx
|| t == "exp"
|| t == "floor"
|| t == "log"
|| t == "log10"
|| t == "neg"
|| t == "reciprocal"
|| t == "rsqrt"
@@ -154,6 +155,7 @@ static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx
if (t == "exp") op_unary->params["0"] = 7;
if (t == "floor") op_unary->params["0"] = 2;
if (t == "log") op_unary->params["0"] = 8;
if (t == "log10") op_unary->params["0"] = 17;
if (t == "neg") op_unary->params["0"] = 1;
if (t == "reciprocal") op_unary->params["0"] = 15;
if (t == "rsqrt") op_unary->params["0"] = 6;
@@ -222,6 +224,14 @@ static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx
op_binary->params["1"] = 1; // with_scalar
op_binary->params["2"] = std::stof(b);

if (t == "pow" && std::stof(b) == 2)
{
// replace pow 2 with square
op_binary->type = "UnaryOp";
op_binary->params.clear();
op_binary->params["0"] = 4;
}

Operand* op_binary_out = graph.new_operand(op->name + "_" + r);
op_binary_out->producer = op_binary;



+ 10
- 0
tools/pnnx/src/pass_ncnn/solve_batch_index.cpp View File

@@ -334,6 +334,16 @@ void solve_batch_index(Graph& graph)
}
}

// always treat 1-dim tensor as no batch axis
for (Operand* r : graph.operands)
{
if (r->shape.size() == 1)
{
fprintf(stderr, "force batch axis 233 for operand %s\n", r->name.c_str());
r->params["__batch_index"] = 233;
}
}

// fallback axis 233 for unknown
for (Operand* r : graph.operands)
{


+ 1
- 0
tools/pnnx/tests/CMakeLists.txt View File

@@ -256,6 +256,7 @@ pnnx_add_test(torch_exp)
pnnx_add_test(torch_floor)
pnnx_add_test(torch_imag)
pnnx_add_test(torch_log)
pnnx_add_test(torch_log10)
pnnx_add_test(torch_neg)
pnnx_add_test(torch_pow)
pnnx_add_test(torch_real)


+ 1
- 0
tools/pnnx/tests/ncnn/CMakeLists.txt View File

@@ -168,6 +168,7 @@ pnnx_ncnn_add_test(torch_cos)
pnnx_ncnn_add_test(torch_exp)
pnnx_ncnn_add_test(torch_floor)
pnnx_ncnn_add_test(torch_log)
pnnx_ncnn_add_test(torch_log10)
pnnx_ncnn_add_test(torch_neg)
pnnx_ncnn_add_test(torch_pow)
pnnx_ncnn_add_test(torch_reciprocal)


+ 61
- 0
tools/pnnx/tests/ncnn/test_torch_log10.py View File

@@ -0,0 +1,61 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z):
x = torch.log10(x)
y = torch.log10(y)
z = torch.log10(z)
return x, y, z

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(3, 16)
y = torch.rand(5, 9, 11)
z = torch.rand(8, 5, 9, 10)

a = net(x, y, z)

# export torchscript
mod = torch.jit.trace(net, (x, y, z))
mod.save("test_torch_log10.pt")

# torchscript to pnnx
import os
os.system("../../src/pnnx test_torch_log10.pt inputshape=[3,16],[5,9,11],[8,5,9,10]")

# ncnn inference
import test_torch_log10_ncnn
b = test_torch_log10_ncnn.test_inference()

for a0, b0 in zip(a, b):
if not torch.allclose(a0, b0, 1e-4, 1e-4):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

+ 61
- 0
tools/pnnx/tests/test_torch_log10.py View File

@@ -0,0 +1,61 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z):
x = torch.log10(x)
y = torch.log10(y)
z = torch.log10(z)
return x, y, z

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(1, 3, 16)
y = torch.rand(1, 5, 9, 11)
z = torch.rand(14, 8, 5, 9, 10)

a = net(x, y, z)

# export torchscript
mod = torch.jit.trace(net, (x, y, z))
mod.save("test_torch_log10.pt")

# torchscript to pnnx
import os
os.system("../src/pnnx test_torch_log10.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]")

# pnnx inference
import test_torch_log10_pnnx
b = test_torch_log10_pnnx.test_inference()

for a0, b0 in zip(a, b):
if not torch.equal(a0, b0):
return False
return True

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

Loading…
Cancel
Save