| @@ -11,6 +11,8 @@ set(pnnx_pass_level0_SRCS | |||
| ) | |||
| set(pnnx_pass_level1_SRCS | |||
| pass_level1/fuse_module_pass.cpp | |||
| pass_level1/nn_AdaptiveAvgPool1d.cpp | |||
| pass_level1/nn_AdaptiveAvgPool2d.cpp | |||
| pass_level1/nn_AdaptiveAvgPool3d.cpp | |||
| @@ -688,20 +690,20 @@ if(onnxruntime_FOUND) | |||
| pass_onnx/shape_inference.cpp | |||
| pass_onnx/fuse_constant_as_attribute.cpp | |||
| pass_onnx/nn_AdaptiveAvgPool2d.cpp | |||
| pass_onnx/nn_AdaptiveAvgPool3d.cpp | |||
| pass_onnx/nn_AvgPool2d.cpp | |||
| pass_onnx/nn_AvgPool3d.cpp | |||
| pass_onnx/nn_BatchNorm2d.cpp | |||
| pass_onnx/nn_BatchNorm3d.cpp | |||
| pass_onnx/nn_Conv2d.cpp | |||
| pass_onnx/nn_Conv3d.cpp | |||
| pass_onnx/nn_GELU.cpp | |||
| pass_onnx/nn_LayerNorm.cpp | |||
| pass_onnx/nn_Linear.cpp | |||
| pass_onnx/nn_MaxPool2d.cpp | |||
| pass_onnx/nn_MaxPool3d.cpp | |||
| pass_onnx/nn_MultiheadAttention.cpp | |||
| # pass_onnx/nn_AdaptiveAvgPool2d.cpp | |||
| # pass_onnx/nn_AdaptiveAvgPool3d.cpp | |||
| # pass_onnx/nn_AvgPool2d.cpp | |||
| # pass_onnx/nn_AvgPool3d.cpp | |||
| # pass_onnx/nn_BatchNorm2d.cpp | |||
| # pass_onnx/nn_BatchNorm3d.cpp | |||
| # pass_onnx/nn_Conv2d.cpp | |||
| # pass_onnx/nn_Conv3d.cpp | |||
| # pass_onnx/nn_GELU.cpp | |||
| # pass_onnx/nn_LayerNorm.cpp | |||
| # pass_onnx/nn_Linear.cpp | |||
| # pass_onnx/nn_MaxPool2d.cpp | |||
| # pass_onnx/nn_MaxPool3d.cpp | |||
| # pass_onnx/nn_MultiheadAttention.cpp | |||
| ) | |||
| set(onnx2pnnx_SRCS | |||
| @@ -34,6 +34,9 @@ struct Node; | |||
| namespace at { | |||
| class Tensor; | |||
| } | |||
| namespace pnnx { | |||
| class TorchTensorProxy; | |||
| } // namespace pnnx | |||
| #endif // BUILD_TORCH2PNNX | |||
| #if BUILD_ONNX2PNNX | |||
| @@ -230,6 +233,7 @@ public: | |||
| #if BUILD_TORCH2PNNX | |||
| Attribute(const at::Tensor& t); | |||
| Attribute(const TorchTensorProxy& t); | |||
| #endif | |||
| #if BUILD_ONNX2PNNX | |||
| Attribute(const onnx::TensorProto& t); | |||
| @@ -31,6 +31,7 @@ int64_t cuda_version(); | |||
| #include "pass_level0.h" | |||
| #include "pass_level1.h" | |||
| #include "pass_level1/fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -372,6 +373,11 @@ Attribute::Attribute(const at::Tensor& t) | |||
| } | |||
| } | |||
| Attribute::Attribute(const TorchTensorProxy& t) | |||
| : Attribute(t.t()) | |||
| { | |||
| } | |||
| Operand* Graph::new_operand(const torch::jit::Value* v) | |||
| { | |||
| // Operand* r = new Operand; | |||
| @@ -442,17 +448,6 @@ static const char* get_at_tensor_type_str(const at::ScalarType& st) | |||
| return ""; | |||
| } | |||
| const torch::jit::Node* find_node_by_kind(const std::shared_ptr<torch::jit::Graph>& graph, const std::string& kind) | |||
| { | |||
| for (const auto& n : graph->nodes()) | |||
| { | |||
| if (n->kind().toDisplayString() == kind) | |||
| return n; | |||
| } | |||
| return 0; | |||
| } | |||
| static void print_shape_list(const std::vector<std::vector<int64_t> >& shapes, const std::vector<std::string>& types) | |||
| { | |||
| for (size_t i = 0; i < shapes.size(); i++) | |||
| @@ -508,7 +503,7 @@ static void get_traced_input_shape(const std::string& ptpath, std::vector<std::v | |||
| { | |||
| // read traced_inputs.pkl | |||
| caffe2::serialize::PyTorchStreamReader reader(ptpath); | |||
| auto v = torch::jit::readArchiveAndTensors("traced_inputs", "", "traced_inputs/", std::nullopt, std::nullopt, std::nullopt, reader); | |||
| auto v = torch::jit::readArchiveAndTensors("traced_inputs", "", "traced_inputs/", c10::nullopt, c10::nullopt, c10::nullopt, reader); | |||
| if (!v.isGenericDict()) | |||
| return; | |||
| @@ -13,7 +13,7 @@ | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "inline_block.h" | |||
| #include "../pass_level1.h" | |||
| #include "../pass_level1/fuse_module_pass.h" | |||
| #include <set> | |||
| @@ -12,43 +12,16 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include <torch/csrc/jit/passes/quantization/helper.h> | |||
| #include <torch/script.h> | |||
| #include <torch/csrc/jit/api/module.h> | |||
| #include <torch/csrc/api/include/torch/version.h> | |||
| #include <torch/csrc/jit/passes/quantization/helper.h> | |||
| #include "pass_level1.h" | |||
| namespace pnnx { | |||
| FuseModulePass::~FuseModulePass() | |||
| { | |||
| } | |||
| void FuseModulePass::write(Operator* /*op*/, const std::shared_ptr<torch::jit::Graph>& /*graph*/) const | |||
| { | |||
| } | |||
| void FuseModulePass::write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& /*mod*/) const | |||
| { | |||
| write(op, graph); | |||
| } | |||
| static std::vector<const FuseModulePass*> g_global_pnnx_fuse_module_passes; | |||
| const std::vector<const FuseModulePass*>& get_global_pnnx_fuse_module_passes() | |||
| { | |||
| return g_global_pnnx_fuse_module_passes; | |||
| } | |||
| #include "pass_level1/fuse_module_pass.h" | |||
| FuseModulePassRegister::FuseModulePassRegister(const FuseModulePass* _pass) | |||
| : pass(_pass) | |||
| { | |||
| g_global_pnnx_fuse_module_passes.push_back(pass); | |||
| } | |||
| FuseModulePassRegister::~FuseModulePassRegister() | |||
| { | |||
| delete pass; | |||
| } | |||
| namespace pnnx { | |||
| static void fuse_moduleop_unpack(Graph& graph, const std::vector<std::string>& module_operators) | |||
| { | |||
| @@ -399,10 +372,12 @@ void pass_level1(const torch::jit::Module& mod, const std::shared_ptr<torch::jit | |||
| op->name = wrapped_name; | |||
| #if TORCH_VERSION_MAJOR >= 2 || (TORCH_VERSION_MAJOR >= 1 && TORCH_VERSION_MINOR >= 11) | |||
| ow->write(op, toGraphFunction(function).graph(), sub_mod); | |||
| TorchGraphProxy graph_proxy(toGraphFunction(function).graph()); | |||
| #else | |||
| ow->write(op, function.graph(), sub_mod); | |||
| TorchGraphProxy graph_proxy(function.graph()); | |||
| #endif | |||
| TorchModuleProxy sub_mod_proxy(sub_mod); | |||
| ow->write(op, graph_proxy, sub_mod_proxy); | |||
| break; | |||
| } | |||
| @@ -15,39 +15,10 @@ | |||
| #ifndef PNNX_PASS_LEVEL1_H | |||
| #define PNNX_PASS_LEVEL1_H | |||
| #include <torch/script.h> | |||
| #include <torch/csrc/jit/api/module.h> | |||
| #include "ir.h" | |||
| namespace pnnx { | |||
| class FuseModulePass | |||
| { | |||
| public: | |||
| virtual ~FuseModulePass(); | |||
| virtual const char* match_type_str() const = 0; | |||
| virtual const char* type_str() const = 0; | |||
| virtual void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const; | |||
| virtual void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const; | |||
| }; | |||
| class FuseModulePassRegister | |||
| { | |||
| public: | |||
| FuseModulePassRegister(const FuseModulePass* pass); | |||
| ~FuseModulePassRegister(); | |||
| const FuseModulePass* pass; | |||
| }; | |||
| const std::vector<const FuseModulePass*>& get_global_pnnx_fuse_module_passes(); | |||
| #define REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(CLASS) \ | |||
| static FuseModulePassRegister g_global_pnnx_fusemodulepass_##CLASS##_register(new CLASS); | |||
| void pass_level1(const torch::jit::Module& mod, const std::shared_ptr<torch::jit::Graph>& g, const std::vector<std::string>& module_operators, Graph& pg); | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,223 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "fuse_module_pass.h" | |||
| #include <torch/script.h> | |||
| #include <torch/csrc/jit/api/module.h> | |||
| #include <torch/csrc/jit/passes/quantization/helper.h> | |||
| namespace pnnx { | |||
| std::string TorchNodeProxy::kind() const | |||
| { | |||
| return node->kind().toDisplayString(); | |||
| } | |||
| bool TorchNodeProxy::hasNamedInput(const std::string& name) const | |||
| { | |||
| return node->hasNamedInput(name); | |||
| } | |||
| const torch::jit::Value* TorchNodeProxy::namedInput(const std::string& name) const | |||
| { | |||
| return node->namedInput(name); | |||
| } | |||
| int TorchNodeProxy::input_count() const | |||
| { | |||
| return node->inputs().size(); | |||
| } | |||
| const torch::jit::Value* TorchNodeProxy::input(int i) const | |||
| { | |||
| return node->input(i); | |||
| } | |||
| int TorchNodeProxy::output_count() const | |||
| { | |||
| return node->outputs().size(); | |||
| } | |||
| const torch::jit::Value* TorchNodeProxy::output(int i) const | |||
| { | |||
| return node->output(i); | |||
| } | |||
| bool TorchNodeProxy::is_input_none(int i) const | |||
| { | |||
| return node->input(i)->type()->kind() == c10::TypeKind::NoneType; | |||
| } | |||
| TorchGraphProxy::TorchGraphProxy(const std::shared_ptr<torch::jit::Graph> _graph) | |||
| : graph(_graph) | |||
| { | |||
| for (const auto& n : graph->nodes()) | |||
| { | |||
| nodes.push_back(n); | |||
| } | |||
| } | |||
| const TorchNodeProxy* TorchGraphProxy::find_node_by_kind(const std::string& kind) const | |||
| { | |||
| for (const auto& n : nodes) | |||
| { | |||
| if (n.node->kind().toDisplayString() == kind) | |||
| return &n; | |||
| } | |||
| return 0; | |||
| } | |||
| const TorchNodeProxy* TorchGraphProxy::find_producer_node_by_value(const torch::jit::Value* value) const | |||
| { | |||
| for (const auto& n : nodes) | |||
| { | |||
| if (n.node == value->node()) | |||
| return &n; | |||
| } | |||
| fprintf(stderr, "TorchGraphProxy find_producer_node_by_value failed\n"); | |||
| return 0; | |||
| } | |||
| int TorchGraphProxy::input_count() const | |||
| { | |||
| return std::as_const(*graph).inputs().size(); | |||
| } | |||
| const torch::jit::Value* TorchGraphProxy::input(int i) const | |||
| { | |||
| return std::as_const(*graph).inputs()[i]; | |||
| } | |||
| int TorchGraphProxy::output_count() const | |||
| { | |||
| return std::as_const(*graph).outputs().size(); | |||
| } | |||
| const torch::jit::Value* TorchGraphProxy::output(int i) const | |||
| { | |||
| return std::as_const(*graph).outputs()[i]; | |||
| } | |||
| void TorchGraphProxy::dump() const | |||
| { | |||
| graph->dump(); | |||
| } | |||
| class TorchTensorProxyPrivate | |||
| { | |||
| public: | |||
| at::Tensor t; | |||
| }; | |||
| TorchTensorProxy::TorchTensorProxy(const at::Tensor& _t) | |||
| : d(new TorchTensorProxyPrivate) | |||
| { | |||
| d->t = _t; | |||
| } | |||
| TorchTensorProxy::~TorchTensorProxy() | |||
| { | |||
| delete d; | |||
| } | |||
| const at::Tensor& TorchTensorProxy::t() const | |||
| { | |||
| return d->t; | |||
| } | |||
| int TorchTensorProxy::size(size_t i) const | |||
| { | |||
| return d->t.size(i); | |||
| } | |||
| TorchModuleProxy::TorchModuleProxy(const torch::jit::Module& _mod) | |||
| : mod(_mod) | |||
| { | |||
| const std::vector<c10::ClassAttribute>& attributes = mod._ivalue()->type()->getAttributes(); | |||
| for (size_t i = 0; i < attributes.size(); i++) | |||
| { | |||
| const std::string& name = attributes[i].getName(); | |||
| const c10::IValue& ivalue = mod._ivalue()->getSlot(i); | |||
| if (name.empty()) | |||
| continue; | |||
| if (ivalue.isTensor()) | |||
| attrs.emplace(name, ivalue.toTensor()); | |||
| if (ivalue.isModule()) | |||
| { | |||
| const torch::jit::Module submod = ivalue.toModule(); | |||
| const std::vector<c10::ClassAttribute>& sub_attributes = submod._ivalue()->type()->getAttributes(); | |||
| for (size_t j = 0; j < sub_attributes.size(); j++) | |||
| { | |||
| const std::string& sub_name = sub_attributes[j].getName(); | |||
| const c10::IValue& sub_ivalue = submod._ivalue()->getSlot(j); | |||
| if (sub_name.empty()) | |||
| continue; | |||
| if (sub_ivalue.isTensor()) | |||
| attrs.emplace(name + "." + sub_name, sub_ivalue.toTensor()); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| bool TorchModuleProxy::hasattr(const std::string& name) const | |||
| { | |||
| return attrs.find(name) != attrs.end(); | |||
| } | |||
| const TorchTensorProxy& TorchModuleProxy::attr(const std::string& name) const | |||
| { | |||
| return attrs.at(name); | |||
| } | |||
| FuseModulePass::~FuseModulePass() | |||
| { | |||
| } | |||
| void FuseModulePass::write(Operator* /*op*/, const TorchGraphProxy& /*graph*/) const | |||
| { | |||
| } | |||
| void FuseModulePass::write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& /*mod*/) const | |||
| { | |||
| write(op, graph); | |||
| } | |||
| static std::vector<const FuseModulePass*> g_global_pnnx_fuse_module_passes; | |||
| const std::vector<const FuseModulePass*>& get_global_pnnx_fuse_module_passes() | |||
| { | |||
| return g_global_pnnx_fuse_module_passes; | |||
| } | |||
| FuseModulePassRegister::FuseModulePassRegister(const FuseModulePass* _pass) | |||
| : pass(_pass) | |||
| { | |||
| g_global_pnnx_fuse_module_passes.push_back(pass); | |||
| } | |||
| FuseModulePassRegister::~FuseModulePassRegister() | |||
| { | |||
| delete pass; | |||
| } | |||
| } // namespace pnnx | |||
| @@ -0,0 +1,151 @@ | |||
| // Tencent is pleased to support the open source community by making ncnn available. | |||
| // | |||
| // Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. | |||
| // | |||
| // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | |||
| // in compliance with the License. You may obtain a copy of the License at | |||
| // | |||
| // https://opensource.org/licenses/BSD-3-Clause | |||
| // | |||
| // Unless required by applicable law or agreed to in writing, software distributed | |||
| // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #ifndef PNNX_FUSE_MODULE_PASS_H | |||
| #define PNNX_FUSE_MODULE_PASS_H | |||
| #include <memory> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include "ir.h" | |||
| namespace torch { | |||
| namespace jit { | |||
| struct Graph; | |||
| struct Module; | |||
| struct Node; | |||
| struct Value; | |||
| } // namespace jit | |||
| } // namespace torch | |||
| namespace at { | |||
| struct Tensor; | |||
| } // namespace at | |||
| namespace pnnx { | |||
| class TorchNodeProxy | |||
| { | |||
| public: | |||
| TorchNodeProxy(const torch::jit::Node* _node) | |||
| : node(_node) | |||
| { | |||
| } | |||
| std::string kind() const; | |||
| bool hasNamedInput(const std::string& name) const; | |||
| const torch::jit::Value* namedInput(const std::string& name) const; | |||
| int input_count() const; | |||
| const torch::jit::Value* input(int i) const; | |||
| int output_count() const; | |||
| const torch::jit::Value* output(int i) const; | |||
| bool is_input_none(int i) const; | |||
| public: | |||
| const torch::jit::Node* node; | |||
| }; | |||
| class TorchGraphProxy | |||
| { | |||
| public: | |||
| TorchGraphProxy(const std::shared_ptr<torch::jit::Graph> _graph); | |||
| // bool has_node(const std::string& name) const; | |||
| const TorchNodeProxy* find_node_by_kind(const std::string& kind) const; | |||
| const TorchNodeProxy* find_producer_node_by_value(const torch::jit::Value* value) const; | |||
| int input_count() const; | |||
| const torch::jit::Value* input(int i) const; | |||
| int output_count() const; | |||
| const torch::jit::Value* output(int i) const; | |||
| void dump() const; | |||
| public: | |||
| const std::shared_ptr<torch::jit::Graph> graph; | |||
| public: | |||
| std::vector<TorchNodeProxy> nodes; | |||
| }; | |||
| class TorchTensorProxyPrivate; | |||
| class TorchTensorProxy | |||
| { | |||
| public: | |||
| TorchTensorProxy(const at::Tensor& _t); | |||
| ~TorchTensorProxy(); | |||
| TorchTensorProxy(const TorchTensorProxy&) = delete; | |||
| TorchTensorProxy& operator=(const TorchTensorProxy&) = delete; | |||
| const at::Tensor& t() const; | |||
| int size(size_t i) const; | |||
| private: | |||
| TorchTensorProxyPrivate* const d; | |||
| }; | |||
| class TorchModuleProxy | |||
| { | |||
| public: | |||
| TorchModuleProxy(const torch::jit::Module& _mod); | |||
| bool hasattr(const std::string& name) const; | |||
| const TorchTensorProxy& attr(const std::string& name) const; | |||
| public: | |||
| const torch::jit::Module& mod; | |||
| private: | |||
| std::unordered_map<std::string, TorchTensorProxy> attrs; | |||
| }; | |||
| class FuseModulePass | |||
| { | |||
| public: | |||
| virtual ~FuseModulePass(); | |||
| virtual const char* match_type_str() const = 0; | |||
| virtual const char* type_str() const = 0; | |||
| virtual void write(Operator* op, const TorchGraphProxy& graph) const; | |||
| virtual void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const; | |||
| }; | |||
| class FuseModulePassRegister | |||
| { | |||
| public: | |||
| FuseModulePassRegister(const FuseModulePass* pass); | |||
| ~FuseModulePassRegister(); | |||
| const FuseModulePass* pass; | |||
| }; | |||
| const std::vector<const FuseModulePass*>& get_global_pnnx_fuse_module_passes(); | |||
| #define REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(CLASS) \ | |||
| static FuseModulePassRegister g_global_pnnx_fusemodulepass_##CLASS##_register(new CLASS); | |||
| } // namespace pnnx | |||
| #endif // PNNX_FUSE_MODULE_PASS_H | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,9 +29,9 @@ public: | |||
| return "nn.AdaptiveAvgPool1d"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* adaptive_avg_pool1d = find_node_by_kind(graph, "aten::adaptive_avg_pool1d"); | |||
| const TorchNodeProxy* adaptive_avg_pool1d = graph.find_node_by_kind("aten::adaptive_avg_pool1d"); | |||
| op->params["output_size"] = adaptive_avg_pool1d->namedInput("output_size"); | |||
| } | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,9 +29,9 @@ public: | |||
| return "nn.AdaptiveAvgPool2d"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* adaptive_avg_pool2d = find_node_by_kind(graph, "aten::adaptive_avg_pool2d"); | |||
| const TorchNodeProxy* adaptive_avg_pool2d = graph.find_node_by_kind("aten::adaptive_avg_pool2d"); | |||
| op->params["output_size"] = adaptive_avg_pool2d->namedInput("output_size"); | |||
| } | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,9 +29,9 @@ public: | |||
| return "nn.AdaptiveAvgPool3d"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* adaptive_avg_pool3d = find_node_by_kind(graph, "aten::adaptive_avg_pool3d"); | |||
| const TorchNodeProxy* adaptive_avg_pool3d = graph.find_node_by_kind("aten::adaptive_avg_pool3d"); | |||
| op->params["output_size"] = adaptive_avg_pool3d->namedInput("output_size"); | |||
| } | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,12 +29,14 @@ public: | |||
| return "nn.AdaptiveMaxPool1d"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* adaptive_max_pool1d = find_node_by_kind(graph, "aten::adaptive_max_pool1d"); | |||
| const TorchNodeProxy* adaptive_max_pool1d = graph.find_node_by_kind("aten::adaptive_max_pool1d"); | |||
| const TorchNodeProxy* graph_out = graph.find_producer_node_by_value(graph.output(0)); | |||
| op->params["output_size"] = adaptive_max_pool1d->namedInput("output_size"); | |||
| op->params["return_indices"] = graph->outputs()[0]->node()->kind() == c10::prim::TupleConstruct ? true : false; | |||
| op->params["return_indices"] = graph_out->kind() == "prim::TupleConstruct" ? true : false; | |||
| } | |||
| }; | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,12 +29,14 @@ public: | |||
| return "nn.AdaptiveMaxPool2d"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* adaptive_max_pool2d = find_node_by_kind(graph, "aten::adaptive_max_pool2d"); | |||
| const TorchNodeProxy* adaptive_max_pool2d = graph.find_node_by_kind("aten::adaptive_max_pool2d"); | |||
| const TorchNodeProxy* graph_out = graph.find_producer_node_by_value(graph.output(0)); | |||
| op->params["output_size"] = adaptive_max_pool2d->namedInput("output_size"); | |||
| op->params["return_indices"] = graph->outputs()[0]->node()->kind() == c10::prim::TupleConstruct ? true : false; | |||
| op->params["return_indices"] = graph_out->kind() == "prim::TupleConstruct" ? true : false; | |||
| } | |||
| }; | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,12 +29,14 @@ public: | |||
| return "nn.AdaptiveMaxPool3d"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* adaptive_max_pool3d = find_node_by_kind(graph, "aten::adaptive_max_pool3d"); | |||
| const TorchNodeProxy* adaptive_max_pool3d = graph.find_node_by_kind("aten::adaptive_max_pool3d"); | |||
| const TorchNodeProxy* graph_out = graph.find_producer_node_by_value(graph.output(0)); | |||
| op->params["output_size"] = adaptive_max_pool3d->namedInput("output_size"); | |||
| op->params["return_indices"] = graph->outputs()[0]->node()->kind() == c10::prim::TupleConstruct ? true : false; | |||
| op->params["return_indices"] = graph_out->kind() == "prim::TupleConstruct" ? true : false; | |||
| } | |||
| }; | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,9 +29,9 @@ public: | |||
| return "nn.AvgPool1d"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* avg_pool1d = find_node_by_kind(graph, "aten::avg_pool1d"); | |||
| const TorchNodeProxy* avg_pool1d = graph.find_node_by_kind("aten::avg_pool1d"); | |||
| op->params["kernel_size"] = avg_pool1d->namedInput("kernel_size"); | |||
| op->params["stride"] = avg_pool1d->namedInput("stride"); | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,9 +29,9 @@ public: | |||
| return "nn.AvgPool2d"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* avg_pool2d = find_node_by_kind(graph, "aten::avg_pool2d"); | |||
| const TorchNodeProxy* avg_pool2d = graph.find_node_by_kind("aten::avg_pool2d"); | |||
| op->params["kernel_size"] = avg_pool2d->namedInput("kernel_size"); | |||
| op->params["stride"] = avg_pool2d->namedInput("stride"); | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,9 +29,9 @@ public: | |||
| return "nn.AvgPool3d"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* avg_pool3d = find_node_by_kind(graph, "aten::avg_pool3d"); | |||
| const TorchNodeProxy* avg_pool3d = graph.find_node_by_kind("aten::avg_pool3d"); | |||
| op->params["kernel_size"] = avg_pool3d->namedInput("kernel_size"); | |||
| op->params["stride"] = avg_pool3d->namedInput("stride"); | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,12 +29,12 @@ public: | |||
| return "nn.BatchNorm1d"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const | |||
| void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const | |||
| { | |||
| const torch::jit::Node* bn = find_node_by_kind(graph, "aten::batch_norm"); | |||
| const TorchNodeProxy* bn = graph.find_node_by_kind("aten::batch_norm"); | |||
| const auto& running_mean = mod.attr("running_mean").toTensor(); | |||
| const auto& running_var = mod.attr("running_var").toTensor(); | |||
| const TorchTensorProxy& running_mean = mod.attr("running_mean"); | |||
| const TorchTensorProxy& running_var = mod.attr("running_var"); | |||
| op->params["num_features"] = running_mean.size(0); | |||
| op->params["eps"] = bn->namedInput("eps"); | |||
| @@ -46,8 +44,8 @@ public: | |||
| op->attrs["running_var"] = running_var; | |||
| if (mod.hasattr("weight") && mod.hasattr("bias")) | |||
| { | |||
| op->attrs["weight"] = mod.attr("weight").toTensor(); | |||
| op->attrs["bias"] = mod.attr("bias").toTensor(); | |||
| op->attrs["weight"] = mod.attr("weight"); | |||
| op->attrs["bias"] = mod.attr("bias"); | |||
| } | |||
| } | |||
| }; | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,12 +29,12 @@ public: | |||
| return "nn.BatchNorm2d"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const | |||
| void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const | |||
| { | |||
| const torch::jit::Node* bn = find_node_by_kind(graph, "aten::batch_norm"); | |||
| const TorchNodeProxy* bn = graph.find_node_by_kind("aten::batch_norm"); | |||
| const auto& running_mean = mod.attr("running_mean").toTensor(); | |||
| const auto& running_var = mod.attr("running_var").toTensor(); | |||
| const TorchTensorProxy& running_mean = mod.attr("running_mean"); | |||
| const TorchTensorProxy& running_var = mod.attr("running_var"); | |||
| op->params["num_features"] = running_mean.size(0); | |||
| op->params["eps"] = bn->namedInput("eps"); | |||
| @@ -46,8 +44,8 @@ public: | |||
| op->attrs["running_var"] = running_var; | |||
| if (mod.hasattr("weight") && mod.hasattr("bias")) | |||
| { | |||
| op->attrs["weight"] = mod.attr("weight").toTensor(); | |||
| op->attrs["bias"] = mod.attr("bias").toTensor(); | |||
| op->attrs["weight"] = mod.attr("weight"); | |||
| op->attrs["bias"] = mod.attr("bias"); | |||
| } | |||
| } | |||
| }; | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,12 +29,12 @@ public: | |||
| return "nn.BatchNorm3d"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const | |||
| void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const | |||
| { | |||
| const torch::jit::Node* bn = find_node_by_kind(graph, "aten::batch_norm"); | |||
| const TorchNodeProxy* bn = graph.find_node_by_kind("aten::batch_norm"); | |||
| const auto& running_mean = mod.attr("running_mean").toTensor(); | |||
| const auto& running_var = mod.attr("running_var").toTensor(); | |||
| const TorchTensorProxy& running_mean = mod.attr("running_mean"); | |||
| const TorchTensorProxy& running_var = mod.attr("running_var"); | |||
| op->params["num_features"] = running_mean.size(0); | |||
| op->params["eps"] = bn->namedInput("eps"); | |||
| @@ -46,8 +44,8 @@ public: | |||
| op->attrs["running_var"] = running_var; | |||
| if (mod.hasattr("weight") && mod.hasattr("bias")) | |||
| { | |||
| op->attrs["weight"] = mod.attr("weight").toTensor(); | |||
| op->attrs["bias"] = mod.attr("bias").toTensor(); | |||
| op->attrs["weight"] = mod.attr("weight"); | |||
| op->attrs["bias"] = mod.attr("bias"); | |||
| } | |||
| } | |||
| }; | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,9 +29,9 @@ public: | |||
| return "nn.CELU"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* celu = find_node_by_kind(graph, "aten::celu"); | |||
| const TorchNodeProxy* celu = graph.find_node_by_kind("aten::celu"); | |||
| op->params["alpha"] = celu->namedInput("alpha"); | |||
| } | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,9 +29,9 @@ public: | |||
| return "nn.ChannelShuffle"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* channel_shuffle = find_node_by_kind(graph, "aten::channel_shuffle"); | |||
| const TorchNodeProxy* channel_shuffle = graph.find_node_by_kind("aten::channel_shuffle"); | |||
| op->params["groups"] = channel_shuffle->namedInput("groups"); | |||
| } | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,10 +29,10 @@ public: | |||
| return "nn.ConstantPad1d"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* pad = find_node_by_kind(graph, "aten::pad"); | |||
| const torch::jit::Node* constant_pad_nd = find_node_by_kind(graph, "aten::constant_pad_nd"); | |||
| const TorchNodeProxy* pad = graph.find_node_by_kind("aten::pad"); | |||
| const TorchNodeProxy* constant_pad_nd = graph.find_node_by_kind("aten::constant_pad_nd"); | |||
| if (!pad) | |||
| { | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,10 +29,10 @@ public: | |||
| return "nn.ConstantPad2d"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* pad = find_node_by_kind(graph, "aten::pad"); | |||
| const torch::jit::Node* constant_pad_nd = find_node_by_kind(graph, "aten::constant_pad_nd"); | |||
| const TorchNodeProxy* pad = graph.find_node_by_kind("aten::pad"); | |||
| const TorchNodeProxy* constant_pad_nd = graph.find_node_by_kind("aten::constant_pad_nd"); | |||
| if (!pad) | |||
| { | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,10 +29,10 @@ public: | |||
| return "nn.ConstantPad3d"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* pad = find_node_by_kind(graph, "aten::pad"); | |||
| const torch::jit::Node* constant_pad_nd = find_node_by_kind(graph, "aten::constant_pad_nd"); | |||
| const TorchNodeProxy* pad = graph.find_node_by_kind("aten::pad"); | |||
| const TorchNodeProxy* constant_pad_nd = graph.find_node_by_kind("aten::constant_pad_nd"); | |||
| if (!pad) | |||
| { | |||
| @@ -12,11 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| // #include "../pass_level3/fuse_expression.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -33,7 +29,7 @@ public: | |||
| return "nn.Conv1d"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const | |||
| void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const | |||
| { | |||
| // { | |||
| // pnnx::Graph pnnx_graph; | |||
| @@ -45,18 +41,18 @@ public: | |||
| // pnnx_graph.save("tmp.param", "tmp.bin"); | |||
| // } | |||
| const torch::jit::Node* convolution = find_node_by_kind(graph, "aten::_convolution"); | |||
| const torch::jit::Node* convolution_mode = find_node_by_kind(graph, "aten::_convolution_mode"); | |||
| const torch::jit::Node* pad = find_node_by_kind(graph, "aten::pad"); | |||
| const torch::jit::Node* reflection_pad1d = find_node_by_kind(graph, "aten::reflection_pad1d"); | |||
| const torch::jit::Node* replication_pad1d = find_node_by_kind(graph, "aten::replication_pad1d"); | |||
| const TorchNodeProxy* convolution = graph.find_node_by_kind("aten::_convolution"); | |||
| const TorchNodeProxy* convolution_mode = graph.find_node_by_kind("aten::_convolution_mode"); | |||
| const TorchNodeProxy* pad = graph.find_node_by_kind("aten::pad"); | |||
| const TorchNodeProxy* reflection_pad1d = graph.find_node_by_kind("aten::reflection_pad1d"); | |||
| const TorchNodeProxy* replication_pad1d = graph.find_node_by_kind("aten::replication_pad1d"); | |||
| if (convolution_mode) | |||
| { | |||
| convolution = convolution_mode; | |||
| } | |||
| const auto& weight = mod.attr("weight").toTensor(); | |||
| const TorchTensorProxy& weight = mod.attr("weight"); | |||
| op->params["groups"] = convolution->namedInput("groups"); | |||
| op->params["in_channels"] = weight.size(1) * op->params["groups"].i; | |||
| @@ -131,7 +127,7 @@ public: | |||
| op->attrs["weight"] = weight; | |||
| if (mod.hasattr("bias")) | |||
| { | |||
| op->attrs["bias"] = mod.attr("bias").toTensor(); | |||
| op->attrs["bias"] = mod.attr("bias"); | |||
| } | |||
| } | |||
| }; | |||
| @@ -12,11 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| // #include "../pass_level3/fuse_expression.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -33,7 +29,7 @@ public: | |||
| return "nn.Conv2d"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const | |||
| void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const | |||
| { | |||
| // { | |||
| // pnnx::Graph pnnx_graph; | |||
| @@ -45,18 +41,18 @@ public: | |||
| // pnnx_graph.save("tmp.param", "tmp.bin"); | |||
| // } | |||
| const torch::jit::Node* convolution = find_node_by_kind(graph, "aten::_convolution"); | |||
| const torch::jit::Node* convolution_mode = find_node_by_kind(graph, "aten::_convolution_mode"); | |||
| const torch::jit::Node* pad = find_node_by_kind(graph, "aten::pad"); | |||
| const torch::jit::Node* reflection_pad2d = find_node_by_kind(graph, "aten::reflection_pad2d"); | |||
| const torch::jit::Node* replication_pad2d = find_node_by_kind(graph, "aten::replication_pad2d"); | |||
| const TorchNodeProxy* convolution = graph.find_node_by_kind("aten::_convolution"); | |||
| const TorchNodeProxy* convolution_mode = graph.find_node_by_kind("aten::_convolution_mode"); | |||
| const TorchNodeProxy* pad = graph.find_node_by_kind("aten::pad"); | |||
| const TorchNodeProxy* reflection_pad2d = graph.find_node_by_kind("aten::reflection_pad2d"); | |||
| const TorchNodeProxy* replication_pad2d = graph.find_node_by_kind("aten::replication_pad2d"); | |||
| if (convolution_mode) | |||
| { | |||
| convolution = convolution_mode; | |||
| } | |||
| const auto& weight = mod.attr("weight").toTensor(); | |||
| const TorchTensorProxy& weight = mod.attr("weight"); | |||
| op->params["groups"] = convolution->namedInput("groups"); | |||
| op->params["in_channels"] = weight.size(1) * op->params["groups"].i; | |||
| @@ -126,12 +122,12 @@ public: | |||
| op->params["padding"] = convolution->namedInput("padding"); | |||
| } | |||
| op->params["dilation"] = convolution->namedInput("dilation"); | |||
| op->params["bias"] = mod.hasattr("bias") && mod.attr("bias").isTensor(); | |||
| op->params["bias"] = mod.hasattr("bias"); | |||
| op->attrs["weight"] = weight; | |||
| if (mod.hasattr("bias") && mod.attr("bias").isTensor()) | |||
| if (mod.hasattr("bias")) | |||
| { | |||
| op->attrs["bias"] = mod.attr("bias").toTensor(); | |||
| op->attrs["bias"] = mod.attr("bias"); | |||
| } | |||
| } | |||
| }; | |||
| @@ -12,11 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| // #include "../pass_level3/fuse_expression.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -33,7 +29,7 @@ public: | |||
| return "nn.Conv3d"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const | |||
| void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const | |||
| { | |||
| // { | |||
| // pnnx::Graph pnnx_graph; | |||
| @@ -45,18 +41,18 @@ public: | |||
| // pnnx_graph.save("tmp.param", "tmp.bin"); | |||
| // } | |||
| const torch::jit::Node* convolution = find_node_by_kind(graph, "aten::_convolution"); | |||
| const torch::jit::Node* convolution_mode = find_node_by_kind(graph, "aten::_convolution_mode"); | |||
| const torch::jit::Node* pad = find_node_by_kind(graph, "aten::pad"); | |||
| const torch::jit::Node* reflection_pad3d = find_node_by_kind(graph, "aten::reflection_pad3d"); | |||
| const torch::jit::Node* replication_pad3d = find_node_by_kind(graph, "aten::replication_pad3d"); | |||
| const TorchNodeProxy* convolution = graph.find_node_by_kind("aten::_convolution"); | |||
| const TorchNodeProxy* convolution_mode = graph.find_node_by_kind("aten::_convolution_mode"); | |||
| const TorchNodeProxy* pad = graph.find_node_by_kind("aten::pad"); | |||
| const TorchNodeProxy* reflection_pad3d = graph.find_node_by_kind("aten::reflection_pad3d"); | |||
| const TorchNodeProxy* replication_pad3d = graph.find_node_by_kind("aten::replication_pad3d"); | |||
| if (convolution_mode) | |||
| { | |||
| convolution = convolution_mode; | |||
| } | |||
| const auto& weight = mod.attr("weight").toTensor(); | |||
| const TorchTensorProxy& weight = mod.attr("weight"); | |||
| op->params["groups"] = convolution->namedInput("groups"); | |||
| op->params["in_channels"] = weight.size(1) * op->params["groups"].i; | |||
| @@ -131,7 +127,7 @@ public: | |||
| op->attrs["weight"] = weight; | |||
| if (mod.hasattr("bias")) | |||
| { | |||
| op->attrs["bias"] = mod.attr("bias").toTensor(); | |||
| op->attrs["bias"] = mod.attr("bias"); | |||
| } | |||
| } | |||
| }; | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,11 +29,11 @@ public: | |||
| return "nn.ConvTranspose1d"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const | |||
| void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const | |||
| { | |||
| const torch::jit::Node* convolution = find_node_by_kind(graph, "aten::_convolution"); | |||
| const TorchNodeProxy* convolution = graph.find_node_by_kind("aten::_convolution"); | |||
| const auto& weight = mod.attr("weight").toTensor(); | |||
| const TorchTensorProxy& weight = mod.attr("weight"); | |||
| op->params["groups"] = convolution->namedInput("groups"); | |||
| op->params["in_channels"] = weight.size(0); | |||
| @@ -50,7 +48,7 @@ public: | |||
| op->attrs["weight"] = weight; | |||
| if (mod.hasattr("bias")) | |||
| { | |||
| op->attrs["bias"] = mod.attr("bias").toTensor(); | |||
| op->attrs["bias"] = mod.attr("bias"); | |||
| } | |||
| if (op->inputs.size() > 1) | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,11 +29,11 @@ public: | |||
| return "nn.ConvTranspose2d"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const | |||
| void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const | |||
| { | |||
| const torch::jit::Node* convolution = find_node_by_kind(graph, "aten::_convolution"); | |||
| const TorchNodeProxy* convolution = graph.find_node_by_kind("aten::_convolution"); | |||
| const auto& weight = mod.attr("weight").toTensor(); | |||
| const TorchTensorProxy& weight = mod.attr("weight"); | |||
| op->params["groups"] = convolution->namedInput("groups"); | |||
| op->params["in_channels"] = weight.size(0); | |||
| @@ -50,7 +48,7 @@ public: | |||
| op->attrs["weight"] = weight; | |||
| if (mod.hasattr("bias")) | |||
| { | |||
| op->attrs["bias"] = mod.attr("bias").toTensor(); | |||
| op->attrs["bias"] = mod.attr("bias"); | |||
| } | |||
| if (op->inputs.size() > 1) | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,11 +29,11 @@ public: | |||
| return "nn.ConvTranspose3d"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const | |||
| void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const | |||
| { | |||
| const torch::jit::Node* convolution = find_node_by_kind(graph, "aten::_convolution"); | |||
| const TorchNodeProxy* convolution = graph.find_node_by_kind("aten::_convolution"); | |||
| const auto& weight = mod.attr("weight").toTensor(); | |||
| const TorchTensorProxy& weight = mod.attr("weight"); | |||
| op->params["groups"] = convolution->namedInput("groups"); | |||
| op->params["in_channels"] = weight.size(0); | |||
| @@ -50,7 +48,7 @@ public: | |||
| op->attrs["weight"] = weight; | |||
| if (mod.hasattr("bias")) | |||
| { | |||
| op->attrs["bias"] = mod.attr("bias").toTensor(); | |||
| op->attrs["bias"] = mod.attr("bias"); | |||
| } | |||
| if (op->inputs.size() > 1) | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,9 +29,9 @@ public: | |||
| return "nn.ELU"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* elu = find_node_by_kind(graph, "aten::elu"); | |||
| const TorchNodeProxy* elu = graph.find_node_by_kind("aten::elu"); | |||
| op->params["alpha"] = elu->namedInput("alpha"); | |||
| } | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,11 +29,11 @@ public: | |||
| return "nn.Embedding"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const | |||
| void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const | |||
| { | |||
| const torch::jit::Node* embedding = find_node_by_kind(graph, "aten::embedding"); | |||
| const TorchNodeProxy* embedding = graph.find_node_by_kind("aten::embedding"); | |||
| const auto& weight = mod.attr("weight").toTensor(); | |||
| const TorchTensorProxy& weight = mod.attr("weight"); | |||
| op->params["num_embeddings"] = weight.size(0); | |||
| op->params["embedding_dim"] = weight.size(1); | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,9 +29,9 @@ public: | |||
| return "nn.Fold"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* col2im = find_node_by_kind(graph, "aten::col2im"); | |||
| const TorchNodeProxy* col2im = graph.find_node_by_kind("aten::col2im"); | |||
| op->params["output_size"] = col2im->namedInput("output_size"); | |||
| op->params["kernel_size"] = col2im->namedInput("kernel_size"); | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,9 +29,9 @@ public: | |||
| return "nn.GELU"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* gelu = find_node_by_kind(graph, "aten::gelu"); | |||
| const TorchNodeProxy* gelu = graph.find_node_by_kind("aten::gelu"); | |||
| if (gelu->hasNamedInput("approximate")) | |||
| { | |||
| @@ -13,9 +13,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -32,9 +30,9 @@ public: | |||
| return "nn.GLU"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* glu = find_node_by_kind(graph, "aten::glu"); | |||
| const TorchNodeProxy* glu = graph.find_node_by_kind("aten::glu"); | |||
| op->params["dim"] = glu->namedInput("dim"); | |||
| } | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,17 +29,17 @@ public: | |||
| return "nn.GRU"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const | |||
| void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const | |||
| { | |||
| // mod.dump(true, true, true); | |||
| // graph->dump(); | |||
| const torch::jit::Node* gru = find_node_by_kind(graph, "aten::gru"); | |||
| const TorchNodeProxy* gru = graph.find_node_by_kind("aten::gru"); | |||
| const torch::jit::Node* return_tuple = find_node_by_kind(graph, "prim::TupleConstruct"); | |||
| if (return_tuple && return_tuple->inputs().size() == 2 && gru->outputs().size() == 2 | |||
| && return_tuple->inputs()[0] == gru->outputs()[1] && return_tuple->inputs()[1] == gru->outputs()[0]) | |||
| const TorchNodeProxy* return_tuple = graph.find_node_by_kind("prim::TupleConstruct"); | |||
| if (return_tuple && return_tuple->input_count() == 2 && gru->output_count() == 2 | |||
| && return_tuple->input(0) == gru->output(1) && return_tuple->input(1) == gru->output(0)) | |||
| { | |||
| // mark the swapped output tuple | |||
| // we would restore the fine order in pass_level3/fuse_rnn_unpack | |||
| @@ -54,7 +52,7 @@ public: | |||
| // fprintf(stderr, "arg %s\n", aa.name().c_str()); | |||
| // } | |||
| const auto& weight_ih_l0 = mod.attr("weight_ih_l0").toTensor(); | |||
| const TorchTensorProxy& weight_ih_l0 = mod.attr("weight_ih_l0"); | |||
| op->params["input_size"] = weight_ih_l0.size(1); | |||
| op->params["hidden_size"] = weight_ih_l0.size(0) / 3; | |||
| @@ -72,16 +70,16 @@ public: | |||
| std::string weight_ih_lk_key = std::string("weight_ih_l") + std::to_string(k); | |||
| std::string weight_hh_lk_key = std::string("weight_hh_l") + std::to_string(k); | |||
| op->attrs[weight_ih_lk_key] = mod.attr(weight_ih_lk_key).toTensor(); | |||
| op->attrs[weight_hh_lk_key] = mod.attr(weight_hh_lk_key).toTensor(); | |||
| op->attrs[weight_ih_lk_key] = mod.attr(weight_ih_lk_key); | |||
| op->attrs[weight_hh_lk_key] = mod.attr(weight_hh_lk_key); | |||
| if (bias) | |||
| { | |||
| std::string bias_ih_lk_key = std::string("bias_ih_l") + std::to_string(k); | |||
| std::string bias_hh_lk_key = std::string("bias_hh_l") + std::to_string(k); | |||
| op->attrs[bias_ih_lk_key] = mod.attr(bias_ih_lk_key).toTensor(); | |||
| op->attrs[bias_hh_lk_key] = mod.attr(bias_hh_lk_key).toTensor(); | |||
| op->attrs[bias_ih_lk_key] = mod.attr(bias_ih_lk_key); | |||
| op->attrs[bias_hh_lk_key] = mod.attr(bias_hh_lk_key); | |||
| } | |||
| if (bidirectional) | |||
| @@ -89,16 +87,16 @@ public: | |||
| std::string weight_ih_lk_reverse_key = std::string("weight_ih_l") + std::to_string(k) + "_reverse"; | |||
| std::string weight_hh_lk_reverse_key = std::string("weight_hh_l") + std::to_string(k) + "_reverse"; | |||
| op->attrs[weight_ih_lk_reverse_key] = mod.attr(weight_ih_lk_reverse_key).toTensor(); | |||
| op->attrs[weight_hh_lk_reverse_key] = mod.attr(weight_hh_lk_reverse_key).toTensor(); | |||
| op->attrs[weight_ih_lk_reverse_key] = mod.attr(weight_ih_lk_reverse_key); | |||
| op->attrs[weight_hh_lk_reverse_key] = mod.attr(weight_hh_lk_reverse_key); | |||
| if (bias) | |||
| { | |||
| std::string bias_ih_lk_reverse_key = std::string("bias_ih_l") + std::to_string(k) + "_reverse"; | |||
| std::string bias_hh_lk_reverse_key = std::string("bias_hh_l") + std::to_string(k) + "_reverse"; | |||
| op->attrs[bias_ih_lk_reverse_key] = mod.attr(bias_ih_lk_reverse_key).toTensor(); | |||
| op->attrs[bias_hh_lk_reverse_key] = mod.attr(bias_hh_lk_reverse_key).toTensor(); | |||
| op->attrs[bias_ih_lk_reverse_key] = mod.attr(bias_ih_lk_reverse_key); | |||
| op->attrs[bias_hh_lk_reverse_key] = mod.attr(bias_hh_lk_reverse_key); | |||
| } | |||
| } | |||
| } | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,11 +29,11 @@ public: | |||
| return "nn.GroupNorm"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const | |||
| void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const | |||
| { | |||
| // graph->dump(); | |||
| const torch::jit::Node* gn = find_node_by_kind(graph, "aten::group_norm"); | |||
| const TorchNodeProxy* gn = graph.find_node_by_kind("aten::group_norm"); | |||
| // for (auto aa : gn->schema().arguments()) | |||
| // { | |||
| @@ -48,12 +46,12 @@ public: | |||
| if (mod.hasattr("weight") && mod.hasattr("bias")) | |||
| { | |||
| const auto& weight = mod.attr("weight").toTensor(); | |||
| const auto& weight = mod.attr("weight"); | |||
| op->params["num_channels"] = weight.size(0); | |||
| op->attrs["weight"] = weight; | |||
| op->attrs["bias"] = mod.attr("bias").toTensor(); | |||
| op->attrs["bias"] = mod.attr("bias"); | |||
| } | |||
| else | |||
| { | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,9 +29,9 @@ public: | |||
| return "nn.Hardshrink"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* hardshrink = find_node_by_kind(graph, "aten::hardshrink"); | |||
| const TorchNodeProxy* hardshrink = graph.find_node_by_kind("aten::hardshrink"); | |||
| op->params["lambd"] = hardshrink->namedInput("lambd"); | |||
| } | |||
| @@ -12,7 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -12,7 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,9 +29,9 @@ public: | |||
| return "nn.Hardtanh"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* hardtanh = find_node_by_kind(graph, "aten::hardtanh"); | |||
| const TorchNodeProxy* hardtanh = graph.find_node_by_kind("aten::hardtanh"); | |||
| op->params["min_val"] = hardtanh->namedInput("min_val"); | |||
| op->params["max_val"] = hardtanh->namedInput("max_val"); | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,11 +29,11 @@ public: | |||
| return "nn.InstanceNorm1d"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const | |||
| void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const | |||
| { | |||
| // graph->dump(); | |||
| const torch::jit::Node* in = find_node_by_kind(graph, "aten::instance_norm"); | |||
| const TorchNodeProxy* in = graph.find_node_by_kind("aten::instance_norm"); | |||
| // for (auto aa : in->schema().arguments()) | |||
| // { | |||
| @@ -48,22 +46,22 @@ public: | |||
| if (mod.hasattr("weight") && mod.hasattr("bias")) | |||
| { | |||
| const auto& weight = mod.attr("weight").toTensor(); | |||
| const TorchTensorProxy& weight = mod.attr("weight"); | |||
| op->params["num_features"] = weight.size(0); | |||
| op->attrs["weight"] = weight; | |||
| op->attrs["bias"] = mod.attr("bias").toTensor(); | |||
| op->attrs["bias"] = mod.attr("bias"); | |||
| } | |||
| if (mod.hasattr("running_mean") && mod.hasattr("running_var")) | |||
| { | |||
| const auto& running_mean = mod.attr("running_mean").toTensor(); | |||
| const TorchTensorProxy& running_mean = mod.attr("running_mean"); | |||
| op->params["num_features"] = running_mean.size(0); | |||
| op->attrs["running_mean"] = running_mean; | |||
| op->attrs["running_var"] = mod.attr("running_var").toTensor(); | |||
| op->attrs["running_var"] = mod.attr("running_var"); | |||
| } | |||
| // take num_features from input shape | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,11 +29,11 @@ public: | |||
| return "nn.InstanceNorm2d"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const | |||
| void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const | |||
| { | |||
| // graph->dump(); | |||
| const torch::jit::Node* in = find_node_by_kind(graph, "aten::instance_norm"); | |||
| const TorchNodeProxy* in = graph.find_node_by_kind("aten::instance_norm"); | |||
| // for (auto aa : in->schema().arguments()) | |||
| // { | |||
| @@ -48,22 +46,22 @@ public: | |||
| if (mod.hasattr("weight") && mod.hasattr("bias")) | |||
| { | |||
| const auto& weight = mod.attr("weight").toTensor(); | |||
| const TorchTensorProxy& weight = mod.attr("weight"); | |||
| op->params["num_features"] = weight.size(0); | |||
| op->attrs["weight"] = weight; | |||
| op->attrs["bias"] = mod.attr("bias").toTensor(); | |||
| op->attrs["bias"] = mod.attr("bias"); | |||
| } | |||
| if (mod.hasattr("running_mean") && mod.hasattr("running_var")) | |||
| { | |||
| const auto& running_mean = mod.attr("running_mean").toTensor(); | |||
| const TorchTensorProxy& running_mean = mod.attr("running_mean"); | |||
| op->params["num_features"] = running_mean.size(0); | |||
| op->attrs["running_mean"] = running_mean; | |||
| op->attrs["running_var"] = mod.attr("running_var").toTensor(); | |||
| op->attrs["running_var"] = mod.attr("running_var"); | |||
| } | |||
| // take num_features from input shape | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,11 +29,11 @@ public: | |||
| return "nn.InstanceNorm3d"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const | |||
| void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const | |||
| { | |||
| // graph->dump(); | |||
| const torch::jit::Node* in = find_node_by_kind(graph, "aten::instance_norm"); | |||
| const TorchNodeProxy* in = graph.find_node_by_kind("aten::instance_norm"); | |||
| // for (auto aa : in->schema().arguments()) | |||
| // { | |||
| @@ -48,22 +46,22 @@ public: | |||
| if (mod.hasattr("weight") && mod.hasattr("bias")) | |||
| { | |||
| const auto& weight = mod.attr("weight").toTensor(); | |||
| const TorchTensorProxy& weight = mod.attr("weight"); | |||
| op->params["num_features"] = weight.size(0); | |||
| op->attrs["weight"] = weight; | |||
| op->attrs["bias"] = mod.attr("bias").toTensor(); | |||
| op->attrs["bias"] = mod.attr("bias"); | |||
| } | |||
| if (mod.hasattr("running_mean") && mod.hasattr("running_var")) | |||
| { | |||
| const auto& running_mean = mod.attr("running_mean").toTensor(); | |||
| const TorchTensorProxy& running_mean = mod.attr("running_mean"); | |||
| op->params["num_features"] = running_mean.size(0); | |||
| op->attrs["running_mean"] = running_mean; | |||
| op->attrs["running_var"] = mod.attr("running_var").toTensor(); | |||
| op->attrs["running_var"] = mod.attr("running_var"); | |||
| } | |||
| // take num_features from input shape | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,21 +29,24 @@ public: | |||
| return "nn.LPPool1d"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* pow = find_node_by_kind(graph, "aten::pow"); | |||
| op->params["norm_type"] = pow->inputs()[1]; | |||
| const TorchNodeProxy* pow = graph.find_node_by_kind("aten::pow"); | |||
| op->params["norm_type"] = pow->input(1); | |||
| const TorchNodeProxy* avg_pool1d = graph.find_node_by_kind("aten::avg_pool1d"); | |||
| const torch::jit::Node* avg_pool1d = find_node_by_kind(graph, "aten::avg_pool1d"); | |||
| const TorchNodeProxy* kernel_size = graph.find_producer_node_by_value(avg_pool1d->namedInput("kernel_size")); | |||
| const TorchNodeProxy* stride = graph.find_producer_node_by_value(avg_pool1d->namedInput("stride")); | |||
| op->params["kernel_size"] = avg_pool1d->namedInput("kernel_size")->node()->inputs()[0]; | |||
| if (avg_pool1d->namedInput("stride")->node()->inputs().size() == 0) | |||
| op->params["kernel_size"] = kernel_size->input(0); | |||
| if (stride->input_count() == 0) | |||
| { | |||
| op->params["stride"] = op->params["kernel_size"]; | |||
| } | |||
| else | |||
| { | |||
| op->params["stride"] = avg_pool1d->namedInput("stride")->node()->inputs()[0]; | |||
| op->params["stride"] = stride->input(0); | |||
| } | |||
| op->params["ceil_mode"] = avg_pool1d->namedInput("ceil_mode"); | |||
| } | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,15 +29,17 @@ public: | |||
| return "nn.LPPool2d"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* pow = find_node_by_kind(graph, "aten::pow"); | |||
| op->params["norm_type"] = pow->inputs()[1]; | |||
| const TorchNodeProxy* pow = graph.find_node_by_kind("aten::pow"); | |||
| op->params["norm_type"] = pow->input(1); | |||
| const TorchNodeProxy* avg_pool2d = graph.find_node_by_kind("aten::avg_pool2d"); | |||
| const torch::jit::Node* avg_pool2d = find_node_by_kind(graph, "aten::avg_pool2d"); | |||
| const TorchNodeProxy* stride = graph.find_producer_node_by_value(avg_pool2d->namedInput("stride")); | |||
| op->params["kernel_size"] = avg_pool2d->namedInput("kernel_size"); | |||
| if (avg_pool2d->namedInput("stride")->node()->inputs().size() == 0) | |||
| if (stride->input_count() == 0) | |||
| { | |||
| op->params["stride"] = op->params["kernel_size"]; | |||
| } | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,17 +29,17 @@ public: | |||
| return "nn.LSTM"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const | |||
| void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const | |||
| { | |||
| // mod.dump(true, true, true); | |||
| // | |||
| // graph->dump(); | |||
| const torch::jit::Node* lstm = find_node_by_kind(graph, "aten::lstm"); | |||
| const TorchNodeProxy* lstm = graph.find_node_by_kind("aten::lstm"); | |||
| const torch::jit::Node* return_tuple = find_node_by_kind(graph, "prim::TupleConstruct"); | |||
| if (return_tuple && return_tuple->inputs().size() == 3 && lstm->outputs().size() == 3 | |||
| && return_tuple->inputs()[0] == lstm->outputs()[1] && return_tuple->inputs()[1] == lstm->outputs()[2] && return_tuple->inputs()[2] == lstm->outputs()[0]) | |||
| const TorchNodeProxy* return_tuple = graph.find_node_by_kind("prim::TupleConstruct"); | |||
| if (return_tuple && return_tuple->input_count() == 3 && lstm->output_count() == 3 | |||
| && return_tuple->input(0) == lstm->output(1) && return_tuple->input(1) == lstm->output(2) && return_tuple->input(2) == lstm->output(0)) | |||
| { | |||
| // mark the swapped output tuple | |||
| // we would restore the fine order in pass_level3/fuse_rnn_unpack | |||
| @@ -54,8 +52,8 @@ public: | |||
| // fprintf(stderr, "arg %s\n", aa.name().c_str()); | |||
| // } | |||
| const auto& weight_ih_l0 = mod.attr("weight_ih_l0").toTensor(); | |||
| const auto& weight_hh_l0 = mod.attr("weight_hh_l0").toTensor(); | |||
| const TorchTensorProxy& weight_ih_l0 = mod.attr("weight_ih_l0"); | |||
| const TorchTensorProxy& weight_hh_l0 = mod.attr("weight_hh_l0"); | |||
| op->params["input_size"] = weight_ih_l0.size(1); | |||
| op->params["hidden_size"] = weight_ih_l0.size(0) / 4; | |||
| @@ -75,23 +73,23 @@ public: | |||
| std::string weight_ih_lk_key = std::string("weight_ih_l") + std::to_string(k); | |||
| std::string weight_hh_lk_key = std::string("weight_hh_l") + std::to_string(k); | |||
| op->attrs[weight_ih_lk_key] = mod.attr(weight_ih_lk_key).toTensor(); | |||
| op->attrs[weight_hh_lk_key] = mod.attr(weight_hh_lk_key).toTensor(); | |||
| op->attrs[weight_ih_lk_key] = mod.attr(weight_ih_lk_key); | |||
| op->attrs[weight_hh_lk_key] = mod.attr(weight_hh_lk_key); | |||
| if (bias) | |||
| { | |||
| std::string bias_ih_lk_key = std::string("bias_ih_l") + std::to_string(k); | |||
| std::string bias_hh_lk_key = std::string("bias_hh_l") + std::to_string(k); | |||
| op->attrs[bias_ih_lk_key] = mod.attr(bias_ih_lk_key).toTensor(); | |||
| op->attrs[bias_hh_lk_key] = mod.attr(bias_hh_lk_key).toTensor(); | |||
| op->attrs[bias_ih_lk_key] = mod.attr(bias_ih_lk_key); | |||
| op->attrs[bias_hh_lk_key] = mod.attr(bias_hh_lk_key); | |||
| } | |||
| if (proj_size > 0) | |||
| { | |||
| std::string weight_hr_lk_key = std::string("weight_hr_l") + std::to_string(k); | |||
| op->attrs[weight_hr_lk_key] = mod.attr(weight_hr_lk_key).toTensor(); | |||
| op->attrs[weight_hr_lk_key] = mod.attr(weight_hr_lk_key); | |||
| } | |||
| if (bidirectional) | |||
| @@ -99,23 +97,23 @@ public: | |||
| std::string weight_ih_lk_reverse_key = std::string("weight_ih_l") + std::to_string(k) + "_reverse"; | |||
| std::string weight_hh_lk_reverse_key = std::string("weight_hh_l") + std::to_string(k) + "_reverse"; | |||
| op->attrs[weight_ih_lk_reverse_key] = mod.attr(weight_ih_lk_reverse_key).toTensor(); | |||
| op->attrs[weight_hh_lk_reverse_key] = mod.attr(weight_hh_lk_reverse_key).toTensor(); | |||
| op->attrs[weight_ih_lk_reverse_key] = mod.attr(weight_ih_lk_reverse_key); | |||
| op->attrs[weight_hh_lk_reverse_key] = mod.attr(weight_hh_lk_reverse_key); | |||
| if (bias) | |||
| { | |||
| std::string bias_ih_lk_reverse_key = std::string("bias_ih_l") + std::to_string(k) + "_reverse"; | |||
| std::string bias_hh_lk_reverse_key = std::string("bias_hh_l") + std::to_string(k) + "_reverse"; | |||
| op->attrs[bias_ih_lk_reverse_key] = mod.attr(bias_ih_lk_reverse_key).toTensor(); | |||
| op->attrs[bias_hh_lk_reverse_key] = mod.attr(bias_hh_lk_reverse_key).toTensor(); | |||
| op->attrs[bias_ih_lk_reverse_key] = mod.attr(bias_ih_lk_reverse_key); | |||
| op->attrs[bias_hh_lk_reverse_key] = mod.attr(bias_hh_lk_reverse_key); | |||
| } | |||
| if (proj_size > 0) | |||
| { | |||
| std::string weight_hr_lk_reverse_key = std::string("weight_hr_l") + std::to_string(k) + "_reverse"; | |||
| op->attrs[weight_hr_lk_reverse_key] = mod.attr(weight_hr_lk_reverse_key).toTensor(); | |||
| op->attrs[weight_hr_lk_reverse_key] = mod.attr(weight_hr_lk_reverse_key); | |||
| } | |||
| } | |||
| } | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,9 +29,9 @@ public: | |||
| return "nn.LayerNorm"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const | |||
| void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const | |||
| { | |||
| const torch::jit::Node* ln = find_node_by_kind(graph, "aten::layer_norm"); | |||
| const TorchNodeProxy* ln = graph.find_node_by_kind("aten::layer_norm"); | |||
| op->params["normalized_shape"] = ln->namedInput("normalized_shape"); | |||
| op->params["eps"] = ln->namedInput("eps"); | |||
| @@ -41,8 +39,8 @@ public: | |||
| if (mod.hasattr("weight") && mod.hasattr("bias")) | |||
| { | |||
| op->attrs["weight"] = mod.attr("weight").toTensor(); | |||
| op->attrs["bias"] = mod.attr("bias").toTensor(); | |||
| op->attrs["weight"] = mod.attr("weight"); | |||
| op->attrs["bias"] = mod.attr("bias"); | |||
| } | |||
| } | |||
| }; | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,10 +29,10 @@ public: | |||
| return "nn.LeakyReLU"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* leaky_relu = find_node_by_kind(graph, "aten::leaky_relu"); | |||
| const torch::jit::Node* leaky_relu_ = find_node_by_kind(graph, "aten::leaky_relu_"); | |||
| const TorchNodeProxy* leaky_relu = graph.find_node_by_kind("aten::leaky_relu"); | |||
| const TorchNodeProxy* leaky_relu_ = graph.find_node_by_kind("aten::leaky_relu_"); | |||
| if (leaky_relu_) | |||
| { | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,20 +29,20 @@ public: | |||
| return "nn.Linear"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const | |||
| void write(Operator* op, const TorchGraphProxy& /*graph*/, const TorchModuleProxy& mod) const | |||
| { | |||
| const torch::jit::Node* addmm = find_node_by_kind(graph, "aten::addmm"); | |||
| // const TorchNodeProxy* addmm = graph.find_node_by_kind("aten::addmm"); | |||
| const auto& weight = mod.attr("weight").toTensor(); | |||
| const TorchTensorProxy& weight = mod.attr("weight"); | |||
| op->params["in_features"] = weight.size(1); | |||
| op->params["out_features"] = weight.size(0); | |||
| op->params["bias"] = mod.hasattr("bias") && mod.attr("bias").isTensor(); | |||
| op->params["bias"] = mod.hasattr("bias"); | |||
| op->attrs["weight"] = weight; | |||
| if (mod.hasattr("bias") && mod.attr("bias").isTensor()) | |||
| if (mod.hasattr("bias")) | |||
| { | |||
| op->attrs["bias"] = mod.attr("bias").toTensor(); | |||
| op->attrs["bias"] = mod.attr("bias"); | |||
| } | |||
| } | |||
| }; | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,26 +29,27 @@ public: | |||
| return "nn.LocalResponseNorm"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* avg_pool = find_node_by_kind(graph, "aten::avg_pool2d"); | |||
| const torch::jit::Node* avg_pool3d = find_node_by_kind(graph, "aten::avg_pool3d"); | |||
| const TorchNodeProxy* avg_pool = graph.find_node_by_kind("aten::avg_pool2d"); | |||
| const TorchNodeProxy* avg_pool3d = graph.find_node_by_kind("aten::avg_pool3d"); | |||
| if (avg_pool3d) | |||
| { | |||
| avg_pool = avg_pool3d; | |||
| } | |||
| op->params["size"] = avg_pool->namedInput("kernel_size")->node()->inputs()[0]; | |||
| const TorchNodeProxy* kernel_size = graph.find_producer_node_by_value(avg_pool->namedInput("kernel_size")); | |||
| op->params["size"] = kernel_size->input(0); | |||
| const torch::jit::Node* pow = find_node_by_kind(graph, "aten::pow"); | |||
| op->params["beta"] = pow->inputs()[1]; | |||
| const TorchNodeProxy* pow = graph.find_node_by_kind("aten::pow"); | |||
| op->params["beta"] = pow->input(1); | |||
| const torch::jit::Node* add = pow->inputs()[0]->node(); | |||
| op->params["k"] = add->inputs()[1]; | |||
| const TorchNodeProxy* add = graph.find_producer_node_by_value(pow->input(0)); | |||
| op->params["k"] = add->input(1); | |||
| const torch::jit::Node* mul = add->inputs()[0]->node(); | |||
| op->params["alpha"] = mul->inputs()[1]; | |||
| const TorchNodeProxy* mul = graph.find_producer_node_by_value(add->input(0)); | |||
| op->params["alpha"] = mul->input(1); | |||
| } | |||
| }; | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,9 +29,9 @@ public: | |||
| return "nn.LogSoftmax"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* log_softmax = find_node_by_kind(graph, "aten::log_softmax"); | |||
| const TorchNodeProxy* log_softmax = graph.find_node_by_kind("aten::log_softmax"); | |||
| op->params["dim"] = log_softmax->namedInput("dim"); | |||
| } | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,10 +29,10 @@ public: | |||
| return "nn.MaxPool1d"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* max_pool1d = find_node_by_kind(graph, "aten::max_pool1d"); | |||
| const torch::jit::Node* max_pool1d_with_indices = find_node_by_kind(graph, "aten::max_pool1d_with_indices"); | |||
| const TorchNodeProxy* max_pool1d = graph.find_node_by_kind("aten::max_pool1d"); | |||
| const TorchNodeProxy* max_pool1d_with_indices = graph.find_node_by_kind("aten::max_pool1d_with_indices"); | |||
| if (max_pool1d_with_indices) | |||
| { | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,10 +29,10 @@ public: | |||
| return "nn.MaxPool2d"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* max_pool2d = find_node_by_kind(graph, "aten::max_pool2d"); | |||
| const torch::jit::Node* max_pool2d_with_indices = find_node_by_kind(graph, "aten::max_pool2d_with_indices"); | |||
| const TorchNodeProxy* max_pool2d = graph.find_node_by_kind("aten::max_pool2d"); | |||
| const TorchNodeProxy* max_pool2d_with_indices = graph.find_node_by_kind("aten::max_pool2d_with_indices"); | |||
| if (max_pool2d_with_indices) | |||
| { | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,10 +29,10 @@ public: | |||
| return "nn.MaxPool3d"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* max_pool3d = find_node_by_kind(graph, "aten::max_pool3d"); | |||
| const torch::jit::Node* max_pool3d_with_indices = find_node_by_kind(graph, "aten::max_pool3d_with_indices"); | |||
| const TorchNodeProxy* max_pool3d = graph.find_node_by_kind("aten::max_pool3d"); | |||
| const TorchNodeProxy* max_pool3d_with_indices = graph.find_node_by_kind("aten::max_pool3d_with_indices"); | |||
| if (max_pool3d_with_indices) | |||
| { | |||
| @@ -12,7 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -12,11 +12,13 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include <torch/csrc/api/include/torch/torch.h> | |||
| // #include "pass_level1.h" | |||
| // | |||
| // #include <torch/csrc/api/include/torch/torch.h> | |||
| // | |||
| // #include "../utils.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -33,19 +35,19 @@ public: | |||
| return "nn.MultiheadAttention"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const | |||
| void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const | |||
| { | |||
| // mod.dump(false, false, false); | |||
| // graph->dump(); | |||
| const torch::jit::Node* multi_head_attention = find_node_by_kind(graph, "aten::_native_multi_head_attention"); | |||
| const TorchNodeProxy* multi_head_attention = graph.find_node_by_kind("aten::_native_multi_head_attention"); | |||
| if (multi_head_attention) | |||
| { | |||
| op->params["num_heads"] = multi_head_attention->namedInput("num_head"); | |||
| op->params["batch_first"] = true; | |||
| op->params["add_zero_attn"] = false; | |||
| if (multi_head_attention->hasNamedInput("mask") && multi_head_attention->namedInput("mask") == graph->inputs()[graph->inputs().size() - 1]) | |||
| if (multi_head_attention->hasNamedInput("mask") && multi_head_attention->namedInput("mask") == graph.input(graph.input_count() - 1)) | |||
| { | |||
| size_t input_count = op->inputs.size(); | |||
| op->inputnames.resize(input_count); | |||
| @@ -54,25 +56,28 @@ public: | |||
| } | |||
| else | |||
| { | |||
| const torch::jit::Node* div_num_heads = find_node_by_kind(graph, "aten::div"); | |||
| const torch::jit::Node* div_num_heads_18 = find_node_by_kind(graph, "aten::floor_divide"); | |||
| const TorchNodeProxy* div_num_heads = graph.find_node_by_kind("aten::div"); | |||
| const TorchNodeProxy* div_num_heads_18 = graph.find_node_by_kind("aten::floor_divide"); | |||
| if (div_num_heads_18) | |||
| { | |||
| div_num_heads = div_num_heads_18; | |||
| } | |||
| op->params["num_heads"] = (int)div_num_heads->input(1)->node()->t(torch::jit::attr::value).item<int64_t>(); | |||
| // const TorchNodeProxy* div_num_heads_input_1 = graph.find_producer_node_by_value(div_num_heads->input(1)); | |||
| // op->params["num_heads"] = (int)div_num_heads_input_1->t(torch::jit::attr::value).item<int64_t>(); | |||
| op->params["num_heads"] = div_num_heads->input(1); | |||
| const torch::jit::Node* transpose_batch_seq = find_node_by_kind(graph, "aten::transpose"); | |||
| const TorchNodeProxy* transpose_batch_seq = graph.find_node_by_kind("aten::transpose"); | |||
| int transpose_dim0 = transpose_batch_seq->input(1)->node()->i(torch::jit::attr::value); | |||
| int transpose_dim1 = transpose_batch_seq->input(2)->node()->i(torch::jit::attr::value); | |||
| if (transpose_dim0 == 1 && transpose_dim1 == 0) | |||
| Parameter transpose_dim0 = transpose_batch_seq->input(1); | |||
| Parameter transpose_dim1 = transpose_batch_seq->input(2); | |||
| if (transpose_dim0.i == 1 && transpose_dim1.i == 0) | |||
| { | |||
| op->params["batch_first"] = true; | |||
| } | |||
| const torch::jit::Node* add_zero_attn = find_node_by_kind(graph, "aten::zeros"); | |||
| const TorchNodeProxy* add_zero_attn = graph.find_node_by_kind("aten::zeros"); | |||
| if (add_zero_attn) | |||
| { | |||
| op->params["add_zero_attn"] = true; | |||
| @@ -82,10 +87,10 @@ public: | |||
| op->params["add_zero_attn"] = false; | |||
| } | |||
| const torch::jit::Node* scaled_dot_product_attention = find_node_by_kind(graph, "aten::scaled_dot_product_attention"); | |||
| const TorchNodeProxy* scaled_dot_product_attention = graph.find_node_by_kind("aten::scaled_dot_product_attention"); | |||
| if (scaled_dot_product_attention) | |||
| { | |||
| if (scaled_dot_product_attention->input(3)->type()->kind() != c10::TypeKind::NoneType) | |||
| if (!scaled_dot_product_attention->is_input_none(3)) | |||
| { | |||
| size_t input_count = op->inputs.size(); | |||
| op->inputnames.resize(input_count); | |||
| @@ -94,7 +99,7 @@ public: | |||
| } | |||
| // find attention mask addition pattern pre torch-2.1 | |||
| const torch::jit::Node* has_attn_mask = find_node_by_kind(graph, "aten::baddbmm"); | |||
| const TorchNodeProxy* has_attn_mask = graph.find_node_by_kind("aten::baddbmm"); | |||
| if (has_attn_mask) | |||
| { | |||
| size_t input_count = op->inputs.size(); | |||
| @@ -106,14 +111,14 @@ public: | |||
| // attn = torch.bmm(Q, K) | |||
| // input0 = torch.add_(attn, attn_mask) | |||
| // attn0 = torch.softmax(input0, -1) | |||
| const torch::jit::Node* softmax = find_node_by_kind(graph, "aten::softmax"); | |||
| const TorchNodeProxy* softmax = graph.find_node_by_kind("aten::softmax"); | |||
| if (softmax) | |||
| { | |||
| const torch::jit::Node* add_ = softmax->input(0)->node(); | |||
| if (add_ && add_->kind().toDisplayString() == std::string("aten::add_")) | |||
| const TorchNodeProxy* add_ = graph.find_producer_node_by_value(softmax->input(0)); | |||
| if (add_ && add_->kind() == "aten::add_") | |||
| { | |||
| const torch::jit::Node* bmm = add_->input(0)->node(); | |||
| if (bmm && bmm->kind().toDisplayString() == std::string("aten::bmm")) | |||
| const TorchNodeProxy* bmm = graph.find_producer_node_by_value(add_->input(0)); | |||
| if (bmm && bmm->kind() == "aten::bmm") | |||
| { | |||
| size_t input_count = op->inputs.size(); | |||
| op->inputnames.resize(input_count); | |||
| @@ -125,7 +130,7 @@ public: | |||
| if (mod.hasattr("in_proj_weight")) | |||
| { | |||
| const auto& in_proj_weight = mod.attr("in_proj_weight").toTensor(); | |||
| const TorchTensorProxy& in_proj_weight = mod.attr("in_proj_weight"); | |||
| op->params["embed_dim"] = in_proj_weight.size(1); | |||
| op->params["kdim"] = in_proj_weight.size(1); | |||
| @@ -134,9 +139,9 @@ public: | |||
| } | |||
| else | |||
| { | |||
| const auto& q_proj_weight = mod.attr("q_proj_weight").toTensor(); | |||
| const auto& k_proj_weight = mod.attr("k_proj_weight").toTensor(); | |||
| const auto& v_proj_weight = mod.attr("v_proj_weight").toTensor(); | |||
| const TorchTensorProxy& q_proj_weight = mod.attr("q_proj_weight"); | |||
| const TorchTensorProxy& k_proj_weight = mod.attr("k_proj_weight"); | |||
| const TorchTensorProxy& v_proj_weight = mod.attr("v_proj_weight"); | |||
| op->params["embed_dim"] = q_proj_weight.size(1); | |||
| op->params["kdim"] = k_proj_weight.size(1); | |||
| @@ -146,15 +151,15 @@ public: | |||
| op->attrs["v_proj_weight"] = v_proj_weight; | |||
| } | |||
| const auto& out_proj_weight = mod.attr("out_proj").toModule().attr("weight").toTensor(); | |||
| const TorchTensorProxy& out_proj_weight = mod.attr("out_proj.weight"); | |||
| op->attrs["out_proj.weight"] = out_proj_weight; | |||
| if (mod.hasattr("in_proj_bias") && mod.attr("out_proj").toModule().hasattr("bias")) | |||
| if (mod.hasattr("in_proj_bias") && mod.hasattr("out_proj.bias")) | |||
| { | |||
| // bias=True | |||
| const auto& in_proj_bias = mod.attr("in_proj_bias").toTensor(); | |||
| const auto& out_proj_bias = mod.attr("out_proj").toModule().attr("bias").toTensor(); | |||
| const TorchTensorProxy& in_proj_bias = mod.attr("in_proj_bias"); | |||
| const TorchTensorProxy& out_proj_bias = mod.attr("out_proj.bias"); | |||
| op->params["bias"] = true; | |||
| op->attrs["in_proj_bias"] = in_proj_bias; | |||
| @@ -166,9 +171,9 @@ public: | |||
| // the output projection bias always there no matter bias is False in pytorch 1.8 | |||
| // this behavior changes since https://github.com/pytorch/pytorch/commit/58d1b3639bc07f9519de18e5a18e575f260c7eeb | |||
| if (mod.attr("out_proj").toModule().hasattr("bias")) | |||
| if (mod.hasattr("out_proj.bias")) | |||
| { | |||
| const auto& out_proj_bias = mod.attr("out_proj").toModule().attr("bias").toTensor(); | |||
| const TorchTensorProxy& out_proj_bias = mod.attr("out_proj.bias"); | |||
| op->attrs["out_proj.bias"] = out_proj_bias; | |||
| } | |||
| } | |||
| @@ -176,8 +181,8 @@ public: | |||
| if (mod.hasattr("bias_k") && mod.hasattr("bias_v")) | |||
| { | |||
| // add_bias_kv=True | |||
| const auto& bias_k = mod.attr("bias_k").toTensor(); | |||
| const auto& bias_v = mod.attr("bias_v").toTensor(); | |||
| const TorchTensorProxy& bias_k = mod.attr("bias_k"); | |||
| const TorchTensorProxy& bias_v = mod.attr("bias_v"); | |||
| op->params["add_bias_kv"] = true; | |||
| op->attrs["bias_k"] = bias_k; | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,9 +29,9 @@ public: | |||
| return "nn.PReLU"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& /*graph*/, const torch::jit::Module& mod) const | |||
| void write(Operator* op, const TorchGraphProxy& /*graph*/, const TorchModuleProxy& mod) const | |||
| { | |||
| const auto& weight = mod.attr("weight").toTensor(); | |||
| const TorchTensorProxy& weight = mod.attr("weight"); | |||
| op->params["num_parameters"] = weight.size(0); | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,9 +29,9 @@ public: | |||
| return "nn.PixelShuffle"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* pixel_shuffle = find_node_by_kind(graph, "aten::pixel_shuffle"); | |||
| const TorchNodeProxy* pixel_shuffle = graph.find_node_by_kind("aten::pixel_shuffle"); | |||
| op->params["upscale_factor"] = pixel_shuffle->namedInput("upscale_factor"); | |||
| } | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,9 +29,9 @@ public: | |||
| return "nn.PixelUnshuffle"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* pixel_unshuffle = find_node_by_kind(graph, "aten::pixel_unshuffle"); | |||
| const TorchNodeProxy* pixel_unshuffle = graph.find_node_by_kind("aten::pixel_unshuffle"); | |||
| op->params["downscale_factor"] = pixel_unshuffle->namedInput("downscale_factor"); | |||
| } | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,9 +29,9 @@ public: | |||
| return "nn.RMSNorm"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const | |||
| void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const | |||
| { | |||
| const torch::jit::Node* rmsn = find_node_by_kind(graph, "aten::rms_norm"); | |||
| const TorchNodeProxy* rmsn = graph.find_node_by_kind("aten::rms_norm"); | |||
| op->params["normalized_shape"] = rmsn->namedInput("normalized_shape"); | |||
| op->params["eps"] = rmsn->namedInput("eps"); | |||
| @@ -41,7 +39,7 @@ public: | |||
| if (mod.hasattr("weight")) | |||
| { | |||
| op->attrs["weight"] = mod.attr("weight").toTensor(); | |||
| op->attrs["weight"] = mod.attr("weight"); | |||
| } | |||
| } | |||
| }; | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,23 +29,23 @@ public: | |||
| return "nn.RNN"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const | |||
| void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const | |||
| { | |||
| // mod.dump(true, true, true); | |||
| // graph->dump(); | |||
| const torch::jit::Node* rnn = find_node_by_kind(graph, "aten::rnn_tanh"); | |||
| const torch::jit::Node* rnn_relu = find_node_by_kind(graph, "aten::rnn_relu"); | |||
| const TorchNodeProxy* rnn = graph.find_node_by_kind("aten::rnn_tanh"); | |||
| const TorchNodeProxy* rnn_relu = graph.find_node_by_kind("aten::rnn_relu"); | |||
| if (rnn_relu) | |||
| { | |||
| rnn = rnn_relu; | |||
| } | |||
| const torch::jit::Node* return_tuple = find_node_by_kind(graph, "prim::TupleConstruct"); | |||
| if (return_tuple && return_tuple->inputs().size() == 2 && rnn->outputs().size() == 2 | |||
| && return_tuple->inputs()[0] == rnn->outputs()[1] && return_tuple->inputs()[1] == rnn->outputs()[0]) | |||
| const TorchNodeProxy* return_tuple = graph.find_node_by_kind("prim::TupleConstruct"); | |||
| if (return_tuple && return_tuple->input_count() == 2 && rnn->output_count() == 2 | |||
| && return_tuple->input(0) == rnn->output(1) && return_tuple->input(1) == rnn->output(0)) | |||
| { | |||
| // mark the swapped output tuple | |||
| // we would restore the fine order in pass_level3/fuse_rnn_unpack | |||
| @@ -60,7 +58,7 @@ public: | |||
| // fprintf(stderr, "arg %s\n", aa.name().c_str()); | |||
| // } | |||
| const auto& weight_ih_l0 = mod.attr("weight_ih_l0").toTensor(); | |||
| const TorchTensorProxy& weight_ih_l0 = mod.attr("weight_ih_l0"); | |||
| op->params["input_size"] = weight_ih_l0.size(1); | |||
| op->params["hidden_size"] = weight_ih_l0.size(0); | |||
| @@ -79,16 +77,16 @@ public: | |||
| std::string weight_ih_lk_key = std::string("weight_ih_l") + std::to_string(k); | |||
| std::string weight_hh_lk_key = std::string("weight_hh_l") + std::to_string(k); | |||
| op->attrs[weight_ih_lk_key] = mod.attr(weight_ih_lk_key).toTensor(); | |||
| op->attrs[weight_hh_lk_key] = mod.attr(weight_hh_lk_key).toTensor(); | |||
| op->attrs[weight_ih_lk_key] = mod.attr(weight_ih_lk_key); | |||
| op->attrs[weight_hh_lk_key] = mod.attr(weight_hh_lk_key); | |||
| if (bias) | |||
| { | |||
| std::string bias_ih_lk_key = std::string("bias_ih_l") + std::to_string(k); | |||
| std::string bias_hh_lk_key = std::string("bias_hh_l") + std::to_string(k); | |||
| op->attrs[bias_ih_lk_key] = mod.attr(bias_ih_lk_key).toTensor(); | |||
| op->attrs[bias_hh_lk_key] = mod.attr(bias_hh_lk_key).toTensor(); | |||
| op->attrs[bias_ih_lk_key] = mod.attr(bias_ih_lk_key); | |||
| op->attrs[bias_hh_lk_key] = mod.attr(bias_hh_lk_key); | |||
| } | |||
| if (bidirectional) | |||
| @@ -96,16 +94,16 @@ public: | |||
| std::string weight_ih_lk_reverse_key = std::string("weight_ih_l") + std::to_string(k) + "_reverse"; | |||
| std::string weight_hh_lk_reverse_key = std::string("weight_hh_l") + std::to_string(k) + "_reverse"; | |||
| op->attrs[weight_ih_lk_reverse_key] = mod.attr(weight_ih_lk_reverse_key).toTensor(); | |||
| op->attrs[weight_hh_lk_reverse_key] = mod.attr(weight_hh_lk_reverse_key).toTensor(); | |||
| op->attrs[weight_ih_lk_reverse_key] = mod.attr(weight_ih_lk_reverse_key); | |||
| op->attrs[weight_hh_lk_reverse_key] = mod.attr(weight_hh_lk_reverse_key); | |||
| if (bias) | |||
| { | |||
| std::string bias_ih_lk_reverse_key = std::string("bias_ih_l") + std::to_string(k) + "_reverse"; | |||
| std::string bias_hh_lk_reverse_key = std::string("bias_hh_l") + std::to_string(k) + "_reverse"; | |||
| op->attrs[bias_ih_lk_reverse_key] = mod.attr(bias_ih_lk_reverse_key).toTensor(); | |||
| op->attrs[bias_hh_lk_reverse_key] = mod.attr(bias_hh_lk_reverse_key).toTensor(); | |||
| op->attrs[bias_ih_lk_reverse_key] = mod.attr(bias_ih_lk_reverse_key); | |||
| op->attrs[bias_hh_lk_reverse_key] = mod.attr(bias_hh_lk_reverse_key); | |||
| } | |||
| } | |||
| } | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,9 +29,9 @@ public: | |||
| return "nn.RReLU"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* rrelu = find_node_by_kind(graph, "aten::rrelu"); | |||
| const TorchNodeProxy* rrelu = graph.find_node_by_kind("aten::rrelu"); | |||
| op->params["lower"] = rrelu->namedInput("lower"); | |||
| op->params["upper"] = rrelu->namedInput("upper"); | |||
| @@ -12,7 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -12,7 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,10 +29,10 @@ public: | |||
| return "nn.ReflectionPad1d"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* pad = find_node_by_kind(graph, "aten::pad"); | |||
| const torch::jit::Node* reflection_pad1d = find_node_by_kind(graph, "aten::reflection_pad1d"); | |||
| const TorchNodeProxy* pad = graph.find_node_by_kind("aten::pad"); | |||
| const TorchNodeProxy* reflection_pad1d = graph.find_node_by_kind("aten::reflection_pad1d"); | |||
| if (pad) | |||
| { | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,10 +29,10 @@ public: | |||
| return "nn.ReflectionPad2d"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* pad = find_node_by_kind(graph, "aten::pad"); | |||
| const torch::jit::Node* reflection_pad2d = find_node_by_kind(graph, "aten::reflection_pad2d"); | |||
| const TorchNodeProxy* pad = graph.find_node_by_kind("aten::pad"); | |||
| const TorchNodeProxy* reflection_pad2d = graph.find_node_by_kind("aten::reflection_pad2d"); | |||
| if (pad) | |||
| { | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,10 +29,10 @@ public: | |||
| return "nn.ReplicationPad1d"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* pad = find_node_by_kind(graph, "aten::pad"); | |||
| const torch::jit::Node* replication_pad1d = find_node_by_kind(graph, "aten::replication_pad1d"); | |||
| const TorchNodeProxy* pad = graph.find_node_by_kind("aten::pad"); | |||
| const TorchNodeProxy* replication_pad1d = graph.find_node_by_kind("aten::replication_pad1d"); | |||
| if (pad) | |||
| { | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,10 +29,10 @@ public: | |||
| return "nn.ReplicationPad2d"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* pad = find_node_by_kind(graph, "aten::pad"); | |||
| const torch::jit::Node* replication_pad2d = find_node_by_kind(graph, "aten::replication_pad2d"); | |||
| const TorchNodeProxy* pad = graph.find_node_by_kind("aten::pad"); | |||
| const TorchNodeProxy* replication_pad2d = graph.find_node_by_kind("aten::replication_pad2d"); | |||
| if (pad) | |||
| { | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,10 +29,10 @@ public: | |||
| return "nn.ReplicationPad3d"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* pad = find_node_by_kind(graph, "aten::pad"); | |||
| const torch::jit::Node* replication_pad3d = find_node_by_kind(graph, "aten::replication_pad3d"); | |||
| const TorchNodeProxy* pad = graph.find_node_by_kind("aten::pad"); | |||
| const TorchNodeProxy* replication_pad3d = graph.find_node_by_kind("aten::replication_pad3d"); | |||
| if (pad) | |||
| { | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -12,7 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -12,7 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,9 +29,9 @@ public: | |||
| return "nn.Softmax"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* softmax = find_node_by_kind(graph, "aten::softmax"); | |||
| const TorchNodeProxy* softmax = graph.find_node_by_kind("aten::softmax"); | |||
| op->params["dim"] = softmax->namedInput("dim"); | |||
| } | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,9 +29,9 @@ public: | |||
| return "nn.Softmin"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* softmax = find_node_by_kind(graph, "aten::softmax"); | |||
| const TorchNodeProxy* softmax = graph.find_node_by_kind("aten::softmax"); | |||
| op->params["dim"] = softmax->namedInput("dim"); | |||
| } | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,9 +29,9 @@ public: | |||
| return "nn.Softplus"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* softplus = find_node_by_kind(graph, "aten::softplus"); | |||
| const TorchNodeProxy* softplus = graph.find_node_by_kind("aten::softplus"); | |||
| op->params["beta"] = softplus->namedInput("beta"); | |||
| op->params["threshold"] = softplus->namedInput("threshold"); | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,9 +29,9 @@ public: | |||
| return "nn.Softshrink"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* softshrink = find_node_by_kind(graph, "aten::softshrink"); | |||
| const TorchNodeProxy* softshrink = graph.find_node_by_kind("aten::softshrink"); | |||
| op->params["lambd"] = softshrink->namedInput("lambd"); | |||
| } | |||
| @@ -12,7 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -12,7 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -12,7 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,9 +29,9 @@ public: | |||
| return "nn.Threshold"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* threshold = find_node_by_kind(graph, "aten::threshold"); | |||
| const TorchNodeProxy* threshold = graph.find_node_by_kind("aten::threshold"); | |||
| op->params["threshold"] = threshold->namedInput("threshold"); | |||
| op->params["value"] = threshold->namedInput("value"); | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,9 +29,9 @@ public: | |||
| return "nn.Unfold"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* im2col = find_node_by_kind(graph, "aten::im2col"); | |||
| const TorchNodeProxy* im2col = graph.find_node_by_kind("aten::im2col"); | |||
| op->params["kernel_size"] = im2col->namedInput("kernel_size"); | |||
| op->params["stride"] = im2col->namedInput("stride"); | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,19 +29,19 @@ public: | |||
| return "nn.Upsample"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* upsample_nearest1d = find_node_by_kind(graph, "aten::upsample_nearest1d"); | |||
| const torch::jit::Node* upsample_linear1d = find_node_by_kind(graph, "aten::upsample_linear1d"); | |||
| const TorchNodeProxy* upsample_nearest1d = graph.find_node_by_kind("aten::upsample_nearest1d"); | |||
| const TorchNodeProxy* upsample_linear1d = graph.find_node_by_kind("aten::upsample_linear1d"); | |||
| const torch::jit::Node* upsample_nearest2d = find_node_by_kind(graph, "aten::upsample_nearest2d"); | |||
| const torch::jit::Node* upsample_bilinear2d = find_node_by_kind(graph, "aten::upsample_bilinear2d"); | |||
| const torch::jit::Node* upsample_bicubic2d = find_node_by_kind(graph, "aten::upsample_bicubic2d"); | |||
| const TorchNodeProxy* upsample_nearest2d = graph.find_node_by_kind("aten::upsample_nearest2d"); | |||
| const TorchNodeProxy* upsample_bilinear2d = graph.find_node_by_kind("aten::upsample_bilinear2d"); | |||
| const TorchNodeProxy* upsample_bicubic2d = graph.find_node_by_kind("aten::upsample_bicubic2d"); | |||
| const torch::jit::Node* upsample_nearest3d = find_node_by_kind(graph, "aten::upsample_nearest3d"); | |||
| const torch::jit::Node* upsample_trilinear3d = find_node_by_kind(graph, "aten::upsample_trilinear3d"); | |||
| const TorchNodeProxy* upsample_nearest3d = graph.find_node_by_kind("aten::upsample_nearest3d"); | |||
| const TorchNodeProxy* upsample_trilinear3d = graph.find_node_by_kind("aten::upsample_trilinear3d"); | |||
| const torch::jit::Node* upsample = 0; | |||
| const TorchNodeProxy* upsample = 0; | |||
| if (upsample_nearest1d) | |||
| { | |||
| upsample = upsample_nearest1d; | |||
| @@ -136,12 +134,17 @@ public: | |||
| std::vector<float> scale_factor; | |||
| try | |||
| { | |||
| const torch::jit::Node* size_list = find_node_by_kind(graph, "prim::ListConstruct"); | |||
| for (auto x : size_list->inputs()) | |||
| const TorchNodeProxy* size_list = graph.find_node_by_kind("prim::ListConstruct"); | |||
| const int size_list_input_count = size_list->input_count(); | |||
| for (int i = 0; i < size_list_input_count; i++) | |||
| { | |||
| auto scale_tensor = x->node()->inputs()[0]->node()->inputs()[0]->node()->inputs()[0]->node()->inputs()[1]->node()->inputs()[0]->node()->inputs()[0]->node(); | |||
| auto t = scale_tensor->t(torch::jit::attr::value); | |||
| float s = (float)t.item<double>(); | |||
| const TorchNodeProxy* scale_tensor = graph.find_producer_node_by_value(graph.find_producer_node_by_value(graph.find_producer_node_by_value(graph.find_producer_node_by_value(graph.find_producer_node_by_value(graph.find_producer_node_by_value(graph.find_producer_node_by_value(size_list->input(i))->input(0))->input(0))->input(0))->input(1))->input(0))->input(0)); | |||
| // auto t = scale_tensor->t(torch::jit::attr::value); | |||
| // float s = (float)t.item<double>(); | |||
| Parameter ps = scale_tensor->node; | |||
| float s = ps.f; | |||
| scale_factor.push_back(s); | |||
| } | |||
| @@ -150,7 +153,7 @@ public: | |||
| catch (...) | |||
| { | |||
| fprintf(stderr, "unhandled upsample recompute_scale_factor graph"); | |||
| graph->dump(); | |||
| graph.dump(); | |||
| } | |||
| } | |||
| } | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,9 +29,9 @@ public: | |||
| return "nn.UpsamplingBilinear2d"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* upsample = find_node_by_kind(graph, "aten::upsample_bilinear2d"); | |||
| const TorchNodeProxy* upsample = graph.find_node_by_kind("aten::upsample_bilinear2d"); | |||
| if (upsample->hasNamedInput("output_size")) | |||
| { | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,9 +29,9 @@ public: | |||
| return "nn.UpsamplingNearest2d"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* upsample = find_node_by_kind(graph, "aten::upsample_nearest2d"); | |||
| const TorchNodeProxy* upsample = graph.find_node_by_kind("aten::upsample_nearest2d"); | |||
| if (upsample->hasNamedInput("output_size")) | |||
| { | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,10 +29,10 @@ public: | |||
| return "nn.ZeroPad2d"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* pad = find_node_by_kind(graph, "aten::pad"); | |||
| const torch::jit::Node* constant_pad_nd = find_node_by_kind(graph, "aten::constant_pad_nd"); | |||
| const TorchNodeProxy* pad = graph.find_node_by_kind("aten::pad"); | |||
| const TorchNodeProxy* constant_pad_nd = graph.find_node_by_kind("aten::constant_pad_nd"); | |||
| if (!pad) | |||
| { | |||
| @@ -12,9 +12,11 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "fuse_module_pass.h" | |||
| #include "../utils.h" | |||
| #include <torch/script.h> | |||
| #include <torch/csrc/jit/api/module.h> | |||
| #include <torch/csrc/jit/passes/quantization/helper.h> | |||
| namespace pnnx { | |||
| @@ -31,11 +33,13 @@ public: | |||
| return "nn.quantized.Conv2d"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const | |||
| void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& _mod) const | |||
| { | |||
| const auto& mod = _mod.mod; | |||
| // graph->dump(); | |||
| const torch::jit::Node* quantized_convolution = find_node_by_kind(graph, "quantized::conv2d"); | |||
| const TorchNodeProxy* quantized_convolution = graph.find_node_by_kind("quantized::conv2d"); | |||
| // for (auto aa : quantized_convolution->schema().arguments()) | |||
| // { | |||
| @@ -113,11 +117,13 @@ public: | |||
| return "nn.intrinsic.quantized.ConvReLU2d"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const | |||
| void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& _mod) const | |||
| { | |||
| const auto& mod = _mod.mod; | |||
| // graph->dump(); | |||
| const torch::jit::Node* quantized_convolution = find_node_by_kind(graph, "quantized::conv2d_relu"); | |||
| const TorchNodeProxy* quantized_convolution = graph.find_node_by_kind("quantized::conv2d_relu"); | |||
| // for (auto aa : quantized_convolution->schema().arguments()) | |||
| // { | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,19 +29,19 @@ public: | |||
| return "nn.quantized.DeQuantize"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const | |||
| { | |||
| // mod.dump(true, false, false); | |||
| // graph->dump(); | |||
| const torch::jit::Node* dequantize = find_node_by_kind(graph, "aten::dequantize"); | |||
| // for (auto aa : dequantize->schema().arguments()) | |||
| // { | |||
| // fprintf(stderr, "arg %s\n", aa.name().c_str()); | |||
| // } | |||
| } | |||
| // void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const | |||
| // { | |||
| // // mod.dump(true, false, false); | |||
| // | |||
| // // graph->dump(); | |||
| // | |||
| // const torch::jit::Node* dequantize = find_node_by_kind(graph, "aten::dequantize"); | |||
| // | |||
| // // for (auto aa : dequantize->schema().arguments()) | |||
| // // { | |||
| // // fprintf(stderr, "arg %s\n", aa.name().c_str()); | |||
| // // } | |||
| // } | |||
| }; | |||
| REGISTER_GLOBAL_PNNX_FUSE_MODULE_PASS(DeQuantize) | |||
| @@ -12,9 +12,11 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "fuse_module_pass.h" | |||
| #include "../utils.h" | |||
| #include <torch/script.h> | |||
| #include <torch/csrc/jit/api/module.h> | |||
| #include <torch/csrc/jit/passes/quantization/helper.h> | |||
| namespace pnnx { | |||
| @@ -31,13 +33,15 @@ public: | |||
| return "nn.quantized.Linear"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const | |||
| void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& _mod) const | |||
| { | |||
| const auto& mod = _mod.mod; | |||
| // mod.dump(true, false, false); | |||
| // graph->dump(); | |||
| const torch::jit::Node* quantized_linear = find_node_by_kind(graph, "quantized::linear"); | |||
| const TorchNodeProxy* quantized_linear = graph.find_node_by_kind("quantized::linear"); | |||
| // for (auto aa : quantized_linear->schema().arguments()) | |||
| // { | |||
| @@ -99,13 +103,15 @@ public: | |||
| return "nn.intrinsic.quantized.LinearReLU"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const | |||
| void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& _mod) const | |||
| { | |||
| const auto& mod = _mod.mod; | |||
| // mod.dump(true, false, false); | |||
| graph->dump(); | |||
| graph.dump(); | |||
| const torch::jit::Node* quantized_linear = find_node_by_kind(graph, "quantized::linear_relu"); | |||
| const TorchNodeProxy* quantized_linear = graph.find_node_by_kind("quantized::linear_relu"); | |||
| // for (auto aa : quantized_linear->schema().arguments()) | |||
| // { | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,13 +29,13 @@ public: | |||
| return "nn.quantized.Quantize"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| // mod.dump(true, false, false); | |||
| // graph->dump(); | |||
| const torch::jit::Node* quantize_per_tensor = find_node_by_kind(graph, "aten::quantize_per_tensor"); | |||
| const TorchNodeProxy* quantize_per_tensor = graph.find_node_by_kind("aten::quantize_per_tensor"); | |||
| // for (auto aa : quantize_per_tensor->schema().arguments()) | |||
| // { | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,11 +29,11 @@ public: | |||
| return "torchvision.ops.DeformConv2d"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& mod) const | |||
| void write(Operator* op, const TorchGraphProxy& graph, const TorchModuleProxy& mod) const | |||
| { | |||
| const torch::jit::Node* deform_conv2d = find_node_by_kind(graph, "torchvision::deform_conv2d"); | |||
| const TorchNodeProxy* deform_conv2d = graph.find_node_by_kind("torchvision::deform_conv2d"); | |||
| const auto& weight = mod.attr("weight").toTensor(); | |||
| const TorchTensorProxy& weight = mod.attr("weight"); | |||
| const Parameter stride_w = deform_conv2d->namedInput("stride_w"); | |||
| const Parameter stride_h = deform_conv2d->namedInput("stride_h"); | |||
| @@ -56,7 +54,7 @@ public: | |||
| op->attrs["weight"] = weight; | |||
| if (mod.hasattr("bias")) | |||
| { | |||
| op->attrs["bias"] = mod.attr("bias").toTensor(); | |||
| op->attrs["bias"] = mod.attr("bias"); | |||
| } | |||
| } | |||
| }; | |||
| @@ -12,9 +12,7 @@ | |||
| // CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
| // specific language governing permissions and limitations under the License. | |||
| #include "pass_level1.h" | |||
| #include "../utils.h" | |||
| #include "fuse_module_pass.h" | |||
| namespace pnnx { | |||
| @@ -31,11 +29,11 @@ public: | |||
| return "torchvision.ops.RoIAlign"; | |||
| } | |||
| void write(Operator* op, const std::shared_ptr<torch::jit::Graph>& graph, const torch::jit::Module& /*mod*/) const | |||
| void write(Operator* op, const TorchGraphProxy& graph) const | |||
| { | |||
| const torch::jit::Node* roi_align = find_node_by_kind(graph, "torchvision::roi_align"); | |||
| const TorchNodeProxy* roi_align = graph.find_node_by_kind("torchvision::roi_align"); | |||
| if (roi_align->inputs()[0] == graph->inputs()[2] && roi_align->inputs()[1] == graph->inputs()[1]) | |||
| if (roi_align->input(0) == graph.input(2) && roi_align->input(1) == graph.input(1)) | |||
| { | |||
| fprintf(stderr, "roi_align inputs swapped detected !\n"); | |||
| std::swap(op->inputs[0], op->inputs[1]); | |||
| @@ -15,22 +15,8 @@ | |||
| #ifndef PNNX_UTILS_H | |||
| #define PNNX_UTILS_H | |||
| #if BUILD_TORCH2PNNX | |||
| #include <memory> | |||
| namespace torch { | |||
| namespace jit { | |||
| struct Graph; | |||
| struct Node; | |||
| } // namespace jit | |||
| } // namespace torch | |||
| #endif | |||
| namespace pnnx { | |||
| #if BUILD_TORCH2PNNX | |||
| const torch::jit::Node* find_node_by_kind(const std::shared_ptr<torch::jit::Graph>& graph, const std::string& kind); | |||
| #endif | |||
| unsigned short float32_to_float16(float value); | |||
| float float16_to_float32(unsigned short value); | |||