// Tencent is pleased to support the open source community by making ncnn available. // // Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except // in compliance with the License. You may obtain a copy of the License at // // https://opensource.org/licenses/BSD-3-Clause // // Unless required by applicable law or agreed to in writing, software distributed // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. #include "pass_level2.h" namespace pnnx { class Tensor_to : public GraphRewriterPass { public: const char* match_pattern_graph() const { return R"PNNXIR(7767517 7 6 pnnx.Input input_0 0 1 input prim::Constant op_0 0 1 dtype value=%dtype prim::Constant op_1 0 1 non_blocking value=* prim::Constant op_2 0 1 copy value=%copy prim::Constant op_3 0 1 memory_format value=%memory_format aten::to op_4 5 1 input dtype non_blocking copy memory_format out pnnx.Output output 1 0 out )PNNXIR"; } const char* type_str() const { return "Tensor.to"; } void write(Operator* op, const std::map& captured_params) const { if (captured_params.at("dtype").type == 0) { op->params["dtype"] = Parameter(); } else // if (captured_params.at("dtype").type == 2) { if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; } op->params["copy"] = captured_params.at("copy"); if (captured_params.at("memory_format").type == 2) { if (captured_params.at("memory_format").i == 0) op->params["memory_format"] = "torch.contiguous_format"; if (captured_params.at("memory_format").i == 1) op->params["memory_format"] = "torch.preserve_format"; if (captured_params.at("memory_format").i == 2) op->params["memory_format"] = "torch.channels_last"; } } }; class Tensor_to_1 : public Tensor_to { public: const char* match_pattern_graph() const { return R"PNNXIR(7767517 8 7 pnnx.Input input_0 0 1 input prim::Constant op_0 0 1 device value=* prim::Constant op_1 0 1 dtype value=%dtype prim::Constant op_2 0 1 non_blocking value=* prim::Constant op_3 0 1 copy value=%copy prim::Constant op_4 0 1 memory_format value=%memory_format aten::to op_5 6 1 input device dtype non_blocking copy memory_format out pnnx.Output output 1 0 out )PNNXIR"; } }; class Tensor_to_2 : public Tensor_to { public: const char* match_pattern_graph() const { return R"PNNXIR(7767517 10 9 pnnx.Input input_0 0 1 input prim::Constant op_0 0 1 dtype value=%dtype prim::Constant op_1 0 1 layout value=* prim::Constant op_2 0 1 device value=* prim::Constant op_3 0 1 pin_memory value=* prim::Constant op_4 0 1 non_blocking value=* prim::Constant op_5 0 1 copy value=%copy prim::Constant op_6 0 1 memory_format value=%memory_format aten::to op_7 8 1 input dtype layout device pin_memory non_blocking copy memory_format out pnnx.Output output 1 0 out )PNNXIR"; } }; REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_to, 20) REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_to_1, 20) REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_to_2, 20) } // namespace pnnx