|
- // Tencent is pleased to support the open source community by making ncnn available.
- //
- // Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
- //
- // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
- // in compliance with the License. You may obtain a copy of the License at
- //
- // https://opensource.org/licenses/BSD-3-Clause
- //
- // Unless required by applicable law or agreed to in writing, software distributed
- // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
- // CONDITIONS OF ANY KIND, either express or implied. See the License for the
- // specific language governing permissions and limitations under the License.
-
- #include "save_ncnn.h"
-
- namespace pnnx {
-
- static bool type_is_integer(int type)
- {
- if (type == 1) return false;
- if (type == 2) return false;
- if (type == 3) return false;
- if (type == 4) return true;
- if (type == 5) return true;
- if (type == 6) return true;
- if (type == 7) return true;
- if (type == 8) return true;
- if (type == 9) return true;
- if (type == 10) return false;
- if (type == 11) return false;
- if (type == 12) return false;
- return false;
- }
-
- static const char* type_to_dtype_string(int type)
- {
- if (type == 1) return "torch.float";
- if (type == 2) return "torch.double";
- if (type == 3) return "torch.half";
- if (type == 4) return "torch.int";
- if (type == 5) return "torch.long";
- if (type == 6) return "torch.short";
- if (type == 7) return "torch.int8";
- if (type == 8) return "torch.uint8";
- if (type == 9) return "torch.bool";
- if (type == 10) return "torch.complex64";
- if (type == 11) return "torch.complex128";
- if (type == 12) return "torch.complex32";
- return "null";
- }
-
- static bool string_is_positive_integer(const std::string& t)
- {
- for (size_t i = 0; i < t.size(); i++)
- {
- if (t[i] < '0' || t[i] > '9')
- return false;
- }
-
- return true;
- }
-
- static unsigned short float32_to_float16(float value)
- {
- // 1 : 8 : 23
- union
- {
- unsigned int u;
- float f;
- } tmp;
-
- tmp.f = value;
-
- // 1 : 8 : 23
- unsigned short sign = (tmp.u & 0x80000000) >> 31;
- unsigned short exponent = (tmp.u & 0x7F800000) >> 23;
- unsigned int significand = tmp.u & 0x7FFFFF;
-
- // NCNN_LOGE("%d %d %d", sign, exponent, significand);
-
- // 1 : 5 : 10
- unsigned short fp16;
- if (exponent == 0)
- {
- // zero or denormal, always underflow
- fp16 = (sign << 15) | (0x00 << 10) | 0x00;
- }
- else if (exponent == 0xFF)
- {
- // infinity or NaN
- fp16 = (sign << 15) | (0x1F << 10) | (significand ? 0x200 : 0x00);
- }
- else
- {
- // normalized
- short newexp = exponent + (-127 + 15);
- if (newexp >= 31)
- {
- // overflow, return infinity
- fp16 = (sign << 15) | (0x1F << 10) | 0x00;
- }
- else if (newexp <= 0)
- {
- // Some normal fp32 cannot be expressed as normal fp16
- fp16 = (sign << 15) | (0x00 << 10) | 0x00;
- }
- else
- {
- // normal fp16
- fp16 = (sign << 15) | (newexp << 10) | (significand >> 13);
- }
- }
-
- return fp16;
- }
-
- static size_t alignSize(size_t sz, int n)
- {
- return (sz + n - 1) & -n;
- }
-
- int save_ncnn(const Graph& g, const std::string& parampath, const std::string& binpath, const std::string& pypath, int fp16)
- {
- FILE* paramfp = fopen(parampath.c_str(), "wb");
- if (!paramfp)
- {
- fprintf(stderr, "fopen %s failed\n", parampath.c_str());
- return -1;
- }
-
- FILE* binfp = fopen(binpath.c_str(), "wb");
- if (!binfp)
- {
- fprintf(stderr, "fopen %s failed\n", binpath.c_str());
- fclose(paramfp);
- return -1;
- }
-
- // magic
- fprintf(paramfp, "7767517\n");
-
- // op count and oprand count
- fprintf(paramfp, "%d %d\n", (int)g.ops.size(), (int)g.operands.size());
-
- for (const Operator* op : g.ops)
- {
- fprintf(paramfp, "%-24s %-24s %d %d", op->type.c_str(), op->name.c_str(), (int)op->inputs.size(), (int)op->outputs.size());
-
- for (const Operand* oprand : op->inputs)
- {
- fprintf(paramfp, " %s", oprand->name.c_str());
- }
-
- for (const Operand* oprand : op->outputs)
- {
- fprintf(paramfp, " %s", oprand->name.c_str());
- }
-
- for (const auto& it : op->params)
- {
- const Parameter& param = it.second;
-
- if (!string_is_positive_integer(it.first))
- {
- fprintf(stderr, "ignore %s %s param %s=", op->type.c_str(), op->name.c_str(), it.first.c_str());
-
- if (param.type == 0)
- {
- fprintf(stderr, "None");
- }
- if (param.type == 1)
- {
- if (param.b)
- fprintf(stderr, "True");
- else
- fprintf(stderr, "False");
- }
- if (param.type == 2)
- {
- fprintf(stderr, "%d", param.i);
- }
- if (param.type == 3)
- {
- fprintf(stderr, "%e", param.f);
- }
- if (param.type == 4)
- {
- fprintf(stderr, "%s", param.s.c_str());
- }
- if (param.type == 5)
- {
- fprintf(stderr, "(");
- for (size_t i = 0; i < param.ai.size(); i++)
- {
- fprintf(stderr, "%d", param.ai[i]);
- if (i + 1 != param.ai.size())
- fprintf(stderr, ",");
- }
- fprintf(stderr, ")");
- }
- if (param.type == 6)
- {
- fprintf(stderr, "(");
- for (size_t i = 0; i < param.af.size(); i++)
- {
- fprintf(stderr, "%e", param.af[i]);
- if (i + 1 != param.af.size())
- fprintf(stderr, ",");
- }
- fprintf(stderr, ")");
- }
- if (param.type == 7)
- {
- fprintf(stderr, "(");
- for (size_t i = 0; i < param.as.size(); i++)
- {
- fprintf(stderr, "%s", param.as[i].c_str());
- if (i + 1 != param.as.size())
- fprintf(stderr, ",");
- }
- fprintf(stderr, ")");
- }
- fprintf(stderr, "\n");
-
- continue;
- }
-
- const int idkey = std::stoi(it.first);
- if (param.type == 2)
- {
- fprintf(paramfp, " %d=%d", idkey, param.i);
- }
- if (param.type == 3)
- {
- fprintf(paramfp, " %d=%e", idkey, param.f);
- }
- if (param.type == 5)
- {
- const int array_size = (int)param.ai.size();
- fprintf(paramfp, " %d=%d", -23300 - idkey, array_size);
- for (size_t i = 0; i < param.ai.size(); i++)
- {
- fprintf(paramfp, ",%d", param.ai[i]);
- }
- }
- if (param.type == 6)
- {
- const int array_size = (int)param.af.size();
- fprintf(paramfp, " %d=%d", -23300 - idkey, array_size);
- for (size_t i = 0; i < param.af.size(); i++)
- {
- fprintf(paramfp, ",%e", param.af[i]);
- }
- }
- }
-
- bool is_type_flag_fp32 = false;
- for (const auto& it : op->attrs)
- {
- // fprintf(paramfp, " @%s=", it.first.c_str());
-
- const Attribute& attr = it.second;
-
- if (fp16 && is_type_flag_fp32)
- {
- // fp32 -> fp16
- const float* p = (const float*)attr.data.data();
- int len = attr.data.size() / 4;
- std::vector<char> data_fp16(alignSize(len * 2, 4));
- unsigned short* p_fp16 = (unsigned short*)data_fp16.data();
- for (int i = 0; i < len; i++)
- {
- p_fp16[i] = float32_to_float16(p[i]);
- }
-
- // pad size to 4bytes
- if (len % 2 == 1)
- {
- // pad with fixed value for model hash consistency
- p_fp16[len] = 0x2283;
- }
-
- fwrite(data_fp16.data(), data_fp16.size(), 1, binfp);
-
- is_type_flag_fp32 = false;
- continue;
- }
-
- if (fp16 && attr.type == 0 && attr.data == std::vector<char> {0, 0, 0, 0})
- {
- // write fp16 flag
- unsigned int fp16_flag = 0x01306B47;
- fwrite((const char*)&fp16_flag, sizeof(fp16_flag), 1, binfp);
-
- is_type_flag_fp32 = true;
- continue;
- }
-
- fwrite(attr.data.data(), attr.data.size(), 1, binfp);
- }
-
- // if (op->inputnames.size() == op->inputs.size())
- // {
- // for (size_t i = 0; i < op->inputs.size(); i++)
- // {
- // const Operand* oprand = op->inputs[i];
- // fprintf(paramfp, " $%s=%s", op->inputnames[i].c_str(), oprand->name.c_str());
- // }
- // }
-
- // for (const Operand* oprand : op->outputs)
- // {
- // if (oprand->params.find("__batch_index") == oprand->params.end())
- // continue;
- //
- // const int batch_index = oprand->params.at("__batch_index").i;
- //
- // fprintf(paramfp, " #%s=%d", oprand->name.c_str(), batch_index);
- // }
-
- // for (const Operand* oprand : op->outputs)
- // {
- // if (oprand->shape.empty())
- // continue;
- //
- // fprintf(paramfp, " #%s=", oprand->name.c_str());
- //
- // fprintf(paramfp, "(");
- // for (int64_t i = 0; i < oprand->shape.size() - 1; i++)
- // {
- // fprintf(paramfp, "%d,", oprand->shape[i]);
- // }
- // if (oprand->shape.size() > 0)
- // fprintf(paramfp, "%d", oprand->shape[oprand->shape.size() - 1]);
- // fprintf(paramfp, ")");
- //
- // fprintf(paramfp, type_to_string(oprand->type));
- // }
-
- fprintf(paramfp, "\n");
- }
-
- fclose(paramfp);
- fclose(binfp);
-
- FILE* pyfp = fopen(pypath.c_str(), "wb");
- if (!pyfp)
- {
- fprintf(stderr, "fopen %s failed\n", pypath.c_str());
- return -1;
- }
-
- fprintf(pyfp, "import numpy as np\n");
- fprintf(pyfp, "import ncnn\n");
- fprintf(pyfp, "import torch\n");
-
- fprintf(pyfp, "\n");
-
- // test inference
- {
- fprintf(pyfp, "def test_inference():\n");
- fprintf(pyfp, " torch.manual_seed(0)\n");
-
- for (int input_index = 0;; input_index++)
- {
- std::string input_name = std::string("in") + std::to_string(input_index);
- const Operand* r = g.get_operand(input_name);
- if (!r)
- break;
-
- if (type_is_integer(r->type))
- {
- fprintf(pyfp, " %s = torch.randint(10, (", input_name.c_str());
- for (size_t i = 0; i < r->shape.size(); i++)
- {
- fprintf(pyfp, "%d", r->shape[i]);
- if (i + 1 != r->shape.size() || r->shape.size() == 1)
- fprintf(pyfp, ", ");
- }
- fprintf(pyfp, "), dtype=%s)\n", type_to_dtype_string(r->type));
- }
- else
- {
- fprintf(pyfp, " %s = torch.rand(", input_name.c_str());
- for (size_t i = 0; i < r->shape.size(); i++)
- {
- fprintf(pyfp, "%d, ", r->shape[i]);
- }
- fprintf(pyfp, "dtype=%s)\n", type_to_dtype_string(r->type));
- }
- }
-
- fprintf(pyfp, " out = []\n");
- fprintf(pyfp, "\n");
-
- fprintf(pyfp, " with ncnn.Net() as net:\n");
- fprintf(pyfp, " net.load_param(\"%s\")\n", parampath.c_str());
- fprintf(pyfp, " net.load_model(\"%s\")\n", binpath.c_str());
- fprintf(pyfp, "\n");
- fprintf(pyfp, " with net.create_extractor() as ex:\n");
-
- for (int input_index = 0;; input_index++)
- {
- std::string input_name = std::string("in") + std::to_string(input_index);
- const Operand* r = g.get_operand(input_name);
- if (!r)
- break;
-
- const int batch_index = r->params.at("__batch_index").i;
- if (batch_index != 233)
- {
- fprintf(pyfp, " ex.input(\"%s\", ncnn.Mat(%s.squeeze(%d).numpy()).clone())\n", input_name.c_str(), input_name.c_str(), batch_index);
- }
- else
- {
- fprintf(pyfp, " ex.input(\"%s\", ncnn.Mat(%s.numpy()).clone())\n", input_name.c_str(), input_name.c_str());
- }
- }
-
- fprintf(pyfp, "\n");
-
- for (int output_index = 0;; output_index++)
- {
- std::string output_name = std::string("out") + std::to_string(output_index);
- const Operand* r = g.get_operand(output_name);
- if (!r)
- break;
-
- fprintf(pyfp, " _, %s = ex.extract(\"%s\")\n", output_name.c_str(), output_name.c_str());
-
- const int batch_index = r->params.at("__batch_index").i;
- if (batch_index != 233)
- {
- fprintf(pyfp, " out.append(torch.from_numpy(np.array(%s)).unsqueeze(%d))\n", output_name.c_str(), batch_index);
- }
- else
- {
- fprintf(pyfp, " out.append(torch.from_numpy(np.array(%s)))\n", output_name.c_str());
- }
- }
-
- fprintf(pyfp, "\n");
-
- fprintf(pyfp, " if len(out) == 1:\n");
- fprintf(pyfp, " return out[0]\n");
- fprintf(pyfp, " else:\n");
- fprintf(pyfp, " return tuple(out)\n");
- }
-
- fclose(pyfp);
-
- return 0;
- }
-
- } // namespace pnnx
|