You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.cpp 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. // Tencent is pleased to support the open source community by making ncnn available.
  2. //
  3. // Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
  4. //
  5. // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
  6. // in compliance with the License. You may obtain a copy of the License at
  7. //
  8. // https://opensource.org/licenses/BSD-3-Clause
  9. //
  10. // Unless required by applicable law or agreed to in writing, software distributed
  11. // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
  12. // CONDITIONS OF ANY KIND, either express or implied. See the License for the
  13. // specific language governing permissions and limitations under the License.
  14. #include <stdio.h>
  15. #if _WIN32
  16. #include <windows.h>
  17. #else
  18. #include <dlfcn.h>
  19. #endif
  20. #include <string>
  21. #include <vector>
  22. #include <torch/script.h>
  23. #ifdef PNNX_TORCHVISION
  24. // register torchvision ops via including headers
  25. #include <torchvision/vision.h>
  26. #endif
  27. #include "ir.h"
  28. #include "pass_level0.h"
  29. #include "pass_level1.h"
  30. #include "pass_level2.h"
  31. #include "pass_level3.h"
  32. #include "pass_level4.h"
  33. #include "pass_level5.h"
  34. #include "pass_ncnn.h"
  35. static std::string get_basename(const std::string& path)
  36. {
  37. return path.substr(0, path.find_last_of('.'));
  38. }
  39. static void parse_string_list(char* s, std::vector<std::string>& list)
  40. {
  41. list.clear();
  42. char* pch = strtok(s, ",");
  43. while (pch != NULL)
  44. {
  45. list.push_back(std::string(pch));
  46. pch = strtok(NULL, ",");
  47. }
  48. }
  49. static void print_string_list(const std::vector<std::string>& list)
  50. {
  51. for (size_t i = 0; i < list.size(); i++)
  52. {
  53. fprintf(stderr, "%s", list[i].c_str());
  54. if (i + 1 != list.size())
  55. fprintf(stderr, ",");
  56. }
  57. }
  58. static void parse_shape_list(char* s, std::vector<std::vector<int64_t> >& shapes, std::vector<std::string>& types)
  59. {
  60. shapes.clear();
  61. types.clear();
  62. char* pch = strtok(s, "[]");
  63. while (pch != NULL)
  64. {
  65. // assign user data type
  66. if (!types.empty() && (pch[0] == 'f' || pch[0] == 'i' || pch[0] == 'u'))
  67. {
  68. char type[32];
  69. int nscan = sscanf(pch, "%31[^,]", type);
  70. if (nscan == 1)
  71. {
  72. types[types.size() - 1] = std::string(type);
  73. }
  74. }
  75. // parse a,b,c
  76. int v;
  77. int nconsumed = 0;
  78. int nscan = sscanf(pch, "%d%n", &v, &nconsumed);
  79. if (nscan == 1)
  80. {
  81. // ok we get shape
  82. pch += nconsumed;
  83. std::vector<int64_t> s;
  84. s.push_back(v);
  85. nscan = sscanf(pch, ",%d%n", &v, &nconsumed);
  86. while (nscan == 1)
  87. {
  88. pch += nconsumed;
  89. s.push_back(v);
  90. nscan = sscanf(pch, ",%d%n", &v, &nconsumed);
  91. }
  92. // shape end
  93. shapes.push_back(s);
  94. types.push_back("f32");
  95. }
  96. pch = strtok(NULL, "[]");
  97. }
  98. }
  99. static void print_shape_list(const std::vector<std::vector<int64_t> >& shapes, const std::vector<std::string>& types)
  100. {
  101. for (size_t i = 0; i < shapes.size(); i++)
  102. {
  103. const std::vector<int64_t>& s = shapes[i];
  104. const std::string& t = types[i];
  105. fprintf(stderr, "[");
  106. for (size_t j = 0; j < s.size(); j++)
  107. {
  108. fprintf(stderr, "%ld", s[j]);
  109. if (j != s.size() - 1)
  110. fprintf(stderr, ",");
  111. }
  112. fprintf(stderr, "]");
  113. fprintf(stderr, "%s", t.c_str());
  114. if (i != shapes.size() - 1)
  115. fprintf(stderr, ",");
  116. }
  117. }
  118. static c10::ScalarType input_type_to_c10_ScalarType(const std::string& t)
  119. {
  120. if (t == "f32") return torch::kFloat32;
  121. if (t == "f16") return torch::kFloat16;
  122. if (t == "f64") return torch::kFloat64;
  123. if (t == "i32") return torch::kInt32;
  124. if (t == "i16") return torch::kInt16;
  125. if (t == "i64") return torch::kInt64;
  126. if (t == "i8") return torch::kInt8;
  127. if (t == "u8") return torch::kUInt8;
  128. fprintf(stderr, "unsupported type %s fallback to f32\n", t.c_str());
  129. return torch::kFloat32;
  130. }
  131. static void show_usage()
  132. {
  133. fprintf(stderr, "Usage: pnnx [model.pt] [(key=value)...]\n");
  134. fprintf(stderr, " pnnxparam=model.pnnx.param\n");
  135. fprintf(stderr, " pnnxbin=model.pnnx.bin\n");
  136. fprintf(stderr, " pnnxpy=model_pnnx.py\n");
  137. fprintf(stderr, " ncnnparam=model.ncnn.param\n");
  138. fprintf(stderr, " ncnnbin=model.ncnn.bin\n");
  139. fprintf(stderr, " ncnnpy=model_ncnn.py\n");
  140. fprintf(stderr, " optlevel=2\n");
  141. fprintf(stderr, " device=cpu/gpu\n");
  142. fprintf(stderr, " inputshape=[1,3,224,224],...\n");
  143. fprintf(stderr, " inputshape2=[1,3,320,320],...\n");
  144. #if _WIN32
  145. fprintf(stderr, " customop=C:\\Users\\nihui\\AppData\\Local\\torch_extensions\\torch_extensions\\Cache\\fused\\fused.dll,...\n");
  146. #else
  147. fprintf(stderr, " customop=/home/nihui/.cache/torch_extensions/fused/fused.so,...\n");
  148. #endif
  149. fprintf(stderr, " moduleop=models.common.Focus,models.yolo.Detect,...\n");
  150. fprintf(stderr, "Sample usage: pnnx mobilenet_v2.pt inputshape=[1,3,224,224]\n");
  151. fprintf(stderr, " pnnx yolov5s.pt inputshape=[1,3,640,640]f32 inputshape2=[1,3,320,320]f32 device=gpu moduleop=models.common.Focus,models.yolo.Detect\n");
  152. }
  153. int main(int argc, char** argv)
  154. {
  155. if (argc < 2)
  156. {
  157. show_usage();
  158. return -1;
  159. }
  160. for (int i = 1; i < argc; i++)
  161. {
  162. if (argv[i][0] == '-')
  163. {
  164. show_usage();
  165. return -1;
  166. }
  167. }
  168. std::string ptpath = std::string(argv[1]);
  169. std::string ptbase = get_basename(ptpath);
  170. std::string pnnxparampath = ptbase + ".pnnx.param";
  171. std::string pnnxbinpath = ptbase + ".pnnx.bin";
  172. std::string pnnxpypath = ptbase + "_pnnx.py";
  173. std::string ncnnparampath = ptbase + ".ncnn.param";
  174. std::string ncnnbinpath = ptbase + ".ncnn.bin";
  175. std::string ncnnpypath = ptbase + "_ncnn.py";
  176. int optlevel = 2;
  177. std::string device = "cpu";
  178. std::vector<std::vector<int64_t> > input_shapes;
  179. std::vector<std::string> input_types;
  180. std::vector<std::vector<int64_t> > input_shapes2;
  181. std::vector<std::string> input_types2;
  182. std::vector<std::string> customop_modules;
  183. std::vector<std::string> module_operators;
  184. for (int i = 2; i < argc; i++)
  185. {
  186. // key=value
  187. char* kv = argv[i];
  188. char* eqs = strchr(kv, '=');
  189. if (eqs == NULL)
  190. {
  191. fprintf(stderr, "unrecognized arg %s\n", kv);
  192. continue;
  193. }
  194. // split k v
  195. eqs[0] = '\0';
  196. const char* key = kv;
  197. char* value = eqs + 1;
  198. if (strcmp(key, "pnnxparam") == 0)
  199. pnnxparampath = std::string(value);
  200. if (strcmp(key, "pnnxbin") == 0)
  201. pnnxbinpath = std::string(value);
  202. if (strcmp(key, "pnnxpy") == 0)
  203. pnnxpypath = std::string(value);
  204. if (strcmp(key, "ncnnparam") == 0)
  205. ncnnparampath = std::string(value);
  206. if (strcmp(key, "ncnnbin") == 0)
  207. ncnnbinpath = std::string(value);
  208. if (strcmp(key, "ncnnpy") == 0)
  209. ncnnpypath = std::string(value);
  210. if (strcmp(key, "optlevel") == 0)
  211. optlevel = atoi(value);
  212. if (strcmp(key, "device") == 0)
  213. device = value;
  214. if (strcmp(key, "inputshape") == 0)
  215. parse_shape_list(value, input_shapes, input_types);
  216. if (strcmp(key, "inputshape2") == 0)
  217. parse_shape_list(value, input_shapes2, input_types2);
  218. if (strcmp(key, "customop") == 0)
  219. parse_string_list(value, customop_modules);
  220. if (strcmp(key, "moduleop") == 0)
  221. parse_string_list(value, module_operators);
  222. }
  223. // print options
  224. {
  225. fprintf(stderr, "pnnxparam = %s\n", pnnxparampath.c_str());
  226. fprintf(stderr, "pnnxbin = %s\n", pnnxbinpath.c_str());
  227. fprintf(stderr, "pnnxpy = %s\n", pnnxpypath.c_str());
  228. fprintf(stderr, "ncnnparam = %s\n", ncnnparampath.c_str());
  229. fprintf(stderr, "ncnnbin = %s\n", ncnnbinpath.c_str());
  230. fprintf(stderr, "ncnnpy = %s\n", ncnnpypath.c_str());
  231. fprintf(stderr, "optlevel = %d\n", optlevel);
  232. fprintf(stderr, "device = %s\n", device.c_str());
  233. fprintf(stderr, "inputshape = ");
  234. print_shape_list(input_shapes, input_types);
  235. fprintf(stderr, "\n");
  236. fprintf(stderr, "inputshape2 = ");
  237. print_shape_list(input_shapes2, input_types2);
  238. fprintf(stderr, "\n");
  239. fprintf(stderr, "customop = ");
  240. print_string_list(customop_modules);
  241. fprintf(stderr, "\n");
  242. fprintf(stderr, "moduleop = ");
  243. print_string_list(module_operators);
  244. fprintf(stderr, "\n");
  245. }
  246. for (auto m : customop_modules)
  247. {
  248. fprintf(stderr, "load custom module %s\n", m.c_str());
  249. #if _WIN32
  250. HMODULE handle = LoadLibraryExA(m.c_str(), NULL, LOAD_WITH_ALTERED_SEARCH_PATH);
  251. if (!handle)
  252. {
  253. fprintf(stderr, "LoadLibraryExA %s failed %s\n", m.c_str(), GetLastError());
  254. }
  255. #else
  256. void* handle = dlopen(m.c_str(), RTLD_LAZY);
  257. if (!handle)
  258. {
  259. fprintf(stderr, "dlopen %s failed %s\n", m.c_str(), dlerror());
  260. }
  261. #endif
  262. }
  263. std::vector<at::Tensor> input_tensors;
  264. for (size_t i = 0; i < input_shapes.size(); i++)
  265. {
  266. const std::vector<int64_t>& shape = input_shapes[i];
  267. const std::string& type = input_types[i];
  268. at::Tensor t = torch::ones(shape, input_type_to_c10_ScalarType(type));
  269. if (device == "gpu")
  270. t = t.cuda();
  271. input_tensors.push_back(t);
  272. }
  273. std::vector<at::Tensor> input_tensors2;
  274. for (size_t i = 0; i < input_shapes2.size(); i++)
  275. {
  276. const std::vector<int64_t>& shape = input_shapes2[i];
  277. const std::string& type = input_types2[i];
  278. at::Tensor t = torch::ones(shape, input_type_to_c10_ScalarType(type));
  279. if (device == "gpu")
  280. t = t.cuda();
  281. input_tensors2.push_back(t);
  282. }
  283. torch::jit::Module mod = torch::jit::load(ptpath);
  284. mod.eval();
  285. // mod.dump(true, false, false);
  286. // mod.dump(true, true, true);
  287. auto g = mod.get_method("forward").graph();
  288. // g->dump();
  289. fprintf(stderr, "############# pass_level0\n");
  290. std::map<std::string, pnnx::Attribute> foldable_constants;
  291. pnnx::pass_level0(mod, g, input_tensors, input_tensors2, module_operators, ptpath, foldable_constants);
  292. // g->dump();
  293. fprintf(stderr, "############# pass_level1\n");
  294. pnnx::Graph pnnx_graph;
  295. pnnx::pass_level1(mod, g, pnnx_graph);
  296. // g->dump();
  297. fprintf(stderr, "############# pass_level2\n");
  298. pnnx::pass_level2(pnnx_graph);
  299. pnnx_graph.save("debug.param", "debug.bin");
  300. if (optlevel >= 1)
  301. {
  302. fprintf(stderr, "############# pass_level3\n");
  303. pnnx::pass_level3(pnnx_graph);
  304. fprintf(stderr, "############# pass_level4\n");
  305. pnnx::pass_level4(pnnx_graph);
  306. }
  307. pnnx_graph.save("debug2.param", "debug2.bin");
  308. if (optlevel >= 2)
  309. {
  310. fprintf(stderr, "############# pass_level5\n");
  311. pnnx::pass_level5(pnnx_graph, foldable_constants);
  312. }
  313. pnnx_graph.save(pnnxparampath, pnnxbinpath);
  314. pnnx_graph.python(pnnxpypath, pnnxbinpath);
  315. // if (optlevel >= 2)
  316. {
  317. fprintf(stderr, "############# pass_ncnn\n");
  318. pnnx::pass_ncnn(pnnx_graph);
  319. pnnx_graph.ncnn(ncnnparampath, ncnnbinpath, ncnnpypath);
  320. }
  321. // pnnx::Graph pnnx_graph2;
  322. // pnnx_graph2.load("pnnx.param", "pnnx.bin");
  323. // pnnx_graph2.save("pnnx2.param", "pnnx2.bin");
  324. return 0;
  325. }