|
- // Tencent is pleased to support the open source community by making ncnn available.
- //
- // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
- //
- // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
- // in compliance with the License. You may obtain a copy of the License at
- //
- // https://opensource.org/licenses/BSD-3-Clause
- //
- // Unless required by applicable law or agreed to in writing, software distributed
- // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
- // CONDITIONS OF ANY KIND, either express or implied. See the License for the
- // specific language governing permissions and limitations under the License.
-
- #include "save_onnx.h"
-
- #include "onnx.pb.h"
-
- #include <string.h>
- #include <fstream>
- #include <iostream>
-
- namespace pnnx {
-
- // from cxxabi bridge
- extern const char* get_operand_name(const Operand* x);
- extern const char* get_operator_type(const Operator* op);
- extern const char* get_operator_name(const Operator* op);
- extern std::vector<const char*> get_operator_params_keys(const Operator* op);
- extern std::vector<const char*> get_operator_attrs_keys(const Operator* op);
- extern const Parameter& get_operator_param(const Operator* op, const char* key);
- extern const Attribute& get_operator_attr(const Operator* op, const char* key);
- extern const char* get_param_s(const Parameter& p);
- extern std::vector<const char*> get_param_as(const Parameter& p);
-
- static unsigned short float32_to_float16(float value)
- {
- // 1 : 8 : 23
- union
- {
- unsigned int u;
- float f;
- } tmp;
-
- tmp.f = value;
-
- // 1 : 8 : 23
- unsigned short sign = (tmp.u & 0x80000000) >> 31;
- unsigned short exponent = (tmp.u & 0x7F800000) >> 23;
- unsigned int significand = tmp.u & 0x7FFFFF;
-
- // NCNN_LOGE("%d %d %d", sign, exponent, significand);
-
- // 1 : 5 : 10
- unsigned short fp16;
- if (exponent == 0)
- {
- // zero or denormal, always underflow
- fp16 = (sign << 15) | (0x00 << 10) | 0x00;
- }
- else if (exponent == 0xFF)
- {
- // infinity or NaN
- fp16 = (sign << 15) | (0x1F << 10) | (significand ? 0x200 : 0x00);
- }
- else
- {
- // normalized
- short newexp = exponent + (-127 + 15);
- if (newexp >= 31)
- {
- // overflow, return infinity
- fp16 = (sign << 15) | (0x1F << 10) | 0x00;
- }
- else if (newexp <= 0)
- {
- // Some normal fp32 cannot be expressed as normal fp16
- fp16 = (sign << 15) | (0x00 << 10) | 0x00;
- }
- else
- {
- // normal fp16
- fp16 = (sign << 15) | (newexp << 10) | (significand >> 13);
- }
- }
-
- return fp16;
- }
-
- int save_onnx(const Graph& g, const char* onnxpath, int fp16)
- {
- onnx::ModelProto model;
-
- onnx::GraphProto* gp = model.mutable_graph();
-
- for (const Operand* x : g.operands)
- {
- onnx::ValueInfoProto* vip = gp->add_value_info();
-
- vip->set_name(get_operand_name(x));
-
- onnx::TypeProto* tp = vip->mutable_type();
-
- onnx::TypeProto_Tensor* tpt = tp->mutable_tensor_type();
-
- switch (x->type)
- {
- case 1: // f32
- tpt->set_elem_type(fp16 ? 10 : 1);
- break;
- case 2: // f64
- tpt->set_elem_type(fp16 ? 10 : 11);
- break;
- case 3: // f16
- tpt->set_elem_type(10);
- break;
- case 4: // i32
- tpt->set_elem_type(6);
- break;
- case 5: // i64
- tpt->set_elem_type(7);
- break;
- case 6: // i16
- tpt->set_elem_type(5);
- break;
- case 7: // i8
- tpt->set_elem_type(3);
- break;
- case 8: // u8
- tpt->set_elem_type(2);
- break;
- case 9: // bool
- tpt->set_elem_type(9);
- break;
- case 10: // cp64
- tpt->set_elem_type(14);
- break;
- case 11: // cp128
- tpt->set_elem_type(15);
- break;
- case 12: // cp32
- tpt->set_elem_type(0);
- break;
- default: // null
- tpt->set_elem_type(0);
- break;
- }
-
- onnx::TensorShapeProto* tsp = tpt->mutable_shape();
-
- for (auto s : x->shape)
- {
- onnx::TensorShapeProto_Dimension* tspd = tsp->add_dim();
-
- tspd->set_dim_value(s);
- }
- }
-
- for (const Operator* op : g.ops)
- {
- onnx::NodeProto* np = gp->add_node();
-
- np->set_op_type(get_operator_type(op));
- np->set_name(get_operator_name(op));
-
- for (const Operand* oprand : op->inputs)
- {
- np->add_input(get_operand_name(oprand));
- }
-
- for (const Operand* oprand : op->outputs)
- {
- np->add_output(get_operand_name(oprand));
- }
-
- std::vector<const char*> params_keys = get_operator_params_keys(op);
- for (const char* param_name : params_keys)
- {
- const Parameter& param = get_operator_param(op, param_name);
-
- onnx::AttributeProto* ap = np->add_attribute();
-
- ap->set_name(param_name);
-
- if (param.type == 0)
- {
- ap->set_s("None");
- }
- if (param.type == 1)
- {
- if (param.b)
- ap->set_i(1);
- else
- ap->set_i(0);
- }
- if (param.type == 2)
- {
- ap->set_i(param.i);
- }
- if (param.type == 3)
- {
- ap->set_f(param.f);
- }
- if (param.type == 4)
- {
- ap->set_s(get_param_s(param));
- }
- if (param.type == 5)
- {
- for (auto i : param.ai)
- {
- ap->add_ints(i);
- }
- }
- if (param.type == 6)
- {
- for (auto f : param.af)
- {
- ap->add_floats(f);
- }
- }
- if (param.type == 7)
- {
- std::vector<const char*> as = get_param_as(param);
- for (auto s : as)
- {
- ap->add_strings(s);
- }
- }
- }
-
- std::vector<const char*> attrs_keys = get_operator_attrs_keys(op);
- for (const char* attr_name : attrs_keys)
- {
- onnx::TensorProto* tp = gp->add_initializer();
-
- tp->set_name(std::string(get_operator_name(op)) + "." + attr_name);
-
- np->add_input(std::string(get_operator_name(op)) + "." + attr_name);
-
- const Attribute& attr = get_operator_attr(op, attr_name);
- for (auto s : attr.shape)
- {
- tp->add_dims(s);
- }
-
- switch (attr.type)
- {
- case 1: // f32
- tp->set_data_type(fp16 ? 10 : 1);
- break;
- case 2: // f64
- tp->set_data_type(fp16 ? 10 : 11);
- break;
- case 3: // f16
- tp->set_data_type(10);
- break;
- case 4: // i32
- tp->set_data_type(6);
- break;
- case 5: // i64
- tp->set_data_type(7);
- break;
- case 6: // i16
- tp->set_data_type(5);
- break;
- case 7: // i8
- tp->set_data_type(3);
- break;
- case 8: // u8
- tp->set_data_type(2);
- break;
- case 9: // bool
- tp->set_data_type(9);
- break;
- case 10: // cp64
- tp->set_data_type(14);
- break;
- case 11: // cp128
- tp->set_data_type(15);
- break;
- case 12: // cp32
- tp->set_data_type(0);
- break;
- default: // null
- tp->set_data_type(0);
- break;
- }
-
- std::string* d = tp->mutable_raw_data();
- if (fp16 && attr.type == 1)
- {
- // fp32 to fp16
- const float* p = (const float*)attr.data.data();
- int len = attr.data.size() / 4;
- d->resize(len * 2);
- unsigned short* p_fp16 = (unsigned short*)d->data();
- for (int i = 0; i < len; i++)
- {
- p_fp16[i] = float32_to_float16(p[i]);
- }
- }
- else if (fp16 && attr.type == 2)
- {
- // fp64 to fp16
- const double* p = (const double*)attr.data.data();
- int len = attr.data.size() / 4;
- d->resize(len);
- unsigned short* p_fp16 = (unsigned short*)d->data();
- for (int i = 0; i < len; i++)
- {
- p_fp16[i] = float32_to_float16((float)p[i]);
- }
- }
- else
- {
- d->resize(attr.data.size());
- memcpy((void*)d->data(), attr.data.data(), attr.data.size());
- }
- }
- }
-
- std::fstream output(onnxpath, std::ios::out | std::ios::trunc | std::ios::binary);
- if (!model.SerializeToOstream(&output))
- {
- fprintf(stderr, "write onnx failed\n");
- return -1;
- }
-
- return 0;
- }
-
- } // namespace pnnx
|