| @@ -1759,3 +1759,4 @@ Operation type: | |||
| - 14 = ATAN | |||
| - 15 = RECIPROCAL | |||
| - 16 = TANH | |||
| - 17 = LOG10 | |||
| @@ -360,6 +360,20 @@ struct unary_op_tanh | |||
| #endif // __ARM_NEON | |||
| }; | |||
| struct unary_op_log10 | |||
| { | |||
| float func(const float& x) const | |||
| { | |||
| return (float)log10(x); | |||
| } | |||
| #if __ARM_NEON | |||
| float32x4_t func_pack4(const float32x4_t& x) const | |||
| { | |||
| return vmulq_f32(log_ps(x), vdupq_n_f32(0.434294481903)); | |||
| } | |||
| #endif // __ARM_NEON | |||
| }; | |||
| } // namespace UnaryOp_arm_functor | |||
| int UnaryOp_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const | |||
| @@ -429,6 +443,9 @@ int UnaryOp_arm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const | |||
| if (op_type == Operation_TANH) | |||
| return unary_op_inplace<unary_op_tanh>(bottom_top_blob, opt); | |||
| if (op_type == Operation_LOG10) | |||
| return unary_op_inplace<unary_op_log10>(bottom_top_blob, opt); | |||
| return 0; | |||
| } | |||
| @@ -556,6 +573,9 @@ int UnaryOp_arm::forward_inplace_bf16s(Mat& bottom_top_blob, const Option& opt) | |||
| if (op_type == Operation_TANH) | |||
| return unary_op_inplace_bf16s<unary_op_tanh>(bottom_top_blob, opt); | |||
| if (op_type == Operation_LOG10) | |||
| return unary_op_inplace_bf16s<unary_op_log10>(bottom_top_blob, opt); | |||
| return 0; | |||
| } | |||
| #endif // NCNN_BF16 | |||
| @@ -436,6 +436,22 @@ struct unary_op_tanh_fp16s | |||
| } | |||
| }; | |||
| struct unary_op_log10_fp16s | |||
| { | |||
| __fp16 func(const __fp16& x) const | |||
| { | |||
| return (__fp16)log10(x); | |||
| } | |||
| float16x4_t func_pack4(const float16x4_t& x) const | |||
| { | |||
| return vmul_f16(log_ps(x), vdup_n_f16(0.434294481903)); | |||
| } | |||
| float16x8_t func_pack8(const float16x8_t& x) const | |||
| { | |||
| return vmulq_f16(log_ps(x), vdupq_n_f16(0.434294481903)); | |||
| } | |||
| }; | |||
| } // namespace UnaryOp_arm_functor | |||
| int UnaryOp_arm::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt) const | |||
| @@ -493,6 +509,9 @@ int UnaryOp_arm::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt) | |||
| if (op_type == Operation_TANH) | |||
| return unary_op_inplace_fp16s<unary_op_tanh_fp16s>(bottom_top_blob, opt); | |||
| if (op_type == Operation_LOG10) | |||
| return unary_op_inplace_fp16s<unary_op_log10_fp16s>(bottom_top_blob, opt); | |||
| return 0; | |||
| } | |||
| #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
| @@ -364,6 +364,20 @@ struct unary_op_tanh | |||
| #endif // __loongarch_sx | |||
| }; | |||
| struct unary_op_log10 | |||
| { | |||
| float func(const float& x) const | |||
| { | |||
| return (float)log10(x); | |||
| } | |||
| #if __loongarch_sx | |||
| __m128 func_pack4(const __m128& x) const | |||
| { | |||
| return __lsx_vfmul_s(log_ps(x), __lsx_vreplfr2vr_s(0.434294481903)); | |||
| } | |||
| #endif // __loongarch_sx | |||
| }; | |||
| } // namespace UnaryOp_loongarch_functor | |||
| int UnaryOp_loongarch::forward_inplace(Mat& bottom_top_blob, const Option& opt) const | |||
| @@ -421,6 +435,9 @@ int UnaryOp_loongarch::forward_inplace(Mat& bottom_top_blob, const Option& opt) | |||
| if (op_type == Operation_TANH) | |||
| return unary_op_inplace<unary_op_tanh>(bottom_top_blob, opt); | |||
| if (op_type == Operation_LOG10) | |||
| return unary_op_inplace<unary_op_log10>(bottom_top_blob, opt); | |||
| return 0; | |||
| } | |||
| @@ -374,6 +374,20 @@ struct unary_op_tanh | |||
| #endif // __mips_msa | |||
| }; | |||
| struct unary_op_log10 | |||
| { | |||
| float func(const float& x) const | |||
| { | |||
| return (float)log10(x); | |||
| } | |||
| #if __mips_msa | |||
| v4f32 func_pack4(const v4f32& x) const | |||
| { | |||
| return __msa_fmul_w(log_ps(x), __msa_fill_w_f32(0.434294481903)); | |||
| } | |||
| #endif // __mips_msa | |||
| }; | |||
| } // namespace UnaryOp_mips_functor | |||
| int UnaryOp_mips::forward_inplace(Mat& bottom_top_blob, const Option& opt) const | |||
| @@ -431,6 +445,9 @@ int UnaryOp_mips::forward_inplace(Mat& bottom_top_blob, const Option& opt) const | |||
| if (op_type == Operation_TANH) | |||
| return unary_op_inplace<unary_op_tanh>(bottom_top_blob, opt); | |||
| if (op_type == Operation_LOG10) | |||
| return unary_op_inplace<unary_op_log10>(bottom_top_blob, opt); | |||
| return 0; | |||
| } | |||
| @@ -245,6 +245,14 @@ struct unary_op_tanh | |||
| } | |||
| }; | |||
| struct unary_op_log10 | |||
| { | |||
| vfloat32m8_t operator()(const vfloat32m8_t& x, const size_t& vl) const | |||
| { | |||
| return vfmul_vf_f32m8(log_ps(x, vl), 0.434294481903, vl); | |||
| } | |||
| }; | |||
| } // namespace UnaryOp_riscv_functor | |||
| #endif // __riscv_vector | |||
| @@ -311,6 +319,9 @@ int UnaryOp_riscv::forward_inplace(Mat& bottom_top_blob, const Option& opt) cons | |||
| if (op_type == Operation_TANH) | |||
| return unary_op_inplace<unary_op_tanh>(bottom_top_blob, opt); | |||
| if (op_type == Operation_LOG10) | |||
| return unary_op_inplace<unary_op_log10>(bottom_top_blob, opt); | |||
| return 0; | |||
| #else // __riscv_vector | |||
| return UnaryOp::forward_inplace(bottom_top_blob, opt); | |||
| @@ -528,6 +539,14 @@ struct unary_op_tanh_fp16s | |||
| } | |||
| }; | |||
| struct unary_op_log10_fp16s | |||
| { | |||
| vfloat16m8_t operator()(const vfloat16m8_t& x, const size_t& vl) const | |||
| { | |||
| return vfmul_vf_f16m8(log_ps(x, vl), 0.434294481903, vl); | |||
| } | |||
| }; | |||
| } // namespace UnaryOp_riscv_functor | |||
| int UnaryOp_riscv::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt) const | |||
| @@ -585,6 +604,9 @@ int UnaryOp_riscv::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt | |||
| if (op_type == Operation_TANH) | |||
| return unary_op_inplace_fp16s<unary_op_tanh_fp16s>(bottom_top_blob, opt); | |||
| if (op_type == Operation_LOG10) | |||
| return unary_op_inplace_fp16s<unary_op_log10_fp16s>(bottom_top_blob, opt); | |||
| return 0; | |||
| } | |||
| #endif // __riscv_vector && __riscv_zfh | |||
| @@ -183,6 +183,14 @@ struct unary_op_tanh | |||
| } | |||
| }; | |||
| struct unary_op_log10 | |||
| { | |||
| float operator()(const float& x) const | |||
| { | |||
| return (float)log10(x); | |||
| } | |||
| }; | |||
| int UnaryOp::forward_inplace(Mat& bottom_top_blob, const Option& opt) const | |||
| { | |||
| if (op_type == Operation_ABS) | |||
| @@ -236,6 +244,9 @@ int UnaryOp::forward_inplace(Mat& bottom_top_blob, const Option& opt) const | |||
| if (op_type == Operation_TANH) | |||
| return unary_op_inplace<unary_op_tanh>(bottom_top_blob, opt); | |||
| if (op_type == Operation_LOG10) | |||
| return unary_op_inplace<unary_op_log10>(bottom_top_blob, opt); | |||
| return 0; | |||
| } | |||
| @@ -46,7 +46,8 @@ public: | |||
| Operation_ACOS = 13, | |||
| Operation_ATAN = 14, | |||
| Operation_RECIPROCAL = 15, | |||
| Operation_TANH = 16 | |||
| Operation_TANH = 16, | |||
| Operation_LOG10 = 17 | |||
| }; | |||
| public: | |||
| @@ -86,6 +86,7 @@ void main() | |||
| #else | |||
| if (op_type == 16) res = tanh(v); | |||
| #endif | |||
| if (op_type == 17) res = log(v) * afp(0.434294481903); | |||
| #if NCNN_image_shader | |||
| image3d_st1(top_blob_3d, ivec3(gx, gy, gz), res); | |||
| @@ -86,6 +86,7 @@ void main() | |||
| #else | |||
| if (op_type == 16) res = tanh(v); | |||
| #endif | |||
| if (op_type == 17) res = log(v) * afp(0.434294481903); | |||
| #if NCNN_image_shader | |||
| image3d_st4(top_blob_3d, ivec3(gx, gy, gz), res); | |||
| @@ -156,6 +156,11 @@ void main() | |||
| res[1] = tanh(v[1]); | |||
| #endif | |||
| } | |||
| if (op_type == 17) | |||
| { | |||
| res[0] = log(v[0]) * afp(0.434294481903); | |||
| res[1] = log(v[1]) * afp(0.434294481903); | |||
| } | |||
| #if NCNN_image_shader | |||
| image3d_st8(top_blob_3d, ivec3(gx, gy, gz), res); | |||
| @@ -686,6 +686,32 @@ struct unary_op_tanh | |||
| #endif // __SSE2__ | |||
| }; | |||
| struct unary_op_log10 | |||
| { | |||
| float func(const float& x) const | |||
| { | |||
| return (float)log10(x); | |||
| } | |||
| #if __SSE2__ | |||
| __m128 func_pack4(const __m128& x) const | |||
| { | |||
| return _mm_mul_ps(log_ps(x), _mm_set1_ps(0.434294481903)); | |||
| } | |||
| #if __AVX__ | |||
| __m256 func_pack8(const __m256& x) const | |||
| { | |||
| return _mm256_mul_ps(log256_ps(x), _mm256_set1_ps(0.434294481903)); | |||
| } | |||
| #if __AVX512F__ | |||
| __m512 func_pack16(const __m512& x) const | |||
| { | |||
| return _mm512_mul_ps(log512_ps(x), _mm512_set1_ps(0.434294481903)); | |||
| } | |||
| #endif // __AVX512F__ | |||
| #endif // __AVX__ | |||
| #endif // __SSE2__ | |||
| }; | |||
| } // namespace UnaryOp_x86_functor | |||
| int UnaryOp_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const | |||
| @@ -742,6 +768,9 @@ int UnaryOp_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const | |||
| if (op_type == Operation_TANH) | |||
| return unary_op_inplace<unary_op_tanh>(bottom_top_blob, opt); | |||
| if (op_type == Operation_LOG10) | |||
| return unary_op_inplace<unary_op_log10>(bottom_top_blob, opt); | |||
| return 0; | |||
| } | |||
| @@ -15,7 +15,7 @@ | |||
| #include "layer/unaryop.h" | |||
| #include "testutil.h" | |||
| #define OP_TYPE_MAX 17 | |||
| #define OP_TYPE_MAX 18 | |||
| static int op_type = 0; | |||
| @@ -30,7 +30,7 @@ static int test_unaryop(const ncnn::Mat& _a) | |||
| a[i] *= 1000; | |||
| } | |||
| } | |||
| if (op_type == 5 || op_type == 6 || op_type == 8) | |||
| if (op_type == 5 || op_type == 6 || op_type == 8 || op_type == 17) | |||
| { | |||
| // value must be positive for sqrt rsqrt log | |||
| Randomize(a, 0.001f, 2.f); | |||
| @@ -298,6 +298,7 @@ set(pnnx_pass_level4_SRCS | |||
| ) | |||
| set(pnnx_pass_level5_SRCS | |||
| pass_level5/attribute_unpooling.cpp | |||
| pass_level5/eliminate_dropout.cpp | |||
| pass_level5/eliminate_identity_operator.cpp | |||
| pass_level5/eliminate_maxpool_indices.cpp | |||
| @@ -1074,6 +1074,7 @@ static std::string expand_expression(const Operator* op) | |||
| || t == "exp" | |||
| || t == "floor" | |||
| || t == "log" | |||
| || t == "log10" | |||
| || t == "neg" | |||
| || t == "reciprocal" | |||
| || t == "rsqrt" | |||
| @@ -1101,6 +1102,7 @@ static std::string expand_expression(const Operator* op) | |||
| if (t == "exp") unaryop = "torch.exp"; | |||
| if (t == "floor") unaryop = "torch.floor"; | |||
| if (t == "log") unaryop = "torch.log"; | |||
| if (t == "log10") unaryop = "torch.log10"; | |||
| if (t == "neg") unaryop = "torch.neg"; | |||
| if (t == "reciprocal") unaryop = "torch.reciprocal"; | |||
| if (t == "rsqrt") unaryop = "torch.rsqrt"; | |||
| @@ -78,6 +78,7 @@ static bool operand_maybe_tensor(const Operand* operand) | |||
| || op->type == "aten::exp" | |||
| || op->type == "aten::floor" | |||
| || op->type == "aten::log" | |||
| || op->type == "aten::log10" | |||
| || op->type == "aten::neg" | |||
| || op->type == "aten::reciprocal" | |||
| || op->type == "aten::rsqrt" | |||
| @@ -286,6 +287,7 @@ static void fuse_expression(Graph& graph, Operand* operand, std::string& expr, s | |||
| || op->type == "aten::exp" | |||
| || op->type == "aten::floor" | |||
| || op->type == "aten::log" | |||
| || op->type == "aten::log10" | |||
| || op->type == "aten::neg" | |||
| || op->type == "aten::reciprocal" | |||
| || op->type == "aten::rsqrt" | |||
| @@ -464,6 +466,7 @@ void fuse_expression(Graph& graph, const std::set<std::string>& foldable_constan | |||
| || op->type == "aten::floor" | |||
| || op->type == "aten::floor_divide" | |||
| || op->type == "aten::log" | |||
| || op->type == "aten::log10" | |||
| || op->type == "aten::mul" | |||
| || op->type == "aten::neg" | |||
| || op->type == "aten::pow" | |||
| @@ -14,6 +14,7 @@ | |||
| #include "pass_level5.h" | |||
| #include "pass_level5/attribute_unpooling.h" | |||
| #include "pass_level5/fold_constants.h" | |||
| #include "pass_level5/eliminate_dropout.h" | |||
| #include "pass_level5/eliminate_identity_operator.h" | |||
| @@ -81,6 +82,8 @@ void pass_level5(Graph& g, const std::set<std::string>& foldable_constants, cons | |||
| fuse_slice_copy(g); | |||
| attribute_unpooling(g); | |||
| fuse_static_batchnorm(g); | |||
| fuse_static_groupnorm(g); | |||
| fuse_static_instancenorm(g); | |||
| @@ -0,0 +1,80 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "attribute_unpooling.h" | |||
| #include <algorithm> | |||
| namespace pnnx { | |||
| void attribute_unpooling(Graph& graph) | |||
| { | |||
| while (1) | |||
| { | |||
| bool matched = false; | |||
| for (size_t i = 0; i < graph.ops.size(); i++) | |||
| { | |||
| Operator* op = graph.ops[i]; | |||
| if (op->type != "pnnx.Attribute") | |||
| continue; | |||
| Operand* attr = op->outputs[0]; | |||
| if (attr->consumers.size() < 2) | |||
| continue; | |||
| // multiple modules share same weight | |||
| matched = true; | |||
| for (int i = 1; i < (int)attr->consumers.size(); i++) | |||
| { | |||
| Operator* x = attr->consumers[i]; | |||
| Operator* op2 = graph.new_operator_after("pnnx.Attribute", op->name + "_" + std::to_string(i), op); | |||
| op2->inputnames = op->inputnames; | |||
| op2->params = op->params; | |||
| op2->attrs = op->attrs; | |||
| Operand* attr2 = graph.new_operand(attr->name + "_" + std::to_string(i)); | |||
| attr2->type = attr->type; | |||
| attr2->shape = attr->shape; | |||
| attr2->params = attr->params; | |||
| op2->outputs.push_back(attr2); | |||
| attr2->producer = op2; | |||
| attr2->consumers.push_back(x); | |||
| for (size_t j = 0; j < x->inputs.size(); j++) | |||
| { | |||
| if (x->inputs[j] == attr) | |||
| x->inputs[j] = attr2; | |||
| } | |||
| } | |||
| attr->consumers.resize(1); | |||
| break; | |||
| } | |||
| if (!matched) | |||
| break; | |||
| } | |||
| } | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,21 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "ir.h" | |||
| namespace pnnx { | |||
| void attribute_unpooling(Graph& g); | |||
| } // namespace pnnx | |||
| @@ -153,6 +153,7 @@ static std::string eval_expression(const Operator* op) | |||
| || t == "exp" | |||
| || t == "floor" | |||
| || t == "log" | |||
| || t == "log10" | |||
| || t == "neg" | |||
| || t == "reciprocal" | |||
| || t == "rsqrt" | |||
| @@ -242,6 +243,11 @@ static std::string eval_expression(const Operator* op) | |||
| float r = log(af); | |||
| exprstack.push(std::to_string(r)); | |||
| } | |||
| if (t == "log10") | |||
| { | |||
| float r = log10(af); | |||
| exprstack.push(std::to_string(r)); | |||
| } | |||
| if (t == "neg") | |||
| { | |||
| float r = -af; | |||
| @@ -45,7 +45,7 @@ pnnx.Output output 1 0 out | |||
| { | |||
| const std::vector<int>& shape = captured_params.at("shape").ai; | |||
| const int batch_index = op->inputs[0]->params["__batch_index"].i; | |||
| const int batch_index = op->outputs[0]->params["__batch_index"].i; | |||
| if (batch_index != 0 && batch_index != 233) | |||
| { | |||
| @@ -45,7 +45,7 @@ pnnx.Output output 1 0 out | |||
| { | |||
| const std::vector<int>& shape = captured_params.at("shape").ai; | |||
| const int batch_index = op->inputs[0]->params["__batch_index"].i; | |||
| const int batch_index = op->outputs[0]->params["__batch_index"].i; | |||
| if (batch_index != 0 && batch_index != 233) | |||
| { | |||
| @@ -128,6 +128,7 @@ static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx | |||
| || t == "exp" | |||
| || t == "floor" | |||
| || t == "log" | |||
| || t == "log10" | |||
| || t == "neg" | |||
| || t == "reciprocal" | |||
| || t == "rsqrt" | |||
| @@ -154,6 +155,7 @@ static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx | |||
| if (t == "exp") op_unary->params["0"] = 7; | |||
| if (t == "floor") op_unary->params["0"] = 2; | |||
| if (t == "log") op_unary->params["0"] = 8; | |||
| if (t == "log10") op_unary->params["0"] = 17; | |||
| if (t == "neg") op_unary->params["0"] = 1; | |||
| if (t == "reciprocal") op_unary->params["0"] = 15; | |||
| if (t == "rsqrt") op_unary->params["0"] = 6; | |||
| @@ -222,6 +224,14 @@ static std::string expand_expression(Graph& graph, const Operator* op, int& pnnx | |||
| op_binary->params["1"] = 1; // with_scalar | |||
| op_binary->params["2"] = std::stof(b); | |||
| if (t == "pow" && std::stof(b) == 2) | |||
| { | |||
| // replace pow 2 with square | |||
| op_binary->type = "UnaryOp"; | |||
| op_binary->params.clear(); | |||
| op_binary->params["0"] = 4; | |||
| } | |||
| Operand* op_binary_out = graph.new_operand(op->name + "_" + r); | |||
| op_binary_out->producer = op_binary; | |||
| @@ -334,6 +334,16 @@ void solve_batch_index(Graph& graph) | |||
| } | |||
| } | |||
| // always treat 1-dim tensor as no batch axis | |||
| for (Operand* r : graph.operands) | |||
| { | |||
| if (r->shape.size() == 1) | |||
| { | |||
| fprintf(stderr, "force batch axis 233 for operand %s\n", r->name.c_str()); | |||
| r->params["__batch_index"] = 233; | |||
| } | |||
| } | |||
| // fallback axis 233 for unknown | |||
| for (Operand* r : graph.operands) | |||
| { | |||
| @@ -256,6 +256,7 @@ pnnx_add_test(torch_exp) | |||
| pnnx_add_test(torch_floor) | |||
| pnnx_add_test(torch_imag) | |||
| pnnx_add_test(torch_log) | |||
| pnnx_add_test(torch_log10) | |||
| pnnx_add_test(torch_neg) | |||
| pnnx_add_test(torch_pow) | |||
| pnnx_add_test(torch_real) | |||
| @@ -168,6 +168,7 @@ pnnx_ncnn_add_test(torch_cos) | |||
| pnnx_ncnn_add_test(torch_exp) | |||
| pnnx_ncnn_add_test(torch_floor) | |||
| pnnx_ncnn_add_test(torch_log) | |||
| pnnx_ncnn_add_test(torch_log10) | |||
| pnnx_ncnn_add_test(torch_neg) | |||
| pnnx_ncnn_add_test(torch_pow) | |||
| pnnx_ncnn_add_test(torch_reciprocal) | |||
| @@ -0,0 +1,61 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, x, y, z): | |||
| x = torch.log10(x) | |||
| y = torch.log10(y) | |||
| z = torch.log10(z) | |||
| return x, y, z | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(3, 16) | |||
| y = torch.rand(5, 9, 11) | |||
| z = torch.rand(8, 5, 9, 10) | |||
| a = net(x, y, z) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y, z)) | |||
| mod.save("test_torch_log10.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../../src/pnnx test_torch_log10.pt inputshape=[3,16],[5,9,11],[8,5,9,10]") | |||
| # ncnn inference | |||
| import test_torch_log10_ncnn | |||
| b = test_torch_log10_ncnn.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.allclose(a0, b0, 1e-4, 1e-4): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||
| @@ -0,0 +1,61 @@ | |||
| # Tencent is pleased to support the open source community by making ncnn available. | |||
| # | |||
| # Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | |||
| # | |||
| # Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| # in compliance with the License. You may obtain a copy of the License at | |||
| # | |||
| # https://opensource.org/licenses/BSD-3-Clause | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software distributed | |||
| # under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| # CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| # specific language governing permissions and limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Model(nn.Module): | |||
| def __init__(self): | |||
| super(Model, self).__init__() | |||
| def forward(self, x, y, z): | |||
| x = torch.log10(x) | |||
| y = torch.log10(y) | |||
| z = torch.log10(z) | |||
| return x, y, z | |||
| def test(): | |||
| net = Model() | |||
| net.eval() | |||
| torch.manual_seed(0) | |||
| x = torch.rand(1, 3, 16) | |||
| y = torch.rand(1, 5, 9, 11) | |||
| z = torch.rand(14, 8, 5, 9, 10) | |||
| a = net(x, y, z) | |||
| # export torchscript | |||
| mod = torch.jit.trace(net, (x, y, z)) | |||
| mod.save("test_torch_log10.pt") | |||
| # torchscript to pnnx | |||
| import os | |||
| os.system("../src/pnnx test_torch_log10.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]") | |||
| # pnnx inference | |||
| import test_torch_log10_pnnx | |||
| b = test_torch_log10_pnnx.test_inference() | |||
| for a0, b0 in zip(a, b): | |||
| if not torch.equal(a0, b0): | |||
| return False | |||
| return True | |||
| if __name__ == "__main__": | |||
| if test(): | |||
| exit(0) | |||
| else: | |||
| exit(1) | |||