diff --git a/.ci/pnnx.yml b/.ci/pnnx.yml index e3760d111..596e753de 100644 --- a/.ci/pnnx.yml +++ b/.ci/pnnx.yml @@ -58,7 +58,7 @@ jobs: - name: install-deps run: | apt-get update - apt-get install -y python3-pip libjpeg-dev libpng-dev + apt-get install -y python3-pip libjpeg-dev libpng-dev libprotobuf-dev protobuf-compiler python3 -m pip install --upgrade pip pip3 uninstall -y setuptools pip3 install -U pytest setuptools wheel twine distribute requests diff --git a/tools/pnnx/README.md b/tools/pnnx/README.md index eac022941..7ed3d887a 100644 --- a/tools/pnnx/README.md +++ b/tools/pnnx/README.md @@ -62,7 +62,7 @@ mod.save("resnet18.pt") pnnx resnet18.pt inputshape=[1,3,224,224] ``` -Normally, you will get six files +Normally, you will get seven files ```resnet18.pnnx.param``` PNNX graph definition @@ -70,6 +70,8 @@ Normally, you will get six files ```resnet18_pnnx.py``` PyTorch script for inference, the python code for model construction and weight initialization +```resnet18.pnnx.onnx``` PNNX model in onnx format + ```resnet18.ncnn.param``` ncnn graph definition ```resnet18.ncnn.bin``` ncnn model weight @@ -87,6 +89,7 @@ Usage: pnnx [model.pt] [(key=value)...] pnnxparam=model.pnnx.param pnnxbin=model.pnnx.bin pnnxpy=model_pnnx.py + pnnxonnx=model.pnnx.onnx ncnnparam=model.ncnn.param ncnnbin=model.ncnn.bin ncnnpy=model_ncnn.py @@ -108,6 +111,8 @@ Parameters: `pnnxpy` (default="*_pnnx.py"): PyTorch script for inference, including model construction and weight initialization code +`pnnxonnx` (default="*.pnnx.onnx"): PNNX model in onnx format + `ncnnparam` (default="*.ncnn.param"): ncnn graph definition `ncnnbin` (default="*.ncnn.bin"): ncnn model weight diff --git a/tools/pnnx/src/CMakeLists.txt b/tools/pnnx/src/CMakeLists.txt index f29437f54..c2004c8da 100644 --- a/tools/pnnx/src/CMakeLists.txt +++ b/tools/pnnx/src/CMakeLists.txt @@ -497,6 +497,27 @@ set(pnnx_pass_ncnn_SRCS pass_ncnn/torchvision_DeformConv2d.cpp ) +find_package(Protobuf) +if(PROTOBUF_FOUND) + protobuf_generate_cpp(ONNX_PROTO_SRCS ONNX_PROTO_HDRS onnx.proto) + + add_library(pnnx2onnx STATIC + save_onnx.cpp + save_onnx_cxxabi_bridge.cpp + ${ONNX_PROTO_SRCS} ${ONNX_PROTO_HDRS} + ) + + target_include_directories(pnnx2onnx PRIVATE ${PROTOBUF_INCLUDE_DIR} ${CMAKE_CURRENT_BINARY_DIR}) + target_link_libraries(pnnx2onnx PRIVATE ${PROTOBUF_LIBRARIES}) + + # libtorch is usually compiled with old cxx11 abi + set_source_files_properties(save_onnx_cxxabi_bridge.cpp PROPERTIES COMPILE_FLAGS "${TORCH_CXX_FLAGS}") + + message(STATUS "Building with onnx-zero") +else() + message(STATUS "Building without onnx-zero") +endif() + set(pnnx_SRCS main.cpp ir.cpp @@ -510,8 +531,6 @@ set(pnnx_SRCS pass_level4.cpp pass_level5.cpp - pass_ncnn.cpp - ${pnnx_pass_level0_SRCS} ${pnnx_pass_level1_SRCS} ${pnnx_pass_level2_SRCS} @@ -519,6 +538,8 @@ set(pnnx_SRCS ${pnnx_pass_level4_SRCS} ${pnnx_pass_level5_SRCS} + pass_ncnn.cpp + save_ncnn.cpp ${pnnx_pass_ncnn_SRCS} ) @@ -528,6 +549,8 @@ endif() add_executable(pnnx ${pnnx_SRCS}) +target_compile_definitions(pnnx PRIVATE BUILD_PNNX) + if(PNNX_COVERAGE) target_compile_options(pnnx PUBLIC -coverage -fprofile-arcs -ftest-coverage) target_link_libraries(pnnx PUBLIC -coverage -lgcov) @@ -537,6 +560,11 @@ if(WIN32) target_compile_definitions(pnnx PUBLIC NOMINMAX) endif() +if(PROTOBUF_FOUND) + target_compile_definitions(pnnx PRIVATE BUILD_PNNX2ONNX) + target_link_libraries(pnnx PRIVATE pnnx2onnx) +endif() + if(TorchVision_FOUND) target_link_libraries(pnnx PRIVATE TorchVision::TorchVision) endif() diff --git a/tools/pnnx/src/ir.cpp b/tools/pnnx/src/ir.cpp index 705b1e1c6..0f44763af 100644 --- a/tools/pnnx/src/ir.cpp +++ b/tools/pnnx/src/ir.cpp @@ -16,13 +16,16 @@ #include #include +#include #include #include #include #include #include +#if BUILD_PNNX #include +#endif #include "storezip.h" @@ -130,6 +133,7 @@ static int string_to_type(const char* s) return 0; // null } +#if BUILD_PNNX int get_at_tensor_type(const at::ScalarType& st) { if (st == c10::ScalarType::Float) return 1; @@ -295,6 +299,7 @@ Parameter::Parameter(const torch::jit::Value* value) : Parameter(value->node()) { } +#endif // BUILD_PNNX bool operator==(const Parameter& lhs, const Parameter& rhs) { @@ -328,6 +333,7 @@ bool operator==(const Parameter& lhs, const Parameter& rhs) return false; } +#if BUILD_PNNX Attribute::Attribute(const at::Tensor& t) { type = get_at_tensor_type(t.scalar_type()); @@ -384,6 +390,7 @@ Attribute::Attribute(const at::Tensor& t) memcpy((void*)data.data(), (const void*)t.cpu().contiguous().data_ptr(), data.size()); } } +#endif // BUILD_PNNX Attribute::Attribute(const std::initializer_list& _shape, const std::vector& t) { @@ -2289,314 +2296,6 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath) return 0; } -static bool string_is_positive_integer(const std::string& t) -{ - for (size_t i = 0; i < t.size(); i++) - { - if (t[i] < '0' || t[i] > '9') - return false; - } - - return true; -} - -int Graph::ncnn(const std::string& parampath, const std::string& binpath, const std::string& pypath) -{ - FILE* paramfp = fopen(parampath.c_str(), "wb"); - if (!paramfp) - { - fprintf(stderr, "fopen %s failed\n", parampath.c_str()); - return -1; - } - - FILE* binfp = fopen(binpath.c_str(), "wb"); - if (!binfp) - { - fprintf(stderr, "fopen %s failed\n", binpath.c_str()); - fclose(paramfp); - return -1; - } - - // magic - fprintf(paramfp, "7767517\n"); - - // op count and oprand count - fprintf(paramfp, "%d %d\n", (int)ops.size(), (int)operands.size()); - - for (const Operator* op : ops) - { - fprintf(paramfp, "%-24s %-24s %d %d", op->type.c_str(), op->name.c_str(), (int)op->inputs.size(), (int)op->outputs.size()); - - for (const Operand* oprand : op->inputs) - { - fprintf(paramfp, " %s", oprand->name.c_str()); - } - - for (const Operand* oprand : op->outputs) - { - fprintf(paramfp, " %s", oprand->name.c_str()); - } - - for (const auto& it : op->params) - { - const Parameter& param = it.second; - - if (!string_is_positive_integer(it.first)) - { - fprintf(stderr, "ignore %s %s param %s=", op->type.c_str(), op->name.c_str(), it.first.c_str()); - - if (param.type == 0) - { - fprintf(stderr, "None"); - } - if (param.type == 1) - { - if (param.b) - fprintf(stderr, "True"); - else - fprintf(stderr, "False"); - } - if (param.type == 2) - { - fprintf(stderr, "%d", param.i); - } - if (param.type == 3) - { - fprintf(stderr, "%e", param.f); - } - if (param.type == 4) - { - fprintf(stderr, "%s", param.s.c_str()); - } - if (param.type == 5) - { - fprintf(stderr, "("); - for (size_t i = 0; i < param.ai.size(); i++) - { - fprintf(stderr, "%d", param.ai[i]); - if (i + 1 != param.ai.size()) - fprintf(stderr, ","); - } - fprintf(stderr, ")"); - } - if (param.type == 6) - { - fprintf(stderr, "("); - for (size_t i = 0; i < param.af.size(); i++) - { - fprintf(stderr, "%e", param.af[i]); - if (i + 1 != param.af.size()) - fprintf(stderr, ","); - } - fprintf(stderr, ")"); - } - if (param.type == 7) - { - fprintf(stderr, "("); - for (size_t i = 0; i < param.as.size(); i++) - { - fprintf(stderr, "%s", param.as[i].c_str()); - if (i + 1 != param.as.size()) - fprintf(stderr, ","); - } - fprintf(stderr, ")"); - } - fprintf(stderr, "\n"); - - continue; - } - - const int idkey = std::stoi(it.first); - if (param.type == 2) - { - fprintf(paramfp, " %d=%d", idkey, param.i); - } - if (param.type == 3) - { - fprintf(paramfp, " %d=%e", idkey, param.f); - } - if (param.type == 5) - { - const int array_size = (int)param.ai.size(); - fprintf(paramfp, " %d=%d", -23300 - idkey, array_size); - for (size_t i = 0; i < param.ai.size(); i++) - { - fprintf(paramfp, ",%d", param.ai[i]); - } - } - if (param.type == 6) - { - const int array_size = (int)param.af.size(); - fprintf(paramfp, " %d=%d", -23300 - idkey, array_size); - for (size_t i = 0; i < param.af.size(); i++) - { - fprintf(paramfp, ",%e", param.af[i]); - } - } - } - - for (const auto& it : op->attrs) - { - // fprintf(paramfp, " @%s=", it.first.c_str()); - - const Attribute& attr = it.second; - - fwrite(attr.data.data(), attr.data.size(), 1, binfp); - } - - // if (op->inputnames.size() == op->inputs.size()) - // { - // for (size_t i = 0; i < op->inputs.size(); i++) - // { - // const Operand* oprand = op->inputs[i]; - // fprintf(paramfp, " $%s=%s", op->inputnames[i].c_str(), oprand->name.c_str()); - // } - // } - - // for (const Operand* oprand : op->outputs) - // { - // if (oprand->params.find("__batch_index") == oprand->params.end()) - // continue; - // - // const int batch_index = oprand->params.at("__batch_index").i; - // - // fprintf(paramfp, " #%s=%d", oprand->name.c_str(), batch_index); - // } - - // for (const Operand* oprand : op->outputs) - // { - // if (oprand->shape.empty()) - // continue; - // - // fprintf(paramfp, " #%s=", oprand->name.c_str()); - // - // fprintf(paramfp, "("); - // for (int64_t i = 0; i < oprand->shape.size() - 1; i++) - // { - // fprintf(paramfp, "%d,", oprand->shape[i]); - // } - // if (oprand->shape.size() > 0) - // fprintf(paramfp, "%d", oprand->shape[oprand->shape.size() - 1]); - // fprintf(paramfp, ")"); - // - // fprintf(paramfp, type_to_string(oprand->type)); - // } - - fprintf(paramfp, "\n"); - } - - fclose(paramfp); - fclose(binfp); - - FILE* pyfp = fopen(pypath.c_str(), "wb"); - if (!pyfp) - { - fprintf(stderr, "fopen %s failed\n", pypath.c_str()); - return -1; - } - - fprintf(pyfp, "import numpy as np\n"); - fprintf(pyfp, "import ncnn\n"); - fprintf(pyfp, "import torch\n"); - - fprintf(pyfp, "\n"); - - // test inference - { - fprintf(pyfp, "def test_inference():\n"); - fprintf(pyfp, " torch.manual_seed(0)\n"); - - for (int input_index = 0;; input_index++) - { - std::string input_name = std::string("in") + std::to_string(input_index); - const Operand* r = get_operand(input_name); - if (!r) - break; - - if (type_is_integer(r->type)) - { - fprintf(pyfp, " %s = torch.randint(10, (", input_name.c_str()); - for (size_t i = 0; i < r->shape.size(); i++) - { - fprintf(pyfp, "%d", r->shape[i]); - if (i + 1 != r->shape.size() || r->shape.size() == 1) - fprintf(pyfp, ", "); - } - fprintf(pyfp, "), dtype=%s)\n", type_to_dtype_string(r->type)); - } - else - { - fprintf(pyfp, " %s = torch.rand(", input_name.c_str()); - for (size_t i = 0; i < r->shape.size(); i++) - { - fprintf(pyfp, "%d, ", r->shape[i]); - } - fprintf(pyfp, "dtype=%s)\n", type_to_dtype_string(r->type)); - } - } - - fprintf(pyfp, " out = []\n"); - fprintf(pyfp, "\n"); - - fprintf(pyfp, " with ncnn.Net() as net:\n"); - fprintf(pyfp, " net.load_param(\"%s\")\n", parampath.c_str()); - fprintf(pyfp, " net.load_model(\"%s\")\n", binpath.c_str()); - fprintf(pyfp, "\n"); - fprintf(pyfp, " with net.create_extractor() as ex:\n"); - - for (int input_index = 0;; input_index++) - { - std::string input_name = std::string("in") + std::to_string(input_index); - const Operand* r = get_operand(input_name); - if (!r) - break; - - const int batch_index = r->params.at("__batch_index").i; - if (batch_index != 233) - { - fprintf(pyfp, " ex.input(\"%s\", ncnn.Mat(%s.squeeze(%d).numpy()).clone())\n", input_name.c_str(), input_name.c_str(), batch_index); - } - else - { - fprintf(pyfp, " ex.input(\"%s\", ncnn.Mat(%s.numpy()).clone())\n", input_name.c_str(), input_name.c_str()); - } - } - - fprintf(pyfp, "\n"); - - for (int output_index = 0;; output_index++) - { - std::string output_name = std::string("out") + std::to_string(output_index); - const Operand* r = get_operand(output_name); - if (!r) - break; - - fprintf(pyfp, " _, %s = ex.extract(\"%s\")\n", output_name.c_str(), output_name.c_str()); - - const int batch_index = r->params.at("__batch_index").i; - if (batch_index != 233) - { - fprintf(pyfp, " out.append(torch.from_numpy(np.array(%s)).unsqueeze(%d))\n", output_name.c_str(), batch_index); - } - else - { - fprintf(pyfp, " out.append(torch.from_numpy(np.array(%s)))\n", output_name.c_str()); - } - } - - fprintf(pyfp, "\n"); - - fprintf(pyfp, " if len(out) == 1:\n"); - fprintf(pyfp, " return out[0]\n"); - fprintf(pyfp, " else:\n"); - fprintf(pyfp, " return tuple(out)\n"); - } - - fclose(pyfp); - - return 0; -} - int Graph::parse(const std::string& param) { std::istringstream is(param); @@ -2731,6 +2430,7 @@ Operator* Graph::new_operator_after(const std::string& type, const std::string& return op; } +#if BUILD_PNNX Operand* Graph::new_operand(const torch::jit::Value* v) { Operand* r = new Operand; @@ -2757,6 +2457,7 @@ Operand* Graph::new_operand(const torch::jit::Value* v) operands.push_back(r); return r; } +#endif // BUILD_PNNX Operand* Graph::new_operand(const std::string& name) { @@ -2777,4 +2478,15 @@ Operand* Graph::get_operand(const std::string& name) return 0; } +const Operand* Graph::get_operand(const std::string& name) const +{ + for (const Operand* r : operands) + { + if (r->name == name) + return r; + } + + return 0; +} + } // namespace pnnx diff --git a/tools/pnnx/src/ir.h b/tools/pnnx/src/ir.h index 06fe09c14..b59c9b2bb 100644 --- a/tools/pnnx/src/ir.h +++ b/tools/pnnx/src/ir.h @@ -20,6 +20,7 @@ #include #include +#if BUILD_PNNX namespace torch { namespace jit { struct Value; @@ -29,6 +30,7 @@ struct Node; namespace at { class Tensor; } +#endif // BUILD_PNNX namespace pnnx { @@ -114,8 +116,10 @@ public: { } +#if BUILD_PNNX Parameter(const torch::jit::Node* value_node); Parameter(const torch::jit::Value* value); +#endif // BUILD_PNNX static Parameter parse_from_string(const std::string& value); @@ -126,9 +130,11 @@ public: bool b; int i; float f; - std::string s; std::vector ai; std::vector af; + + // keep std::string typed member the last for cross cxxabi compatibility + std::string s; std::vector as; }; @@ -142,7 +148,9 @@ public: { } +#if BUILD_PNNX Attribute(const at::Tensor& t); +#endif // BUILD_PNNX Attribute(const std::initializer_list& shape, const std::vector& t); @@ -164,8 +172,6 @@ class Operand public: void remove_consumer(const Operator* c); - std::string name; - Operator* producer; std::vector consumers; @@ -173,6 +179,9 @@ public: int type; std::vector shape; + // keep std::string typed member the last for cross cxxabi compatibility + std::string name; + std::map params; private: @@ -185,12 +194,13 @@ private: class Operator { public: - std::string type; - std::string name; - std::vector inputs; std::vector outputs; + // keep std::string typed member the last for cross cxxabi compatibility + std::string type; + std::string name; + std::vector inputnames; std::map params; std::map attrs; @@ -213,8 +223,6 @@ public: int python(const std::string& pypath, const std::string& binpath); - int ncnn(const std::string& parampath, const std::string& binpath, const std::string& pypath); - int parse(const std::string& param); Operator* new_operator(const std::string& type, const std::string& name); @@ -223,11 +231,14 @@ public: Operator* new_operator_after(const std::string& type, const std::string& name, const Operator* cur); +#if BUILD_PNNX Operand* new_operand(const torch::jit::Value* v); +#endif Operand* new_operand(const std::string& name); Operand* get_operand(const std::string& name); + const Operand* get_operand(const std::string& name) const; std::vector ops; std::vector operands; diff --git a/tools/pnnx/src/main.cpp b/tools/pnnx/src/main.cpp index 57290fa27..dbe07fd48 100644 --- a/tools/pnnx/src/main.cpp +++ b/tools/pnnx/src/main.cpp @@ -39,6 +39,11 @@ #include "pass_level5.h" #include "pass_ncnn.h" +#include "save_ncnn.h" + +#if BUILD_PNNX2ONNX +#include "save_onnx.h" +#endif static std::string get_basename(const std::string& path) { @@ -159,6 +164,7 @@ static void show_usage() fprintf(stderr, " pnnxparam=model.pnnx.param\n"); fprintf(stderr, " pnnxbin=model.pnnx.bin\n"); fprintf(stderr, " pnnxpy=model_pnnx.py\n"); + fprintf(stderr, " pnnxonnx=model.pnnx.onnx\n"); fprintf(stderr, " ncnnparam=model.ncnn.param\n"); fprintf(stderr, " ncnnbin=model.ncnn.bin\n"); fprintf(stderr, " ncnnpy=model_ncnn.py\n"); @@ -200,6 +206,7 @@ int main(int argc, char** argv) std::string pnnxparampath = ptbase + ".pnnx.param"; std::string pnnxbinpath = ptbase + ".pnnx.bin"; std::string pnnxpypath = ptbase + "_pnnx.py"; + std::string pnnxonnxpath = ptbase + ".pnnx.onnx"; std::string ncnnparampath = ptbase + ".ncnn.param"; std::string ncnnbinpath = ptbase + ".ncnn.bin"; std::string ncnnpypath = ptbase + "_ncnn.py"; @@ -235,6 +242,8 @@ int main(int argc, char** argv) pnnxbinpath = std::string(value); if (strcmp(key, "pnnxpy") == 0) pnnxpypath = std::string(value); + if (strcmp(key, "pnnxonnx") == 0) + pnnxonnxpath = std::string(value); if (strcmp(key, "ncnnparam") == 0) ncnnparampath = std::string(value); if (strcmp(key, "ncnnbin") == 0) @@ -260,6 +269,7 @@ int main(int argc, char** argv) fprintf(stderr, "pnnxparam = %s\n", pnnxparampath.c_str()); fprintf(stderr, "pnnxbin = %s\n", pnnxbinpath.c_str()); fprintf(stderr, "pnnxpy = %s\n", pnnxpypath.c_str()); + fprintf(stderr, "pnnxonnx = %s\n", pnnxonnxpath.c_str()); fprintf(stderr, "ncnnparam = %s\n", ncnnparampath.c_str()); fprintf(stderr, "ncnnbin = %s\n", ncnnbinpath.c_str()); fprintf(stderr, "ncnnpy = %s\n", ncnnpypath.c_str()); @@ -400,13 +410,19 @@ int main(int argc, char** argv) pnnx_graph.python(pnnxpypath, pnnxbinpath); +#if BUILD_PNNX2ONNX + pnnx::save_onnx(pnnx_graph, pnnxonnxpath.c_str()); +#else + fprintf(stderr, "pnnx build without onnx-zero support, skip saving onnx\n"); +#endif + // if (optlevel >= 2) { fprintf(stderr, "############# pass_ncnn\n"); pnnx::pass_ncnn(pnnx_graph); - pnnx_graph.ncnn(ncnnparampath, ncnnbinpath, ncnnpypath); + pnnx::save_ncnn(pnnx_graph, ncnnparampath, ncnnbinpath, ncnnpypath); } // pnnx::Graph pnnx_graph2; diff --git a/tools/pnnx/src/onnx.proto b/tools/pnnx/src/onnx.proto new file mode 100644 index 000000000..461bd0b78 --- /dev/null +++ b/tools/pnnx/src/onnx.proto @@ -0,0 +1,505 @@ +// +// WARNING: This file is automatically generated! Please edit onnx.in.proto. +// + + +// Copyright (c) ONNX Project Contributors. +// Licensed under the MIT license. + +syntax = "proto2"; + +package onnx; + +// Overview +// +// ONNX is an open specification that is comprised of the following components: +// +// 1) A definition of an extensible computation graph model. +// 2) Definitions of standard data types. +// 3) Definitions of built-in operators. +// +// This document describes the syntax of models and their computation graphs, +// as well as the standard data types. Together, they are referred to as the ONNX +// Intermediate Representation, or 'IR' for short. +// +// The normative semantic specification of the ONNX IR is found in docs/IR.md. +// Definitions of the built-in neural network operators may be found in docs/Operators.md. + +// Notes +// +// Release +// +// We are still in the very early stage of defining ONNX. The current +// version of ONNX is a starting point. While we are actively working +// towards a complete spec, we would like to get the community involved +// by sharing our working version of ONNX. +// +// Protobuf compatibility +// +// To simplify framework compatibility, ONNX is defined using the subset of protobuf +// that is compatible with both protobuf v2 and v3. This means that we do not use any +// protobuf features that are only available in one of the two versions. +// +// Here are the most notable contortions we have to carry out to work around +// these limitations: +// +// - No 'map' (added protobuf 3.0). We instead represent mappings as lists +// of key-value pairs, where order does not matter and duplicates +// are not allowed. + + +// Versioning +// +// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md +// +// To be compatible with both proto2 and proto3, we will use a version number +// that is not defined by the default value but an explicit enum number. +enum Version { + // proto3 requires the first enum value to be zero. + // We add this just to appease the compiler. + _START_VERSION = 0; + // The version field is always serialized and we will use it to store the + // version that the graph is generated from. This helps us set up version + // control. + // For the IR, we are using simple numbers starting with with 0x00000001, + // which was the version we published on Oct 10, 2017. + IR_VERSION_2017_10_10 = 0x0000000000000001; + + // IR_VERSION 2 published on Oct 30, 2017 + // - Added type discriminator to AttributeProto to support proto3 users + IR_VERSION_2017_10_30 = 0x0000000000000002; + + // IR VERSION 3 published on Nov 3, 2017 + // - For operator versioning: + // - Added new message OperatorSetIdProto + // - Added opset_import in ModelProto + // - For vendor extensions, added domain in NodeProto + IR_VERSION_2017_11_3 = 0x0000000000000003; + + // IR VERSION 4 published on Jan 22, 2019 + // - Relax constraint that initializers should be a subset of graph inputs + // - Add type BFLOAT16 + IR_VERSION_2019_1_22 = 0x0000000000000004; + + // IR VERSION 5 published on March 18, 2019 + // - Add message TensorAnnotation. + // - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters. + IR_VERSION = 0x0000000000000005; +} + +// Attributes +// +// A named attribute containing either singular float, integer, string, graph, +// and tensor values, or repeated float, integer, string, graph, and tensor values. +// An AttributeProto MUST contain the name field, and *only one* of the +// following content fields, effectively enforcing a C/C++ union equivalent. +message AttributeProto { + + // Note: this enum is structurally identical to the OpSchema::AttrType + // enum defined in schema.h. If you rev one, you likely need to rev the other. + enum AttributeType { + UNDEFINED = 0; + FLOAT = 1; + INT = 2; + STRING = 3; + TENSOR = 4; + GRAPH = 5; + + FLOATS = 6; + INTS = 7; + STRINGS = 8; + TENSORS = 9; + GRAPHS = 10; + } + + // The name field MUST be present for this version of the IR. + optional string name = 1; // namespace Attribute + + // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function. + // In this case, this AttributeProto does not contain data, and it's a reference of attribute + // in parent scope. + // NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph. + optional string ref_attr_name = 21; + + // A human-readable documentation for this attribute. Markdown is allowed. + optional string doc_string = 13; + + // The type field MUST be present for this version of the IR. + // For 0.0.1 versions of the IR, this field was not defined, and + // implementations needed to use has_field hueristics to determine + // which value field was in use. For IR_VERSION 0.0.2 or later, this + // field MUST be set and match the f|i|s|t|... field in use. This + // change was made to accommodate proto3 implementations. + optional AttributeType type = 20; // discriminator that indicates which field below is in use + + // Exactly ONE of the following fields must be present for this version of the IR + optional float f = 2; // float + optional int64 i = 3; // int + optional bytes s = 4; // UTF-8 string + optional TensorProto t = 5; // tensor value + optional GraphProto g = 6; // graph + // Do not use field below, it's deprecated. + // optional ValueProto v = 12; // value - subsumes everything but graph + + repeated float floats = 7; // list of floats + repeated int64 ints = 8; // list of ints + repeated bytes strings = 9; // list of UTF-8 strings + repeated TensorProto tensors = 10; // list of tensors + repeated GraphProto graphs = 11; // list of graph +} + +// Defines information on value, including the name, the type, and +// the shape of the value. +message ValueInfoProto { + // This field MUST be present in this version of the IR. + optional string name = 1; // namespace Value + // This field MUST be present in this version of the IR. + optional TypeProto type = 2; + // A human-readable documentation for this value. Markdown is allowed. + optional string doc_string = 3; +} + +// Nodes +// +// Computation graphs are made up of a DAG of nodes, which represent what is +// commonly called a "layer" or "pipeline stage" in machine learning frameworks. +// +// For example, it can be a node of type "Conv" that takes in an image, a filter +// tensor and a bias tensor, and produces the convolved output. +message NodeProto { + repeated string input = 1; // namespace Value + repeated string output = 2; // namespace Value + + // An optional identifier for this node in a graph. + // This field MAY be absent in ths version of the IR. + optional string name = 3; // namespace Node + + // The symbolic identifier of the Operator to execute. + optional string op_type = 4; // namespace Operator + // The domain of the OperatorSet that specifies the operator named by op_type. + optional string domain = 7; // namespace Domain + + // Additional named attributes. + repeated AttributeProto attribute = 5; + + // A human-readable documentation for this node. Markdown is allowed. + optional string doc_string = 6; +} + +// Models +// +// ModelProto is a top-level file/container format for bundling a ML model and +// associating its computation graph with metadata. +// +// The semantics of the model are described by the associated GraphProto. +message ModelProto { + // The version of the IR this model targets. See Version enum above. + // This field MUST be present. + optional int64 ir_version = 1; + + // The OperatorSets this model relies on. + // All ModelProtos MUST have at least one entry that + // specifies which version of the ONNX OperatorSet is + // being imported. + // + // All nodes in the ModelProto's graph will bind against the operator + // with the same-domain/same-op_type operator with the HIGHEST version + // in the referenced operator sets. + repeated OperatorSetIdProto opset_import = 8; + + // The name of the framework or tool used to generate this model. + // This field SHOULD be present to indicate which implementation/tool/framework + // emitted the model. + optional string producer_name = 2; + + // The version of the framework or tool used to generate this model. + // This field SHOULD be present to indicate which implementation/tool/framework + // emitted the model. + optional string producer_version = 3; + + // Domain name of the model. + // We use reverse domain names as name space indicators. For example: + // `com.facebook.fair` or `com.microsoft.cognitiveservices` + // + // Together with `model_version` and GraphProto.name, this forms the unique identity of + // the graph. + optional string domain = 4; + + // The version of the graph encoded. See Version enum below. + optional int64 model_version = 5; + + // A human-readable documentation for this model. Markdown is allowed. + optional string doc_string = 6; + + // The parameterized graph that is evaluated to execute the model. + optional GraphProto graph = 7; + + // Named metadata values; keys should be distinct. + repeated StringStringEntryProto metadata_props = 14; +}; + +// StringStringEntryProto follows the pattern for cross-proto-version maps. +// See https://developers.google.com/protocol-buffers/docs/proto3#maps +message StringStringEntryProto { + optional string key = 1; + optional string value= 2; +}; + +message TensorAnnotation { + optional string tensor_name = 1; + // pairs to annotate tensor specified by above. + // The keys used in the mapping below must be pre-defined in ONNX spec. + // For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as + // quantization parameter keys. + repeated StringStringEntryProto quant_parameter_tensor_names = 2; +} + + + +// Graphs +// +// A graph defines the computational logic of a model and is comprised of a parameterized +// list of nodes that form a directed acyclic graph based on their inputs and outputs. +// This is the equivalent of the "network" or "graph" in many deep learning +// frameworks. +message GraphProto { + // The nodes in the graph, sorted topologically. + repeated NodeProto node = 1; + + // The name of the graph. + optional string name = 2; // namespace Graph + + // A list of named tensor values, used to specify constant inputs of the graph. + // Each TensorProto entry must have a distinct name (within the list) that + // MAY also appear in the input list. + repeated TensorProto initializer = 5; + + // A human-readable documentation for this graph. Markdown is allowed. + optional string doc_string = 10; + + // The inputs and outputs of the graph. + repeated ValueInfoProto input = 11; + repeated ValueInfoProto output = 12; + + // Information for the values in the graph. The ValueInfoProto.name's + // must be distinct. It is optional for a value to appear in value_info list. + repeated ValueInfoProto value_info = 13; + + // This field carries information to indicate the mapping among a tensor and its + // quantization parameter tensors. For example: + // For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated, + // which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model. + repeated TensorAnnotation quantization_annotation = 14; + + // DO NOT USE the following fields, they were deprecated from earlier versions. + // repeated string input = 3; + // repeated string output = 4; + // optional int64 ir_version = 6; + // optional int64 producer_version = 7; + // optional string producer_tag = 8; + // optional string domain = 9; +} + +// Tensors +// +// A serialized tensor value. +message TensorProto { + enum DataType { + UNDEFINED = 0; + // Basic types. + FLOAT = 1; // float + UINT8 = 2; // uint8_t + INT8 = 3; // int8_t + UINT16 = 4; // uint16_t + INT16 = 5; // int16_t + INT32 = 6; // int32_t + INT64 = 7; // int64_t + STRING = 8; // string + BOOL = 9; // bool + + // IEEE754 half-precision floating-point format (16 bits wide). + // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits. + FLOAT16 = 10; + + DOUBLE = 11; + UINT32 = 12; + UINT64 = 13; + COMPLEX64 = 14; // complex with float32 real and imaginary components + COMPLEX128 = 15; // complex with float64 real and imaginary components + + // Non-IEEE floating-point format based on IEEE754 single-precision + // floating-point number truncated to 16 bits. + // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits. + BFLOAT16 = 16; + + // Future extensions go here. + } + + // The shape of the tensor. + repeated int64 dims = 1; + + // The data type of the tensor. + // This field MUST have a valid TensorProto.DataType value + optional int32 data_type = 2; + + // For very large tensors, we may want to store them in chunks, in which + // case the following fields will specify the segment that is stored in + // the current TensorProto. + message Segment { + optional int64 begin = 1; + optional int64 end = 2; + } + optional Segment segment = 3; + + // Tensor content must be organized in row-major order. + // + // Depending on the data_type field, exactly one of the fields below with + // name ending in _data is used to store the elements of the tensor. + + // For float and complex64 values + // Complex64 tensors are encoded as a single array of floats, + // with the real components appearing in odd numbered positions, + // and the corresponding imaginary component apparing in the + // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + // is encoded as [1.0, 2.0 ,3.0 ,4.0] + // When this field is present, the data_type field MUST be FLOAT or COMPLEX64. + repeated float float_data = 4 [packed = true]; + + // For int32, uint8, int8, uint16, int16, bool, and float16 values + // float16 values must be bit-wise converted to an uint16_t prior + // to writing to the buffer. + // When this field is present, the data_type field MUST be + // INT32, INT16, INT8, UINT16, UINT8, BOOL, or FLOAT16 + repeated int32 int32_data = 5 [packed = true]; + + // For strings. + // Each element of string_data is a UTF-8 encoded Unicode + // string. No trailing null, no leading BOM. The protobuf "string" + // scalar type is not used to match ML community conventions. + // When this field is present, the data_type field MUST be STRING + repeated bytes string_data = 6; + + // For int64. + // When this field is present, the data_type field MUST be INT64 + repeated int64 int64_data = 7 [packed = true]; + + // Optionally, a name for the tensor. + optional string name = 8; // namespace Value + + // A human-readable documentation for this tensor. Markdown is allowed. + optional string doc_string = 12; + + // Serializations can either use one of the fields above, or use this + // raw bytes field. The only exception is the string case, where one is + // required to store the content in the repeated bytes string_data field. + // + // When this raw_data field is used to store tensor value, elements MUST + // be stored in as fixed-width, little-endian order. + // Floating-point data types MUST be stored in IEEE 754 format. + // Complex64 elements must be written as two consecutive FLOAT values, real component first. + // Complex128 elements must be written as two consecutive DOUBLE values, real component first. + // Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false). + // + // Note: the advantage of specific field rather than the raw_data field is + // that in some cases (e.g. int data), protobuf does a better packing via + // variable length storage, and may lead to smaller binary footprint. + // When this field is present, the data_type field MUST NOT be STRING or UNDEFINED + optional bytes raw_data = 9; + + // Data can be stored inside the protobuf file using type-specific fields or raw_data. + // Alternatively, raw bytes data can be stored in an external file, using the external_data field. + // external_data stores key-value pairs describing data location. Recognized keys are: + // - "location" (required) - POSIX filesystem path relative to the directory where the ONNX + // protobuf model was stored + // - "offset" (optional) - position of byte at which stored data begins. Integer stored as string. + // Offset values SHOULD be multiples 4096 (page size) to enable mmap support. + // - "length" (optional) - number of bytes containing data. Integer stored as string. + // - "checksum" (optional) - SHA1 digest of file specified in under 'location' key. + repeated StringStringEntryProto external_data = 13; + + // Location of the data for this tensor. MUST be one of: + // - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field. + // - EXTERNAL - data stored in an external location as described by external_data field. + enum DataLocation { + DEFAULT = 0; + EXTERNAL = 1; + } + + // If value not set, data is stored in raw_data (if set) otherwise in type-specified field. + optional DataLocation data_location = 14; + + // For double + // Complex128 tensors are encoded as a single array of doubles, + // with the real components appearing in odd numbered positions, + // and the corresponding imaginary component apparing in the + // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + // is encoded as [1.0, 2.0 ,3.0 ,4.0] + // When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 + repeated double double_data = 10 [packed = true]; + + // For uint64 and uint32 values + // When this field is present, the data_type field MUST be + // UINT32 or UINT64 + repeated uint64 uint64_data = 11 [packed = true]; +} + +// Defines a tensor shape. A dimension can be either an integer value +// or a symbolic variable. A symbolic variable represents an unknown +// dimension. +message TensorShapeProto { + message Dimension { + oneof value { + int64 dim_value = 1; + string dim_param = 2; // namespace Shape + }; + // Standard denotation can optionally be used to denote tensor + // dimensions with standard semantic descriptions to ensure + // that operations are applied to the correct axis of a tensor. + // Refer to https://github.com/onnx/onnx/blob/master/docs/DimensionDenotation.md#denotation-definition + // for pre-defined dimension denotations. + optional string denotation = 3; + }; + repeated Dimension dim = 1; +} + +// Types +// +// The standard ONNX data types. +message TypeProto { + + message Tensor { + // This field MUST NOT have the value of UNDEFINED + // This field MUST have a valid TensorProto.DataType value + // This field MUST be present for this version of the IR. + optional int32 elem_type = 1; + optional TensorShapeProto shape = 2; + } + + + oneof value { + // The type of a tensor. + Tensor tensor_type = 1; + + } + + // An optional denotation can be used to denote the whole + // type with a standard semantic description as to what is + // stored inside. Refer to https://github.com/onnx/onnx/blob/master/docs/TypeDenotation.md#type-denotation-definition + // for pre-defined type denotations. + optional string denotation = 6; +} + +// Operator Sets +// +// OperatorSets are uniquely identified by a (domain, opset_version) pair. +message OperatorSetIdProto { + // The domain of the operator set being identified. + // The empty string ("") or absence of this field implies the operator + // set that is defined as part of the ONNX specification. + // This field MUST be present in this version of the IR when referring to any other operator set. + optional string domain = 1; + + // The version of the operator set being identified. + // This field MUST be present in this version of the IR. + optional int64 version = 2; +} diff --git a/tools/pnnx/src/save_ncnn.cpp b/tools/pnnx/src/save_ncnn.cpp new file mode 100644 index 000000000..b0710e9db --- /dev/null +++ b/tools/pnnx/src/save_ncnn.cpp @@ -0,0 +1,361 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "save_ncnn.h" + +namespace pnnx { + +static bool type_is_integer(int type) +{ + if (type == 1) return false; + if (type == 2) return false; + if (type == 3) return false; + if (type == 4) return true; + if (type == 5) return true; + if (type == 6) return true; + if (type == 7) return true; + if (type == 8) return true; + if (type == 9) return true; + if (type == 10) return false; + if (type == 11) return false; + if (type == 12) return false; + return false; +} + +static const char* type_to_dtype_string(int type) +{ + if (type == 1) return "torch.float"; + if (type == 2) return "torch.double"; + if (type == 3) return "torch.half"; + if (type == 4) return "torch.int"; + if (type == 5) return "torch.long"; + if (type == 6) return "torch.short"; + if (type == 7) return "torch.int8"; + if (type == 8) return "torch.uint8"; + if (type == 9) return "torch.bool"; + if (type == 10) return "torch.complex64"; + if (type == 11) return "torch.complex128"; + if (type == 12) return "torch.complex32"; + return "null"; +} + +static bool string_is_positive_integer(const std::string& t) +{ + for (size_t i = 0; i < t.size(); i++) + { + if (t[i] < '0' || t[i] > '9') + return false; + } + + return true; +} + +int save_ncnn(const Graph& g, const std::string& parampath, const std::string& binpath, const std::string& pypath) +{ + FILE* paramfp = fopen(parampath.c_str(), "wb"); + if (!paramfp) + { + fprintf(stderr, "fopen %s failed\n", parampath.c_str()); + return -1; + } + + FILE* binfp = fopen(binpath.c_str(), "wb"); + if (!binfp) + { + fprintf(stderr, "fopen %s failed\n", binpath.c_str()); + fclose(paramfp); + return -1; + } + + // magic + fprintf(paramfp, "7767517\n"); + + // op count and oprand count + fprintf(paramfp, "%d %d\n", (int)g.ops.size(), (int)g.operands.size()); + + for (const Operator* op : g.ops) + { + fprintf(paramfp, "%-24s %-24s %d %d", op->type.c_str(), op->name.c_str(), (int)op->inputs.size(), (int)op->outputs.size()); + + for (const Operand* oprand : op->inputs) + { + fprintf(paramfp, " %s", oprand->name.c_str()); + } + + for (const Operand* oprand : op->outputs) + { + fprintf(paramfp, " %s", oprand->name.c_str()); + } + + for (const auto& it : op->params) + { + const Parameter& param = it.second; + + if (!string_is_positive_integer(it.first)) + { + fprintf(stderr, "ignore %s %s param %s=", op->type.c_str(), op->name.c_str(), it.first.c_str()); + + if (param.type == 0) + { + fprintf(stderr, "None"); + } + if (param.type == 1) + { + if (param.b) + fprintf(stderr, "True"); + else + fprintf(stderr, "False"); + } + if (param.type == 2) + { + fprintf(stderr, "%d", param.i); + } + if (param.type == 3) + { + fprintf(stderr, "%e", param.f); + } + if (param.type == 4) + { + fprintf(stderr, "%s", param.s.c_str()); + } + if (param.type == 5) + { + fprintf(stderr, "("); + for (size_t i = 0; i < param.ai.size(); i++) + { + fprintf(stderr, "%d", param.ai[i]); + if (i + 1 != param.ai.size()) + fprintf(stderr, ","); + } + fprintf(stderr, ")"); + } + if (param.type == 6) + { + fprintf(stderr, "("); + for (size_t i = 0; i < param.af.size(); i++) + { + fprintf(stderr, "%e", param.af[i]); + if (i + 1 != param.af.size()) + fprintf(stderr, ","); + } + fprintf(stderr, ")"); + } + if (param.type == 7) + { + fprintf(stderr, "("); + for (size_t i = 0; i < param.as.size(); i++) + { + fprintf(stderr, "%s", param.as[i].c_str()); + if (i + 1 != param.as.size()) + fprintf(stderr, ","); + } + fprintf(stderr, ")"); + } + fprintf(stderr, "\n"); + + continue; + } + + const int idkey = std::stoi(it.first); + if (param.type == 2) + { + fprintf(paramfp, " %d=%d", idkey, param.i); + } + if (param.type == 3) + { + fprintf(paramfp, " %d=%e", idkey, param.f); + } + if (param.type == 5) + { + const int array_size = (int)param.ai.size(); + fprintf(paramfp, " %d=%d", -23300 - idkey, array_size); + for (size_t i = 0; i < param.ai.size(); i++) + { + fprintf(paramfp, ",%d", param.ai[i]); + } + } + if (param.type == 6) + { + const int array_size = (int)param.af.size(); + fprintf(paramfp, " %d=%d", -23300 - idkey, array_size); + for (size_t i = 0; i < param.af.size(); i++) + { + fprintf(paramfp, ",%e", param.af[i]); + } + } + } + + for (const auto& it : op->attrs) + { + // fprintf(paramfp, " @%s=", it.first.c_str()); + + const Attribute& attr = it.second; + + fwrite(attr.data.data(), attr.data.size(), 1, binfp); + } + + // if (op->inputnames.size() == op->inputs.size()) + // { + // for (size_t i = 0; i < op->inputs.size(); i++) + // { + // const Operand* oprand = op->inputs[i]; + // fprintf(paramfp, " $%s=%s", op->inputnames[i].c_str(), oprand->name.c_str()); + // } + // } + + // for (const Operand* oprand : op->outputs) + // { + // if (oprand->params.find("__batch_index") == oprand->params.end()) + // continue; + // + // const int batch_index = oprand->params.at("__batch_index").i; + // + // fprintf(paramfp, " #%s=%d", oprand->name.c_str(), batch_index); + // } + + // for (const Operand* oprand : op->outputs) + // { + // if (oprand->shape.empty()) + // continue; + // + // fprintf(paramfp, " #%s=", oprand->name.c_str()); + // + // fprintf(paramfp, "("); + // for (int64_t i = 0; i < oprand->shape.size() - 1; i++) + // { + // fprintf(paramfp, "%d,", oprand->shape[i]); + // } + // if (oprand->shape.size() > 0) + // fprintf(paramfp, "%d", oprand->shape[oprand->shape.size() - 1]); + // fprintf(paramfp, ")"); + // + // fprintf(paramfp, type_to_string(oprand->type)); + // } + + fprintf(paramfp, "\n"); + } + + fclose(paramfp); + fclose(binfp); + + FILE* pyfp = fopen(pypath.c_str(), "wb"); + if (!pyfp) + { + fprintf(stderr, "fopen %s failed\n", pypath.c_str()); + return -1; + } + + fprintf(pyfp, "import numpy as np\n"); + fprintf(pyfp, "import ncnn\n"); + fprintf(pyfp, "import torch\n"); + + fprintf(pyfp, "\n"); + + // test inference + { + fprintf(pyfp, "def test_inference():\n"); + fprintf(pyfp, " torch.manual_seed(0)\n"); + + for (int input_index = 0;; input_index++) + { + std::string input_name = std::string("in") + std::to_string(input_index); + const Operand* r = g.get_operand(input_name); + if (!r) + break; + + if (type_is_integer(r->type)) + { + fprintf(pyfp, " %s = torch.randint(10, (", input_name.c_str()); + for (size_t i = 0; i < r->shape.size(); i++) + { + fprintf(pyfp, "%d", r->shape[i]); + if (i + 1 != r->shape.size() || r->shape.size() == 1) + fprintf(pyfp, ", "); + } + fprintf(pyfp, "), dtype=%s)\n", type_to_dtype_string(r->type)); + } + else + { + fprintf(pyfp, " %s = torch.rand(", input_name.c_str()); + for (size_t i = 0; i < r->shape.size(); i++) + { + fprintf(pyfp, "%d, ", r->shape[i]); + } + fprintf(pyfp, "dtype=%s)\n", type_to_dtype_string(r->type)); + } + } + + fprintf(pyfp, " out = []\n"); + fprintf(pyfp, "\n"); + + fprintf(pyfp, " with ncnn.Net() as net:\n"); + fprintf(pyfp, " net.load_param(\"%s\")\n", parampath.c_str()); + fprintf(pyfp, " net.load_model(\"%s\")\n", binpath.c_str()); + fprintf(pyfp, "\n"); + fprintf(pyfp, " with net.create_extractor() as ex:\n"); + + for (int input_index = 0;; input_index++) + { + std::string input_name = std::string("in") + std::to_string(input_index); + const Operand* r = g.get_operand(input_name); + if (!r) + break; + + const int batch_index = r->params.at("__batch_index").i; + if (batch_index != 233) + { + fprintf(pyfp, " ex.input(\"%s\", ncnn.Mat(%s.squeeze(%d).numpy()).clone())\n", input_name.c_str(), input_name.c_str(), batch_index); + } + else + { + fprintf(pyfp, " ex.input(\"%s\", ncnn.Mat(%s.numpy()).clone())\n", input_name.c_str(), input_name.c_str()); + } + } + + fprintf(pyfp, "\n"); + + for (int output_index = 0;; output_index++) + { + std::string output_name = std::string("out") + std::to_string(output_index); + const Operand* r = g.get_operand(output_name); + if (!r) + break; + + fprintf(pyfp, " _, %s = ex.extract(\"%s\")\n", output_name.c_str(), output_name.c_str()); + + const int batch_index = r->params.at("__batch_index").i; + if (batch_index != 233) + { + fprintf(pyfp, " out.append(torch.from_numpy(np.array(%s)).unsqueeze(%d))\n", output_name.c_str(), batch_index); + } + else + { + fprintf(pyfp, " out.append(torch.from_numpy(np.array(%s)))\n", output_name.c_str()); + } + } + + fprintf(pyfp, "\n"); + + fprintf(pyfp, " if len(out) == 1:\n"); + fprintf(pyfp, " return out[0]\n"); + fprintf(pyfp, " else:\n"); + fprintf(pyfp, " return tuple(out)\n"); + } + + fclose(pyfp); + + return 0; +} + +} // namespace pnnx diff --git a/tools/pnnx/src/save_ncnn.h b/tools/pnnx/src/save_ncnn.h new file mode 100644 index 000000000..c49f506d3 --- /dev/null +++ b/tools/pnnx/src/save_ncnn.h @@ -0,0 +1,26 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef PNNX_SAVE_NCNN_H +#define PNNX_SAVE_NCNN_H + +#include "ir.h" + +namespace pnnx { + +int save_ncnn(const Graph& g, const std::string& parampath, const std::string& binpath, const std::string& pypath); + +} // namespace pnnx + +#endif // PNNX_SAVE_NCNN_H diff --git a/tools/pnnx/src/save_onnx.cpp b/tools/pnnx/src/save_onnx.cpp new file mode 100644 index 000000000..86c64b904 --- /dev/null +++ b/tools/pnnx/src/save_onnx.cpp @@ -0,0 +1,268 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "save_onnx.h" + +#include "onnx.pb.h" + +#include +#include +#include + +namespace pnnx { + +// from cxxabi bridge +extern const char* get_operand_name(const Operand* x); +extern const char* get_operator_type(const Operator* op); +extern const char* get_operator_name(const Operator* op); +extern std::vector get_operator_params_keys(const Operator* op); +extern std::vector get_operator_attrs_keys(const Operator* op); +extern const Parameter& get_operator_param(const Operator* op, const char* key); +extern const Attribute& get_operator_attr(const Operator* op, const char* key); +extern const char* get_param_s(const Parameter& p); +extern std::vector get_param_as(const Parameter& p); + +int save_onnx(const Graph& g, const char* onnxpath) +{ + onnx::ModelProto model; + + onnx::GraphProto* gp = model.mutable_graph(); + + for (const Operand* x : g.operands) + { + onnx::ValueInfoProto* vip = gp->add_value_info(); + + vip->set_name(get_operand_name(x)); + + onnx::TypeProto* tp = vip->mutable_type(); + + onnx::TypeProto_Tensor* tpt = tp->mutable_tensor_type(); + + switch (x->type) + { + case 1: // f32 + tpt->set_elem_type(1); + break; + case 2: // f64 + tpt->set_elem_type(11); + break; + case 3: // f16 + tpt->set_elem_type(10); + break; + case 4: // i32 + tpt->set_elem_type(6); + break; + case 5: // i64 + tpt->set_elem_type(7); + break; + case 6: // i16 + tpt->set_elem_type(5); + break; + case 7: // i8 + tpt->set_elem_type(3); + break; + case 8: // u8 + tpt->set_elem_type(2); + break; + case 9: // bool + tpt->set_elem_type(9); + break; + case 10: // cp64 + tpt->set_elem_type(14); + break; + case 11: // cp128 + tpt->set_elem_type(15); + break; + case 12: // cp32 + tpt->set_elem_type(0); + break; + default: // null + tpt->set_elem_type(0); + break; + } + + onnx::TensorShapeProto* tsp = tpt->mutable_shape(); + + for (auto s : x->shape) + { + onnx::TensorShapeProto_Dimension* tspd = tsp->add_dim(); + + tspd->set_dim_value(s); + } + } + + for (const Operator* op : g.ops) + { + onnx::NodeProto* np = gp->add_node(); + + np->set_op_type(get_operator_type(op)); + np->set_name(get_operator_name(op)); + + for (const Operand* oprand : op->inputs) + { + np->add_input(get_operand_name(oprand)); + } + + for (const Operand* oprand : op->outputs) + { + np->add_output(get_operand_name(oprand)); + } + + std::vector params_keys = get_operator_params_keys(op); + + // for (const auto& it : op->params) + for (const char* param_name : params_keys) + { + // const Parameter& param = it.second; + const Parameter& param = get_operator_param(op, param_name); + + onnx::AttributeProto* ap = np->add_attribute(); + + // ap->set_name(get_param_name(it)); + ap->set_name(param_name); + + if (param.type == 0) + { + ap->set_s("None"); + } + if (param.type == 1) + { + if (param.b) + ap->set_i(1); + else + ap->set_i(0); + } + if (param.type == 2) + { + ap->set_i(param.i); + } + if (param.type == 3) + { + ap->set_f(param.f); + } + if (param.type == 4) + { + ap->set_s(get_param_s(param)); + } + if (param.type == 5) + { + for (auto i : param.ai) + { + ap->add_ints(i); + } + } + if (param.type == 6) + { + for (auto f : param.af) + { + ap->add_floats(f); + } + } + if (param.type == 7) + { + std::vector as = get_param_as(param); + for (auto s : as) + { + ap->add_strings(s); + } + } + } + + std::vector attrs_keys = get_operator_attrs_keys(op); + + // for (const auto& it : op->attrs) + for (const char* attr_name : attrs_keys) + { + onnx::TensorProto* tp = gp->add_initializer(); + + tp->set_name(std::string(get_operator_name(op)) + "." + attr_name); + + np->add_input(std::string(get_operator_name(op)) + "." + attr_name); + + // const Attribute& attr = it.second; + const Attribute& attr = get_operator_attr(op, attr_name); + for (auto s : attr.shape) + { + tp->add_dims(s); + } + + switch (attr.type) + { + case 1: // f32 + tp->set_data_type(1); + break; + case 2: // f64 + tp->set_data_type(11); + break; + case 3: // f16 + tp->set_data_type(10); + break; + case 4: // i32 + tp->set_data_type(6); + break; + case 5: // i64 + tp->set_data_type(7); + break; + case 6: // i16 + tp->set_data_type(5); + break; + case 7: // i8 + tp->set_data_type(3); + break; + case 8: // u8 + tp->set_data_type(2); + break; + case 9: // bool + tp->set_data_type(9); + break; + case 10: // cp64 + tp->set_data_type(14); + break; + case 11: // cp128 + tp->set_data_type(15); + break; + case 12: // cp32 + tp->set_data_type(0); + break; + default: // null + tp->set_data_type(0); + break; + } + + std::string* d = tp->mutable_raw_data(); + d->resize(attr.data.size()); + memcpy((void*)d->data(), attr.data.data(), attr.data.size()); + } + + // if (op->inputnames.size() == op->inputs.size()) + // { + // for (size_t i = 0; i < op->inputs.size(); i++) + // { + // const Operand* oprand = op->inputs[i]; + // fprintf(paramfp, " $%s=%s", op->inputnames[i].c_str(), oprand->name.c_str()); + // } + // } + } + + std::fstream output(onnxpath, std::ios::out | std::ios::trunc | std::ios::binary); + if (!model.SerializeToOstream(&output)) + { + fprintf(stderr, "write onnx failed\n"); + return -1; + } + + return 0; +} + +} // namespace pnnx diff --git a/tools/pnnx/src/save_onnx.h b/tools/pnnx/src/save_onnx.h new file mode 100644 index 000000000..236a9911e --- /dev/null +++ b/tools/pnnx/src/save_onnx.h @@ -0,0 +1,26 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#ifndef PNNX_SAVE_ONNX_H +#define PNNX_SAVE_ONNX_H + +#include "ir.h" + +namespace pnnx { + +int save_onnx(const Graph& g, const char* onnxpath); + +} // namespace pnnx + +#endif // PNNX_SAVE_ONNX_H diff --git a/tools/pnnx/src/save_onnx_cxxabi_bridge.cpp b/tools/pnnx/src/save_onnx_cxxabi_bridge.cpp new file mode 100644 index 000000000..b74f2ab7a --- /dev/null +++ b/tools/pnnx/src/save_onnx_cxxabi_bridge.cpp @@ -0,0 +1,81 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "ir.h" + +namespace pnnx { + +const char* get_operand_name(const Operand* x) +{ + return x->name.c_str(); +} + +const char* get_operator_type(const Operator* op) +{ + return op->type.c_str(); +} + +const char* get_operator_name(const Operator* op) +{ + return op->name.c_str(); +} + +std::vector get_operator_params_keys(const Operator* op) +{ + std::vector keys; + for (const auto& it : op->params) + { + const std::string& key = it.first; + keys.push_back(key.c_str()); + } + return keys; +} + +std::vector get_operator_attrs_keys(const Operator* op) +{ + std::vector keys; + for (const auto& it : op->attrs) + { + const std::string& key = it.first; + keys.push_back(key.c_str()); + } + return keys; +} + +const Parameter& get_operator_param(const Operator* op, const char* key) +{ + return op->params.at(key); +} + +const Attribute& get_operator_attr(const Operator* op, const char* key) +{ + return op->attrs.at(key); +} + +const char* get_param_s(const Parameter& p) +{ + return p.s.c_str(); +} + +std::vector get_param_as(const Parameter& p) +{ + std::vector as; + for (const auto& s : p.as) + { + as.push_back(s.c_str()); + } + return as; +} + +} // namespace pnnx