Browse Source

pnnx save onnx zero (#4077)

tags/20221128
nihui GitHub 3 years ago
parent
commit
cb88e16fdf
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 1360 additions and 321 deletions
  1. +1
    -1
      .ci/pnnx.yml
  2. +6
    -1
      tools/pnnx/README.md
  3. +30
    -2
      tools/pnnx/src/CMakeLists.txt
  4. +20
    -308
      tools/pnnx/src/ir.cpp
  5. +19
    -8
      tools/pnnx/src/ir.h
  6. +17
    -1
      tools/pnnx/src/main.cpp
  7. +505
    -0
      tools/pnnx/src/onnx.proto
  8. +361
    -0
      tools/pnnx/src/save_ncnn.cpp
  9. +26
    -0
      tools/pnnx/src/save_ncnn.h
  10. +268
    -0
      tools/pnnx/src/save_onnx.cpp
  11. +26
    -0
      tools/pnnx/src/save_onnx.h
  12. +81
    -0
      tools/pnnx/src/save_onnx_cxxabi_bridge.cpp

+ 1
- 1
.ci/pnnx.yml View File

@@ -58,7 +58,7 @@ jobs:
- name: install-deps
run: |
apt-get update
apt-get install -y python3-pip libjpeg-dev libpng-dev
apt-get install -y python3-pip libjpeg-dev libpng-dev libprotobuf-dev protobuf-compiler
python3 -m pip install --upgrade pip
pip3 uninstall -y setuptools
pip3 install -U pytest setuptools wheel twine distribute requests


+ 6
- 1
tools/pnnx/README.md View File

@@ -62,7 +62,7 @@ mod.save("resnet18.pt")
pnnx resnet18.pt inputshape=[1,3,224,224]
```

Normally, you will get six files
Normally, you will get seven files

```resnet18.pnnx.param``` PNNX graph definition

@@ -70,6 +70,8 @@ Normally, you will get six files

```resnet18_pnnx.py``` PyTorch script for inference, the python code for model construction and weight initialization

```resnet18.pnnx.onnx``` PNNX model in onnx format

```resnet18.ncnn.param``` ncnn graph definition

```resnet18.ncnn.bin``` ncnn model weight
@@ -87,6 +89,7 @@ Usage: pnnx [model.pt] [(key=value)...]
pnnxparam=model.pnnx.param
pnnxbin=model.pnnx.bin
pnnxpy=model_pnnx.py
pnnxonnx=model.pnnx.onnx
ncnnparam=model.ncnn.param
ncnnbin=model.ncnn.bin
ncnnpy=model_ncnn.py
@@ -108,6 +111,8 @@ Parameters:

`pnnxpy` (default="*_pnnx.py"): PyTorch script for inference, including model construction and weight initialization code

`pnnxonnx` (default="*.pnnx.onnx"): PNNX model in onnx format

`ncnnparam` (default="*.ncnn.param"): ncnn graph definition

`ncnnbin` (default="*.ncnn.bin"): ncnn model weight


+ 30
- 2
tools/pnnx/src/CMakeLists.txt View File

@@ -497,6 +497,27 @@ set(pnnx_pass_ncnn_SRCS
pass_ncnn/torchvision_DeformConv2d.cpp
)

find_package(Protobuf)
if(PROTOBUF_FOUND)
protobuf_generate_cpp(ONNX_PROTO_SRCS ONNX_PROTO_HDRS onnx.proto)

add_library(pnnx2onnx STATIC
save_onnx.cpp
save_onnx_cxxabi_bridge.cpp
${ONNX_PROTO_SRCS} ${ONNX_PROTO_HDRS}
)

target_include_directories(pnnx2onnx PRIVATE ${PROTOBUF_INCLUDE_DIR} ${CMAKE_CURRENT_BINARY_DIR})
target_link_libraries(pnnx2onnx PRIVATE ${PROTOBUF_LIBRARIES})

# libtorch is usually compiled with old cxx11 abi
set_source_files_properties(save_onnx_cxxabi_bridge.cpp PROPERTIES COMPILE_FLAGS "${TORCH_CXX_FLAGS}")

message(STATUS "Building with onnx-zero")
else()
message(STATUS "Building without onnx-zero")
endif()

set(pnnx_SRCS
main.cpp
ir.cpp
@@ -510,8 +531,6 @@ set(pnnx_SRCS
pass_level4.cpp
pass_level5.cpp

pass_ncnn.cpp

${pnnx_pass_level0_SRCS}
${pnnx_pass_level1_SRCS}
${pnnx_pass_level2_SRCS}
@@ -519,6 +538,8 @@ set(pnnx_SRCS
${pnnx_pass_level4_SRCS}
${pnnx_pass_level5_SRCS}

pass_ncnn.cpp
save_ncnn.cpp
${pnnx_pass_ncnn_SRCS}
)

@@ -528,6 +549,8 @@ endif()

add_executable(pnnx ${pnnx_SRCS})

target_compile_definitions(pnnx PRIVATE BUILD_PNNX)

if(PNNX_COVERAGE)
target_compile_options(pnnx PUBLIC -coverage -fprofile-arcs -ftest-coverage)
target_link_libraries(pnnx PUBLIC -coverage -lgcov)
@@ -537,6 +560,11 @@ if(WIN32)
target_compile_definitions(pnnx PUBLIC NOMINMAX)
endif()

if(PROTOBUF_FOUND)
target_compile_definitions(pnnx PRIVATE BUILD_PNNX2ONNX)
target_link_libraries(pnnx PRIVATE pnnx2onnx)
endif()

if(TorchVision_FOUND)
target_link_libraries(pnnx PRIVATE TorchVision::TorchVision)
endif()


+ 20
- 308
tools/pnnx/src/ir.cpp View File

@@ -16,13 +16,16 @@

#include <limits.h>
#include <stdint.h>
#include <string.h>
#include <algorithm>
#include <fstream>
#include <sstream>
#include <string>
#include <stack>

#if BUILD_PNNX
#include <torch/script.h>
#endif

#include "storezip.h"

@@ -130,6 +133,7 @@ static int string_to_type(const char* s)
return 0; // null
}

#if BUILD_PNNX
int get_at_tensor_type(const at::ScalarType& st)
{
if (st == c10::ScalarType::Float) return 1;
@@ -295,6 +299,7 @@ Parameter::Parameter(const torch::jit::Value* value)
: Parameter(value->node())
{
}
#endif // BUILD_PNNX

bool operator==(const Parameter& lhs, const Parameter& rhs)
{
@@ -328,6 +333,7 @@ bool operator==(const Parameter& lhs, const Parameter& rhs)
return false;
}

#if BUILD_PNNX
Attribute::Attribute(const at::Tensor& t)
{
type = get_at_tensor_type(t.scalar_type());
@@ -384,6 +390,7 @@ Attribute::Attribute(const at::Tensor& t)
memcpy((void*)data.data(), (const void*)t.cpu().contiguous().data_ptr(), data.size());
}
}
#endif // BUILD_PNNX

Attribute::Attribute(const std::initializer_list<int>& _shape, const std::vector<float>& t)
{
@@ -2289,314 +2296,6 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath)
return 0;
}

static bool string_is_positive_integer(const std::string& t)
{
for (size_t i = 0; i < t.size(); i++)
{
if (t[i] < '0' || t[i] > '9')
return false;
}

return true;
}

int Graph::ncnn(const std::string& parampath, const std::string& binpath, const std::string& pypath)
{
FILE* paramfp = fopen(parampath.c_str(), "wb");
if (!paramfp)
{
fprintf(stderr, "fopen %s failed\n", parampath.c_str());
return -1;
}

FILE* binfp = fopen(binpath.c_str(), "wb");
if (!binfp)
{
fprintf(stderr, "fopen %s failed\n", binpath.c_str());
fclose(paramfp);
return -1;
}

// magic
fprintf(paramfp, "7767517\n");

// op count and oprand count
fprintf(paramfp, "%d %d\n", (int)ops.size(), (int)operands.size());

for (const Operator* op : ops)
{
fprintf(paramfp, "%-24s %-24s %d %d", op->type.c_str(), op->name.c_str(), (int)op->inputs.size(), (int)op->outputs.size());

for (const Operand* oprand : op->inputs)
{
fprintf(paramfp, " %s", oprand->name.c_str());
}

for (const Operand* oprand : op->outputs)
{
fprintf(paramfp, " %s", oprand->name.c_str());
}

for (const auto& it : op->params)
{
const Parameter& param = it.second;

if (!string_is_positive_integer(it.first))
{
fprintf(stderr, "ignore %s %s param %s=", op->type.c_str(), op->name.c_str(), it.first.c_str());

if (param.type == 0)
{
fprintf(stderr, "None");
}
if (param.type == 1)
{
if (param.b)
fprintf(stderr, "True");
else
fprintf(stderr, "False");
}
if (param.type == 2)
{
fprintf(stderr, "%d", param.i);
}
if (param.type == 3)
{
fprintf(stderr, "%e", param.f);
}
if (param.type == 4)
{
fprintf(stderr, "%s", param.s.c_str());
}
if (param.type == 5)
{
fprintf(stderr, "(");
for (size_t i = 0; i < param.ai.size(); i++)
{
fprintf(stderr, "%d", param.ai[i]);
if (i + 1 != param.ai.size())
fprintf(stderr, ",");
}
fprintf(stderr, ")");
}
if (param.type == 6)
{
fprintf(stderr, "(");
for (size_t i = 0; i < param.af.size(); i++)
{
fprintf(stderr, "%e", param.af[i]);
if (i + 1 != param.af.size())
fprintf(stderr, ",");
}
fprintf(stderr, ")");
}
if (param.type == 7)
{
fprintf(stderr, "(");
for (size_t i = 0; i < param.as.size(); i++)
{
fprintf(stderr, "%s", param.as[i].c_str());
if (i + 1 != param.as.size())
fprintf(stderr, ",");
}
fprintf(stderr, ")");
}
fprintf(stderr, "\n");

continue;
}

const int idkey = std::stoi(it.first);
if (param.type == 2)
{
fprintf(paramfp, " %d=%d", idkey, param.i);
}
if (param.type == 3)
{
fprintf(paramfp, " %d=%e", idkey, param.f);
}
if (param.type == 5)
{
const int array_size = (int)param.ai.size();
fprintf(paramfp, " %d=%d", -23300 - idkey, array_size);
for (size_t i = 0; i < param.ai.size(); i++)
{
fprintf(paramfp, ",%d", param.ai[i]);
}
}
if (param.type == 6)
{
const int array_size = (int)param.af.size();
fprintf(paramfp, " %d=%d", -23300 - idkey, array_size);
for (size_t i = 0; i < param.af.size(); i++)
{
fprintf(paramfp, ",%e", param.af[i]);
}
}
}

for (const auto& it : op->attrs)
{
// fprintf(paramfp, " @%s=", it.first.c_str());

const Attribute& attr = it.second;

fwrite(attr.data.data(), attr.data.size(), 1, binfp);
}

// if (op->inputnames.size() == op->inputs.size())
// {
// for (size_t i = 0; i < op->inputs.size(); i++)
// {
// const Operand* oprand = op->inputs[i];
// fprintf(paramfp, " $%s=%s", op->inputnames[i].c_str(), oprand->name.c_str());
// }
// }

// for (const Operand* oprand : op->outputs)
// {
// if (oprand->params.find("__batch_index") == oprand->params.end())
// continue;
//
// const int batch_index = oprand->params.at("__batch_index").i;
//
// fprintf(paramfp, " #%s=%d", oprand->name.c_str(), batch_index);
// }

// for (const Operand* oprand : op->outputs)
// {
// if (oprand->shape.empty())
// continue;
//
// fprintf(paramfp, " #%s=", oprand->name.c_str());
//
// fprintf(paramfp, "(");
// for (int64_t i = 0; i < oprand->shape.size() - 1; i++)
// {
// fprintf(paramfp, "%d,", oprand->shape[i]);
// }
// if (oprand->shape.size() > 0)
// fprintf(paramfp, "%d", oprand->shape[oprand->shape.size() - 1]);
// fprintf(paramfp, ")");
//
// fprintf(paramfp, type_to_string(oprand->type));
// }

fprintf(paramfp, "\n");
}

fclose(paramfp);
fclose(binfp);

FILE* pyfp = fopen(pypath.c_str(), "wb");
if (!pyfp)
{
fprintf(stderr, "fopen %s failed\n", pypath.c_str());
return -1;
}

fprintf(pyfp, "import numpy as np\n");
fprintf(pyfp, "import ncnn\n");
fprintf(pyfp, "import torch\n");

fprintf(pyfp, "\n");

// test inference
{
fprintf(pyfp, "def test_inference():\n");
fprintf(pyfp, " torch.manual_seed(0)\n");

for (int input_index = 0;; input_index++)
{
std::string input_name = std::string("in") + std::to_string(input_index);
const Operand* r = get_operand(input_name);
if (!r)
break;

if (type_is_integer(r->type))
{
fprintf(pyfp, " %s = torch.randint(10, (", input_name.c_str());
for (size_t i = 0; i < r->shape.size(); i++)
{
fprintf(pyfp, "%d", r->shape[i]);
if (i + 1 != r->shape.size() || r->shape.size() == 1)
fprintf(pyfp, ", ");
}
fprintf(pyfp, "), dtype=%s)\n", type_to_dtype_string(r->type));
}
else
{
fprintf(pyfp, " %s = torch.rand(", input_name.c_str());
for (size_t i = 0; i < r->shape.size(); i++)
{
fprintf(pyfp, "%d, ", r->shape[i]);
}
fprintf(pyfp, "dtype=%s)\n", type_to_dtype_string(r->type));
}
}

fprintf(pyfp, " out = []\n");
fprintf(pyfp, "\n");

fprintf(pyfp, " with ncnn.Net() as net:\n");
fprintf(pyfp, " net.load_param(\"%s\")\n", parampath.c_str());
fprintf(pyfp, " net.load_model(\"%s\")\n", binpath.c_str());
fprintf(pyfp, "\n");
fprintf(pyfp, " with net.create_extractor() as ex:\n");

for (int input_index = 0;; input_index++)
{
std::string input_name = std::string("in") + std::to_string(input_index);
const Operand* r = get_operand(input_name);
if (!r)
break;

const int batch_index = r->params.at("__batch_index").i;
if (batch_index != 233)
{
fprintf(pyfp, " ex.input(\"%s\", ncnn.Mat(%s.squeeze(%d).numpy()).clone())\n", input_name.c_str(), input_name.c_str(), batch_index);
}
else
{
fprintf(pyfp, " ex.input(\"%s\", ncnn.Mat(%s.numpy()).clone())\n", input_name.c_str(), input_name.c_str());
}
}

fprintf(pyfp, "\n");

for (int output_index = 0;; output_index++)
{
std::string output_name = std::string("out") + std::to_string(output_index);
const Operand* r = get_operand(output_name);
if (!r)
break;

fprintf(pyfp, " _, %s = ex.extract(\"%s\")\n", output_name.c_str(), output_name.c_str());

const int batch_index = r->params.at("__batch_index").i;
if (batch_index != 233)
{
fprintf(pyfp, " out.append(torch.from_numpy(np.array(%s)).unsqueeze(%d))\n", output_name.c_str(), batch_index);
}
else
{
fprintf(pyfp, " out.append(torch.from_numpy(np.array(%s)))\n", output_name.c_str());
}
}

fprintf(pyfp, "\n");

fprintf(pyfp, " if len(out) == 1:\n");
fprintf(pyfp, " return out[0]\n");
fprintf(pyfp, " else:\n");
fprintf(pyfp, " return tuple(out)\n");
}

fclose(pyfp);

return 0;
}

int Graph::parse(const std::string& param)
{
std::istringstream is(param);
@@ -2731,6 +2430,7 @@ Operator* Graph::new_operator_after(const std::string& type, const std::string&
return op;
}

#if BUILD_PNNX
Operand* Graph::new_operand(const torch::jit::Value* v)
{
Operand* r = new Operand;
@@ -2757,6 +2457,7 @@ Operand* Graph::new_operand(const torch::jit::Value* v)
operands.push_back(r);
return r;
}
#endif // BUILD_PNNX

Operand* Graph::new_operand(const std::string& name)
{
@@ -2777,4 +2478,15 @@ Operand* Graph::get_operand(const std::string& name)
return 0;
}

const Operand* Graph::get_operand(const std::string& name) const
{
for (const Operand* r : operands)
{
if (r->name == name)
return r;
}

return 0;
}

} // namespace pnnx

+ 19
- 8
tools/pnnx/src/ir.h View File

@@ -20,6 +20,7 @@
#include <string>
#include <vector>

#if BUILD_PNNX
namespace torch {
namespace jit {
struct Value;
@@ -29,6 +30,7 @@ struct Node;
namespace at {
class Tensor;
}
#endif // BUILD_PNNX

namespace pnnx {

@@ -114,8 +116,10 @@ public:
{
}

#if BUILD_PNNX
Parameter(const torch::jit::Node* value_node);
Parameter(const torch::jit::Value* value);
#endif // BUILD_PNNX

static Parameter parse_from_string(const std::string& value);

@@ -126,9 +130,11 @@ public:
bool b;
int i;
float f;
std::string s;
std::vector<int> ai;
std::vector<float> af;

// keep std::string typed member the last for cross cxxabi compatibility
std::string s;
std::vector<std::string> as;
};

@@ -142,7 +148,9 @@ public:
{
}

#if BUILD_PNNX
Attribute(const at::Tensor& t);
#endif // BUILD_PNNX

Attribute(const std::initializer_list<int>& shape, const std::vector<float>& t);

@@ -164,8 +172,6 @@ class Operand
public:
void remove_consumer(const Operator* c);

std::string name;

Operator* producer;
std::vector<Operator*> consumers;

@@ -173,6 +179,9 @@ public:
int type;
std::vector<int> shape;

// keep std::string typed member the last for cross cxxabi compatibility
std::string name;

std::map<std::string, Parameter> params;

private:
@@ -185,12 +194,13 @@ private:
class Operator
{
public:
std::string type;
std::string name;

std::vector<Operand*> inputs;
std::vector<Operand*> outputs;

// keep std::string typed member the last for cross cxxabi compatibility
std::string type;
std::string name;

std::vector<std::string> inputnames;
std::map<std::string, Parameter> params;
std::map<std::string, Attribute> attrs;
@@ -213,8 +223,6 @@ public:

int python(const std::string& pypath, const std::string& binpath);

int ncnn(const std::string& parampath, const std::string& binpath, const std::string& pypath);

int parse(const std::string& param);

Operator* new_operator(const std::string& type, const std::string& name);
@@ -223,11 +231,14 @@ public:

Operator* new_operator_after(const std::string& type, const std::string& name, const Operator* cur);

#if BUILD_PNNX
Operand* new_operand(const torch::jit::Value* v);
#endif

Operand* new_operand(const std::string& name);

Operand* get_operand(const std::string& name);
const Operand* get_operand(const std::string& name) const;

std::vector<Operator*> ops;
std::vector<Operand*> operands;


+ 17
- 1
tools/pnnx/src/main.cpp View File

@@ -39,6 +39,11 @@
#include "pass_level5.h"

#include "pass_ncnn.h"
#include "save_ncnn.h"

#if BUILD_PNNX2ONNX
#include "save_onnx.h"
#endif

static std::string get_basename(const std::string& path)
{
@@ -159,6 +164,7 @@ static void show_usage()
fprintf(stderr, " pnnxparam=model.pnnx.param\n");
fprintf(stderr, " pnnxbin=model.pnnx.bin\n");
fprintf(stderr, " pnnxpy=model_pnnx.py\n");
fprintf(stderr, " pnnxonnx=model.pnnx.onnx\n");
fprintf(stderr, " ncnnparam=model.ncnn.param\n");
fprintf(stderr, " ncnnbin=model.ncnn.bin\n");
fprintf(stderr, " ncnnpy=model_ncnn.py\n");
@@ -200,6 +206,7 @@ int main(int argc, char** argv)
std::string pnnxparampath = ptbase + ".pnnx.param";
std::string pnnxbinpath = ptbase + ".pnnx.bin";
std::string pnnxpypath = ptbase + "_pnnx.py";
std::string pnnxonnxpath = ptbase + ".pnnx.onnx";
std::string ncnnparampath = ptbase + ".ncnn.param";
std::string ncnnbinpath = ptbase + ".ncnn.bin";
std::string ncnnpypath = ptbase + "_ncnn.py";
@@ -235,6 +242,8 @@ int main(int argc, char** argv)
pnnxbinpath = std::string(value);
if (strcmp(key, "pnnxpy") == 0)
pnnxpypath = std::string(value);
if (strcmp(key, "pnnxonnx") == 0)
pnnxonnxpath = std::string(value);
if (strcmp(key, "ncnnparam") == 0)
ncnnparampath = std::string(value);
if (strcmp(key, "ncnnbin") == 0)
@@ -260,6 +269,7 @@ int main(int argc, char** argv)
fprintf(stderr, "pnnxparam = %s\n", pnnxparampath.c_str());
fprintf(stderr, "pnnxbin = %s\n", pnnxbinpath.c_str());
fprintf(stderr, "pnnxpy = %s\n", pnnxpypath.c_str());
fprintf(stderr, "pnnxonnx = %s\n", pnnxonnxpath.c_str());
fprintf(stderr, "ncnnparam = %s\n", ncnnparampath.c_str());
fprintf(stderr, "ncnnbin = %s\n", ncnnbinpath.c_str());
fprintf(stderr, "ncnnpy = %s\n", ncnnpypath.c_str());
@@ -400,13 +410,19 @@ int main(int argc, char** argv)

pnnx_graph.python(pnnxpypath, pnnxbinpath);

#if BUILD_PNNX2ONNX
pnnx::save_onnx(pnnx_graph, pnnxonnxpath.c_str());
#else
fprintf(stderr, "pnnx build without onnx-zero support, skip saving onnx\n");
#endif

// if (optlevel >= 2)
{
fprintf(stderr, "############# pass_ncnn\n");

pnnx::pass_ncnn(pnnx_graph);

pnnx_graph.ncnn(ncnnparampath, ncnnbinpath, ncnnpypath);
pnnx::save_ncnn(pnnx_graph, ncnnparampath, ncnnbinpath, ncnnpypath);
}

// pnnx::Graph pnnx_graph2;


+ 505
- 0
tools/pnnx/src/onnx.proto View File

@@ -0,0 +1,505 @@
//
// WARNING: This file is automatically generated! Please edit onnx.in.proto.
//


// Copyright (c) ONNX Project Contributors.
// Licensed under the MIT license.

syntax = "proto2";

package onnx;

// Overview
//
// ONNX is an open specification that is comprised of the following components:
//
// 1) A definition of an extensible computation graph model.
// 2) Definitions of standard data types.
// 3) Definitions of built-in operators.
//
// This document describes the syntax of models and their computation graphs,
// as well as the standard data types. Together, they are referred to as the ONNX
// Intermediate Representation, or 'IR' for short.
//
// The normative semantic specification of the ONNX IR is found in docs/IR.md.
// Definitions of the built-in neural network operators may be found in docs/Operators.md.

// Notes
//
// Release
//
// We are still in the very early stage of defining ONNX. The current
// version of ONNX is a starting point. While we are actively working
// towards a complete spec, we would like to get the community involved
// by sharing our working version of ONNX.
//
// Protobuf compatibility
//
// To simplify framework compatibility, ONNX is defined using the subset of protobuf
// that is compatible with both protobuf v2 and v3. This means that we do not use any
// protobuf features that are only available in one of the two versions.
//
// Here are the most notable contortions we have to carry out to work around
// these limitations:
//
// - No 'map' (added protobuf 3.0). We instead represent mappings as lists
// of key-value pairs, where order does not matter and duplicates
// are not allowed.


// Versioning
//
// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md
//
// To be compatible with both proto2 and proto3, we will use a version number
// that is not defined by the default value but an explicit enum number.
enum Version {
// proto3 requires the first enum value to be zero.
// We add this just to appease the compiler.
_START_VERSION = 0;
// The version field is always serialized and we will use it to store the
// version that the graph is generated from. This helps us set up version
// control.
// For the IR, we are using simple numbers starting with with 0x00000001,
// which was the version we published on Oct 10, 2017.
IR_VERSION_2017_10_10 = 0x0000000000000001;

// IR_VERSION 2 published on Oct 30, 2017
// - Added type discriminator to AttributeProto to support proto3 users
IR_VERSION_2017_10_30 = 0x0000000000000002;

// IR VERSION 3 published on Nov 3, 2017
// - For operator versioning:
// - Added new message OperatorSetIdProto
// - Added opset_import in ModelProto
// - For vendor extensions, added domain in NodeProto
IR_VERSION_2017_11_3 = 0x0000000000000003;

// IR VERSION 4 published on Jan 22, 2019
// - Relax constraint that initializers should be a subset of graph inputs
// - Add type BFLOAT16
IR_VERSION_2019_1_22 = 0x0000000000000004;

// IR VERSION 5 published on March 18, 2019
// - Add message TensorAnnotation.
// - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters.
IR_VERSION = 0x0000000000000005;
}

// Attributes
//
// A named attribute containing either singular float, integer, string, graph,
// and tensor values, or repeated float, integer, string, graph, and tensor values.
// An AttributeProto MUST contain the name field, and *only one* of the
// following content fields, effectively enforcing a C/C++ union equivalent.
message AttributeProto {

// Note: this enum is structurally identical to the OpSchema::AttrType
// enum defined in schema.h. If you rev one, you likely need to rev the other.
enum AttributeType {
UNDEFINED = 0;
FLOAT = 1;
INT = 2;
STRING = 3;
TENSOR = 4;
GRAPH = 5;

FLOATS = 6;
INTS = 7;
STRINGS = 8;
TENSORS = 9;
GRAPHS = 10;
}

// The name field MUST be present for this version of the IR.
optional string name = 1; // namespace Attribute
// if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function.
// In this case, this AttributeProto does not contain data, and it's a reference of attribute
// in parent scope.
// NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph.
optional string ref_attr_name = 21;

// A human-readable documentation for this attribute. Markdown is allowed.
optional string doc_string = 13;

// The type field MUST be present for this version of the IR.
// For 0.0.1 versions of the IR, this field was not defined, and
// implementations needed to use has_field hueristics to determine
// which value field was in use. For IR_VERSION 0.0.2 or later, this
// field MUST be set and match the f|i|s|t|... field in use. This
// change was made to accommodate proto3 implementations.
optional AttributeType type = 20; // discriminator that indicates which field below is in use

// Exactly ONE of the following fields must be present for this version of the IR
optional float f = 2; // float
optional int64 i = 3; // int
optional bytes s = 4; // UTF-8 string
optional TensorProto t = 5; // tensor value
optional GraphProto g = 6; // graph
// Do not use field below, it's deprecated.
// optional ValueProto v = 12; // value - subsumes everything but graph

repeated float floats = 7; // list of floats
repeated int64 ints = 8; // list of ints
repeated bytes strings = 9; // list of UTF-8 strings
repeated TensorProto tensors = 10; // list of tensors
repeated GraphProto graphs = 11; // list of graph
}

// Defines information on value, including the name, the type, and
// the shape of the value.
message ValueInfoProto {
// This field MUST be present in this version of the IR.
optional string name = 1; // namespace Value
// This field MUST be present in this version of the IR.
optional TypeProto type = 2;
// A human-readable documentation for this value. Markdown is allowed.
optional string doc_string = 3;
}

// Nodes
//
// Computation graphs are made up of a DAG of nodes, which represent what is
// commonly called a "layer" or "pipeline stage" in machine learning frameworks.
//
// For example, it can be a node of type "Conv" that takes in an image, a filter
// tensor and a bias tensor, and produces the convolved output.
message NodeProto {
repeated string input = 1; // namespace Value
repeated string output = 2; // namespace Value

// An optional identifier for this node in a graph.
// This field MAY be absent in ths version of the IR.
optional string name = 3; // namespace Node

// The symbolic identifier of the Operator to execute.
optional string op_type = 4; // namespace Operator
// The domain of the OperatorSet that specifies the operator named by op_type.
optional string domain = 7; // namespace Domain

// Additional named attributes.
repeated AttributeProto attribute = 5;

// A human-readable documentation for this node. Markdown is allowed.
optional string doc_string = 6;
}

// Models
//
// ModelProto is a top-level file/container format for bundling a ML model and
// associating its computation graph with metadata.
//
// The semantics of the model are described by the associated GraphProto.
message ModelProto {
// The version of the IR this model targets. See Version enum above.
// This field MUST be present.
optional int64 ir_version = 1;

// The OperatorSets this model relies on.
// All ModelProtos MUST have at least one entry that
// specifies which version of the ONNX OperatorSet is
// being imported.
//
// All nodes in the ModelProto's graph will bind against the operator
// with the same-domain/same-op_type operator with the HIGHEST version
// in the referenced operator sets.
repeated OperatorSetIdProto opset_import = 8;

// The name of the framework or tool used to generate this model.
// This field SHOULD be present to indicate which implementation/tool/framework
// emitted the model.
optional string producer_name = 2;

// The version of the framework or tool used to generate this model.
// This field SHOULD be present to indicate which implementation/tool/framework
// emitted the model.
optional string producer_version = 3;

// Domain name of the model.
// We use reverse domain names as name space indicators. For example:
// `com.facebook.fair` or `com.microsoft.cognitiveservices`
//
// Together with `model_version` and GraphProto.name, this forms the unique identity of
// the graph.
optional string domain = 4;

// The version of the graph encoded. See Version enum below.
optional int64 model_version = 5;

// A human-readable documentation for this model. Markdown is allowed.
optional string doc_string = 6;

// The parameterized graph that is evaluated to execute the model.
optional GraphProto graph = 7;

// Named metadata values; keys should be distinct.
repeated StringStringEntryProto metadata_props = 14;
};

// StringStringEntryProto follows the pattern for cross-proto-version maps.
// See https://developers.google.com/protocol-buffers/docs/proto3#maps
message StringStringEntryProto {
optional string key = 1;
optional string value= 2;
};

message TensorAnnotation {
optional string tensor_name = 1;
// <key, value> pairs to annotate tensor specified by <tensor_name> above.
// The keys used in the mapping below must be pre-defined in ONNX spec.
// For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as
// quantization parameter keys.
repeated StringStringEntryProto quant_parameter_tensor_names = 2;
}



// Graphs
//
// A graph defines the computational logic of a model and is comprised of a parameterized
// list of nodes that form a directed acyclic graph based on their inputs and outputs.
// This is the equivalent of the "network" or "graph" in many deep learning
// frameworks.
message GraphProto {
// The nodes in the graph, sorted topologically.
repeated NodeProto node = 1;

// The name of the graph.
optional string name = 2; // namespace Graph

// A list of named tensor values, used to specify constant inputs of the graph.
// Each TensorProto entry must have a distinct name (within the list) that
// MAY also appear in the input list.
repeated TensorProto initializer = 5;

// A human-readable documentation for this graph. Markdown is allowed.
optional string doc_string = 10;

// The inputs and outputs of the graph.
repeated ValueInfoProto input = 11;
repeated ValueInfoProto output = 12;

// Information for the values in the graph. The ValueInfoProto.name's
// must be distinct. It is optional for a value to appear in value_info list.
repeated ValueInfoProto value_info = 13;

// This field carries information to indicate the mapping among a tensor and its
// quantization parameter tensors. For example:
// For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated,
// which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model.
repeated TensorAnnotation quantization_annotation = 14;

// DO NOT USE the following fields, they were deprecated from earlier versions.
// repeated string input = 3;
// repeated string output = 4;
// optional int64 ir_version = 6;
// optional int64 producer_version = 7;
// optional string producer_tag = 8;
// optional string domain = 9;
}

// Tensors
//
// A serialized tensor value.
message TensorProto {
enum DataType {
UNDEFINED = 0;
// Basic types.
FLOAT = 1; // float
UINT8 = 2; // uint8_t
INT8 = 3; // int8_t
UINT16 = 4; // uint16_t
INT16 = 5; // int16_t
INT32 = 6; // int32_t
INT64 = 7; // int64_t
STRING = 8; // string
BOOL = 9; // bool

// IEEE754 half-precision floating-point format (16 bits wide).
// This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits.
FLOAT16 = 10;

DOUBLE = 11;
UINT32 = 12;
UINT64 = 13;
COMPLEX64 = 14; // complex with float32 real and imaginary components
COMPLEX128 = 15; // complex with float64 real and imaginary components

// Non-IEEE floating-point format based on IEEE754 single-precision
// floating-point number truncated to 16 bits.
// This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.
BFLOAT16 = 16;

// Future extensions go here.
}

// The shape of the tensor.
repeated int64 dims = 1;

// The data type of the tensor.
// This field MUST have a valid TensorProto.DataType value
optional int32 data_type = 2;

// For very large tensors, we may want to store them in chunks, in which
// case the following fields will specify the segment that is stored in
// the current TensorProto.
message Segment {
optional int64 begin = 1;
optional int64 end = 2;
}
optional Segment segment = 3;

// Tensor content must be organized in row-major order.
//
// Depending on the data_type field, exactly one of the fields below with
// name ending in _data is used to store the elements of the tensor.

// For float and complex64 values
// Complex64 tensors are encoded as a single array of floats,
// with the real components appearing in odd numbered positions,
// and the corresponding imaginary component apparing in the
// subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
// is encoded as [1.0, 2.0 ,3.0 ,4.0]
// When this field is present, the data_type field MUST be FLOAT or COMPLEX64.
repeated float float_data = 4 [packed = true];

// For int32, uint8, int8, uint16, int16, bool, and float16 values
// float16 values must be bit-wise converted to an uint16_t prior
// to writing to the buffer.
// When this field is present, the data_type field MUST be
// INT32, INT16, INT8, UINT16, UINT8, BOOL, or FLOAT16
repeated int32 int32_data = 5 [packed = true];

// For strings.
// Each element of string_data is a UTF-8 encoded Unicode
// string. No trailing null, no leading BOM. The protobuf "string"
// scalar type is not used to match ML community conventions.
// When this field is present, the data_type field MUST be STRING
repeated bytes string_data = 6;

// For int64.
// When this field is present, the data_type field MUST be INT64
repeated int64 int64_data = 7 [packed = true];

// Optionally, a name for the tensor.
optional string name = 8; // namespace Value

// A human-readable documentation for this tensor. Markdown is allowed.
optional string doc_string = 12;

// Serializations can either use one of the fields above, or use this
// raw bytes field. The only exception is the string case, where one is
// required to store the content in the repeated bytes string_data field.
//
// When this raw_data field is used to store tensor value, elements MUST
// be stored in as fixed-width, little-endian order.
// Floating-point data types MUST be stored in IEEE 754 format.
// Complex64 elements must be written as two consecutive FLOAT values, real component first.
// Complex128 elements must be written as two consecutive DOUBLE values, real component first.
// Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false).
//
// Note: the advantage of specific field rather than the raw_data field is
// that in some cases (e.g. int data), protobuf does a better packing via
// variable length storage, and may lead to smaller binary footprint.
// When this field is present, the data_type field MUST NOT be STRING or UNDEFINED
optional bytes raw_data = 9;

// Data can be stored inside the protobuf file using type-specific fields or raw_data.
// Alternatively, raw bytes data can be stored in an external file, using the external_data field.
// external_data stores key-value pairs describing data location. Recognized keys are:
// - "location" (required) - POSIX filesystem path relative to the directory where the ONNX
// protobuf model was stored
// - "offset" (optional) - position of byte at which stored data begins. Integer stored as string.
// Offset values SHOULD be multiples 4096 (page size) to enable mmap support.
// - "length" (optional) - number of bytes containing data. Integer stored as string.
// - "checksum" (optional) - SHA1 digest of file specified in under 'location' key.
repeated StringStringEntryProto external_data = 13;

// Location of the data for this tensor. MUST be one of:
// - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field.
// - EXTERNAL - data stored in an external location as described by external_data field.
enum DataLocation {
DEFAULT = 0;
EXTERNAL = 1;
}

// If value not set, data is stored in raw_data (if set) otherwise in type-specified field.
optional DataLocation data_location = 14;

// For double
// Complex128 tensors are encoded as a single array of doubles,
// with the real components appearing in odd numbered positions,
// and the corresponding imaginary component apparing in the
// subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
// is encoded as [1.0, 2.0 ,3.0 ,4.0]
// When this field is present, the data_type field MUST be DOUBLE or COMPLEX128
repeated double double_data = 10 [packed = true];

// For uint64 and uint32 values
// When this field is present, the data_type field MUST be
// UINT32 or UINT64
repeated uint64 uint64_data = 11 [packed = true];
}

// Defines a tensor shape. A dimension can be either an integer value
// or a symbolic variable. A symbolic variable represents an unknown
// dimension.
message TensorShapeProto {
message Dimension {
oneof value {
int64 dim_value = 1;
string dim_param = 2; // namespace Shape
};
// Standard denotation can optionally be used to denote tensor
// dimensions with standard semantic descriptions to ensure
// that operations are applied to the correct axis of a tensor.
// Refer to https://github.com/onnx/onnx/blob/master/docs/DimensionDenotation.md#denotation-definition
// for pre-defined dimension denotations.
optional string denotation = 3;
};
repeated Dimension dim = 1;
}

// Types
//
// The standard ONNX data types.
message TypeProto {

message Tensor {
// This field MUST NOT have the value of UNDEFINED
// This field MUST have a valid TensorProto.DataType value
// This field MUST be present for this version of the IR.
optional int32 elem_type = 1;
optional TensorShapeProto shape = 2;
}


oneof value {
// The type of a tensor.
Tensor tensor_type = 1;

}

// An optional denotation can be used to denote the whole
// type with a standard semantic description as to what is
// stored inside. Refer to https://github.com/onnx/onnx/blob/master/docs/TypeDenotation.md#type-denotation-definition
// for pre-defined type denotations.
optional string denotation = 6;
}

// Operator Sets
//
// OperatorSets are uniquely identified by a (domain, opset_version) pair.
message OperatorSetIdProto {
// The domain of the operator set being identified.
// The empty string ("") or absence of this field implies the operator
// set that is defined as part of the ONNX specification.
// This field MUST be present in this version of the IR when referring to any other operator set.
optional string domain = 1;

// The version of the operator set being identified.
// This field MUST be present in this version of the IR.
optional int64 version = 2;
}

+ 361
- 0
tools/pnnx/src/save_ncnn.cpp View File

@@ -0,0 +1,361 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "save_ncnn.h"

namespace pnnx {

static bool type_is_integer(int type)
{
if (type == 1) return false;
if (type == 2) return false;
if (type == 3) return false;
if (type == 4) return true;
if (type == 5) return true;
if (type == 6) return true;
if (type == 7) return true;
if (type == 8) return true;
if (type == 9) return true;
if (type == 10) return false;
if (type == 11) return false;
if (type == 12) return false;
return false;
}

static const char* type_to_dtype_string(int type)
{
if (type == 1) return "torch.float";
if (type == 2) return "torch.double";
if (type == 3) return "torch.half";
if (type == 4) return "torch.int";
if (type == 5) return "torch.long";
if (type == 6) return "torch.short";
if (type == 7) return "torch.int8";
if (type == 8) return "torch.uint8";
if (type == 9) return "torch.bool";
if (type == 10) return "torch.complex64";
if (type == 11) return "torch.complex128";
if (type == 12) return "torch.complex32";
return "null";
}

static bool string_is_positive_integer(const std::string& t)
{
for (size_t i = 0; i < t.size(); i++)
{
if (t[i] < '0' || t[i] > '9')
return false;
}

return true;
}

int save_ncnn(const Graph& g, const std::string& parampath, const std::string& binpath, const std::string& pypath)
{
FILE* paramfp = fopen(parampath.c_str(), "wb");
if (!paramfp)
{
fprintf(stderr, "fopen %s failed\n", parampath.c_str());
return -1;
}

FILE* binfp = fopen(binpath.c_str(), "wb");
if (!binfp)
{
fprintf(stderr, "fopen %s failed\n", binpath.c_str());
fclose(paramfp);
return -1;
}

// magic
fprintf(paramfp, "7767517\n");

// op count and oprand count
fprintf(paramfp, "%d %d\n", (int)g.ops.size(), (int)g.operands.size());

for (const Operator* op : g.ops)
{
fprintf(paramfp, "%-24s %-24s %d %d", op->type.c_str(), op->name.c_str(), (int)op->inputs.size(), (int)op->outputs.size());

for (const Operand* oprand : op->inputs)
{
fprintf(paramfp, " %s", oprand->name.c_str());
}

for (const Operand* oprand : op->outputs)
{
fprintf(paramfp, " %s", oprand->name.c_str());
}

for (const auto& it : op->params)
{
const Parameter& param = it.second;

if (!string_is_positive_integer(it.first))
{
fprintf(stderr, "ignore %s %s param %s=", op->type.c_str(), op->name.c_str(), it.first.c_str());

if (param.type == 0)
{
fprintf(stderr, "None");
}
if (param.type == 1)
{
if (param.b)
fprintf(stderr, "True");
else
fprintf(stderr, "False");
}
if (param.type == 2)
{
fprintf(stderr, "%d", param.i);
}
if (param.type == 3)
{
fprintf(stderr, "%e", param.f);
}
if (param.type == 4)
{
fprintf(stderr, "%s", param.s.c_str());
}
if (param.type == 5)
{
fprintf(stderr, "(");
for (size_t i = 0; i < param.ai.size(); i++)
{
fprintf(stderr, "%d", param.ai[i]);
if (i + 1 != param.ai.size())
fprintf(stderr, ",");
}
fprintf(stderr, ")");
}
if (param.type == 6)
{
fprintf(stderr, "(");
for (size_t i = 0; i < param.af.size(); i++)
{
fprintf(stderr, "%e", param.af[i]);
if (i + 1 != param.af.size())
fprintf(stderr, ",");
}
fprintf(stderr, ")");
}
if (param.type == 7)
{
fprintf(stderr, "(");
for (size_t i = 0; i < param.as.size(); i++)
{
fprintf(stderr, "%s", param.as[i].c_str());
if (i + 1 != param.as.size())
fprintf(stderr, ",");
}
fprintf(stderr, ")");
}
fprintf(stderr, "\n");

continue;
}

const int idkey = std::stoi(it.first);
if (param.type == 2)
{
fprintf(paramfp, " %d=%d", idkey, param.i);
}
if (param.type == 3)
{
fprintf(paramfp, " %d=%e", idkey, param.f);
}
if (param.type == 5)
{
const int array_size = (int)param.ai.size();
fprintf(paramfp, " %d=%d", -23300 - idkey, array_size);
for (size_t i = 0; i < param.ai.size(); i++)
{
fprintf(paramfp, ",%d", param.ai[i]);
}
}
if (param.type == 6)
{
const int array_size = (int)param.af.size();
fprintf(paramfp, " %d=%d", -23300 - idkey, array_size);
for (size_t i = 0; i < param.af.size(); i++)
{
fprintf(paramfp, ",%e", param.af[i]);
}
}
}

for (const auto& it : op->attrs)
{
// fprintf(paramfp, " @%s=", it.first.c_str());

const Attribute& attr = it.second;

fwrite(attr.data.data(), attr.data.size(), 1, binfp);
}

// if (op->inputnames.size() == op->inputs.size())
// {
// for (size_t i = 0; i < op->inputs.size(); i++)
// {
// const Operand* oprand = op->inputs[i];
// fprintf(paramfp, " $%s=%s", op->inputnames[i].c_str(), oprand->name.c_str());
// }
// }

// for (const Operand* oprand : op->outputs)
// {
// if (oprand->params.find("__batch_index") == oprand->params.end())
// continue;
//
// const int batch_index = oprand->params.at("__batch_index").i;
//
// fprintf(paramfp, " #%s=%d", oprand->name.c_str(), batch_index);
// }

// for (const Operand* oprand : op->outputs)
// {
// if (oprand->shape.empty())
// continue;
//
// fprintf(paramfp, " #%s=", oprand->name.c_str());
//
// fprintf(paramfp, "(");
// for (int64_t i = 0; i < oprand->shape.size() - 1; i++)
// {
// fprintf(paramfp, "%d,", oprand->shape[i]);
// }
// if (oprand->shape.size() > 0)
// fprintf(paramfp, "%d", oprand->shape[oprand->shape.size() - 1]);
// fprintf(paramfp, ")");
//
// fprintf(paramfp, type_to_string(oprand->type));
// }

fprintf(paramfp, "\n");
}

fclose(paramfp);
fclose(binfp);

FILE* pyfp = fopen(pypath.c_str(), "wb");
if (!pyfp)
{
fprintf(stderr, "fopen %s failed\n", pypath.c_str());
return -1;
}

fprintf(pyfp, "import numpy as np\n");
fprintf(pyfp, "import ncnn\n");
fprintf(pyfp, "import torch\n");

fprintf(pyfp, "\n");

// test inference
{
fprintf(pyfp, "def test_inference():\n");
fprintf(pyfp, " torch.manual_seed(0)\n");

for (int input_index = 0;; input_index++)
{
std::string input_name = std::string("in") + std::to_string(input_index);
const Operand* r = g.get_operand(input_name);
if (!r)
break;

if (type_is_integer(r->type))
{
fprintf(pyfp, " %s = torch.randint(10, (", input_name.c_str());
for (size_t i = 0; i < r->shape.size(); i++)
{
fprintf(pyfp, "%d", r->shape[i]);
if (i + 1 != r->shape.size() || r->shape.size() == 1)
fprintf(pyfp, ", ");
}
fprintf(pyfp, "), dtype=%s)\n", type_to_dtype_string(r->type));
}
else
{
fprintf(pyfp, " %s = torch.rand(", input_name.c_str());
for (size_t i = 0; i < r->shape.size(); i++)
{
fprintf(pyfp, "%d, ", r->shape[i]);
}
fprintf(pyfp, "dtype=%s)\n", type_to_dtype_string(r->type));
}
}

fprintf(pyfp, " out = []\n");
fprintf(pyfp, "\n");

fprintf(pyfp, " with ncnn.Net() as net:\n");
fprintf(pyfp, " net.load_param(\"%s\")\n", parampath.c_str());
fprintf(pyfp, " net.load_model(\"%s\")\n", binpath.c_str());
fprintf(pyfp, "\n");
fprintf(pyfp, " with net.create_extractor() as ex:\n");

for (int input_index = 0;; input_index++)
{
std::string input_name = std::string("in") + std::to_string(input_index);
const Operand* r = g.get_operand(input_name);
if (!r)
break;

const int batch_index = r->params.at("__batch_index").i;
if (batch_index != 233)
{
fprintf(pyfp, " ex.input(\"%s\", ncnn.Mat(%s.squeeze(%d).numpy()).clone())\n", input_name.c_str(), input_name.c_str(), batch_index);
}
else
{
fprintf(pyfp, " ex.input(\"%s\", ncnn.Mat(%s.numpy()).clone())\n", input_name.c_str(), input_name.c_str());
}
}

fprintf(pyfp, "\n");

for (int output_index = 0;; output_index++)
{
std::string output_name = std::string("out") + std::to_string(output_index);
const Operand* r = g.get_operand(output_name);
if (!r)
break;

fprintf(pyfp, " _, %s = ex.extract(\"%s\")\n", output_name.c_str(), output_name.c_str());

const int batch_index = r->params.at("__batch_index").i;
if (batch_index != 233)
{
fprintf(pyfp, " out.append(torch.from_numpy(np.array(%s)).unsqueeze(%d))\n", output_name.c_str(), batch_index);
}
else
{
fprintf(pyfp, " out.append(torch.from_numpy(np.array(%s)))\n", output_name.c_str());
}
}

fprintf(pyfp, "\n");

fprintf(pyfp, " if len(out) == 1:\n");
fprintf(pyfp, " return out[0]\n");
fprintf(pyfp, " else:\n");
fprintf(pyfp, " return tuple(out)\n");
}

fclose(pyfp);

return 0;
}

} // namespace pnnx

+ 26
- 0
tools/pnnx/src/save_ncnn.h View File

@@ -0,0 +1,26 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#ifndef PNNX_SAVE_NCNN_H
#define PNNX_SAVE_NCNN_H

#include "ir.h"

namespace pnnx {

int save_ncnn(const Graph& g, const std::string& parampath, const std::string& binpath, const std::string& pypath);

} // namespace pnnx

#endif // PNNX_SAVE_NCNN_H

+ 268
- 0
tools/pnnx/src/save_onnx.cpp View File

@@ -0,0 +1,268 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "save_onnx.h"

#include "onnx.pb.h"

#include <string.h>
#include <fstream>
#include <iostream>

namespace pnnx {

// from cxxabi bridge
extern const char* get_operand_name(const Operand* x);
extern const char* get_operator_type(const Operator* op);
extern const char* get_operator_name(const Operator* op);
extern std::vector<const char*> get_operator_params_keys(const Operator* op);
extern std::vector<const char*> get_operator_attrs_keys(const Operator* op);
extern const Parameter& get_operator_param(const Operator* op, const char* key);
extern const Attribute& get_operator_attr(const Operator* op, const char* key);
extern const char* get_param_s(const Parameter& p);
extern std::vector<const char*> get_param_as(const Parameter& p);

int save_onnx(const Graph& g, const char* onnxpath)
{
onnx::ModelProto model;

onnx::GraphProto* gp = model.mutable_graph();

for (const Operand* x : g.operands)
{
onnx::ValueInfoProto* vip = gp->add_value_info();

vip->set_name(get_operand_name(x));

onnx::TypeProto* tp = vip->mutable_type();

onnx::TypeProto_Tensor* tpt = tp->mutable_tensor_type();

switch (x->type)
{
case 1: // f32
tpt->set_elem_type(1);
break;
case 2: // f64
tpt->set_elem_type(11);
break;
case 3: // f16
tpt->set_elem_type(10);
break;
case 4: // i32
tpt->set_elem_type(6);
break;
case 5: // i64
tpt->set_elem_type(7);
break;
case 6: // i16
tpt->set_elem_type(5);
break;
case 7: // i8
tpt->set_elem_type(3);
break;
case 8: // u8
tpt->set_elem_type(2);
break;
case 9: // bool
tpt->set_elem_type(9);
break;
case 10: // cp64
tpt->set_elem_type(14);
break;
case 11: // cp128
tpt->set_elem_type(15);
break;
case 12: // cp32
tpt->set_elem_type(0);
break;
default: // null
tpt->set_elem_type(0);
break;
}

onnx::TensorShapeProto* tsp = tpt->mutable_shape();

for (auto s : x->shape)
{
onnx::TensorShapeProto_Dimension* tspd = tsp->add_dim();

tspd->set_dim_value(s);
}
}

for (const Operator* op : g.ops)
{
onnx::NodeProto* np = gp->add_node();

np->set_op_type(get_operator_type(op));
np->set_name(get_operator_name(op));

for (const Operand* oprand : op->inputs)
{
np->add_input(get_operand_name(oprand));
}

for (const Operand* oprand : op->outputs)
{
np->add_output(get_operand_name(oprand));
}

std::vector<const char*> params_keys = get_operator_params_keys(op);

// for (const auto& it : op->params)
for (const char* param_name : params_keys)
{
// const Parameter& param = it.second;
const Parameter& param = get_operator_param(op, param_name);

onnx::AttributeProto* ap = np->add_attribute();

// ap->set_name(get_param_name(it));
ap->set_name(param_name);

if (param.type == 0)
{
ap->set_s("None");
}
if (param.type == 1)
{
if (param.b)
ap->set_i(1);
else
ap->set_i(0);
}
if (param.type == 2)
{
ap->set_i(param.i);
}
if (param.type == 3)
{
ap->set_f(param.f);
}
if (param.type == 4)
{
ap->set_s(get_param_s(param));
}
if (param.type == 5)
{
for (auto i : param.ai)
{
ap->add_ints(i);
}
}
if (param.type == 6)
{
for (auto f : param.af)
{
ap->add_floats(f);
}
}
if (param.type == 7)
{
std::vector<const char*> as = get_param_as(param);
for (auto s : as)
{
ap->add_strings(s);
}
}
}

std::vector<const char*> attrs_keys = get_operator_attrs_keys(op);

// for (const auto& it : op->attrs)
for (const char* attr_name : attrs_keys)
{
onnx::TensorProto* tp = gp->add_initializer();

tp->set_name(std::string(get_operator_name(op)) + "." + attr_name);

np->add_input(std::string(get_operator_name(op)) + "." + attr_name);

// const Attribute& attr = it.second;
const Attribute& attr = get_operator_attr(op, attr_name);
for (auto s : attr.shape)
{
tp->add_dims(s);
}

switch (attr.type)
{
case 1: // f32
tp->set_data_type(1);
break;
case 2: // f64
tp->set_data_type(11);
break;
case 3: // f16
tp->set_data_type(10);
break;
case 4: // i32
tp->set_data_type(6);
break;
case 5: // i64
tp->set_data_type(7);
break;
case 6: // i16
tp->set_data_type(5);
break;
case 7: // i8
tp->set_data_type(3);
break;
case 8: // u8
tp->set_data_type(2);
break;
case 9: // bool
tp->set_data_type(9);
break;
case 10: // cp64
tp->set_data_type(14);
break;
case 11: // cp128
tp->set_data_type(15);
break;
case 12: // cp32
tp->set_data_type(0);
break;
default: // null
tp->set_data_type(0);
break;
}

std::string* d = tp->mutable_raw_data();
d->resize(attr.data.size());
memcpy((void*)d->data(), attr.data.data(), attr.data.size());
}

// if (op->inputnames.size() == op->inputs.size())
// {
// for (size_t i = 0; i < op->inputs.size(); i++)
// {
// const Operand* oprand = op->inputs[i];
// fprintf(paramfp, " $%s=%s", op->inputnames[i].c_str(), oprand->name.c_str());
// }
// }
}

std::fstream output(onnxpath, std::ios::out | std::ios::trunc | std::ios::binary);
if (!model.SerializeToOstream(&output))
{
fprintf(stderr, "write onnx failed\n");
return -1;
}

return 0;
}

} // namespace pnnx

+ 26
- 0
tools/pnnx/src/save_onnx.h View File

@@ -0,0 +1,26 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#ifndef PNNX_SAVE_ONNX_H
#define PNNX_SAVE_ONNX_H

#include "ir.h"

namespace pnnx {

int save_onnx(const Graph& g, const char* onnxpath);

} // namespace pnnx

#endif // PNNX_SAVE_ONNX_H

+ 81
- 0
tools/pnnx/src/save_onnx_cxxabi_bridge.cpp View File

@@ -0,0 +1,81 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "ir.h"

namespace pnnx {

const char* get_operand_name(const Operand* x)
{
return x->name.c_str();
}

const char* get_operator_type(const Operator* op)
{
return op->type.c_str();
}

const char* get_operator_name(const Operator* op)
{
return op->name.c_str();
}

std::vector<const char*> get_operator_params_keys(const Operator* op)
{
std::vector<const char*> keys;
for (const auto& it : op->params)
{
const std::string& key = it.first;
keys.push_back(key.c_str());
}
return keys;
}

std::vector<const char*> get_operator_attrs_keys(const Operator* op)
{
std::vector<const char*> keys;
for (const auto& it : op->attrs)
{
const std::string& key = it.first;
keys.push_back(key.c_str());
}
return keys;
}

const Parameter& get_operator_param(const Operator* op, const char* key)
{
return op->params.at(key);
}

const Attribute& get_operator_attr(const Operator* op, const char* key)
{
return op->attrs.at(key);
}

const char* get_param_s(const Parameter& p)
{
return p.s.c_str();
}

std::vector<const char*> get_param_as(const Parameter& p)
{
std::vector<const char*> as;
for (const auto& s : p.as)
{
as.push_back(s.c_str());
}
return as;
}

} // namespace pnnx

Loading…
Cancel
Save