You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

load_torchscript.cpp 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734
  1. // Copyright 2024 Tencent
  2. // SPDX-License-Identifier: BSD-3-Clause
  3. #include "load_torchscript.h"
  4. #if _WIN32
  5. #include <windows.h>
  6. #else
  7. #include <dlfcn.h>
  8. #endif
  9. #include <torch/script.h>
  10. #include <torch/csrc/api/include/torch/version.h>
  11. #include <torch/csrc/jit/serialization/import_read.h>
  12. #ifdef PNNX_TORCHVISION
  13. namespace vision {
  14. int64_t cuda_version();
  15. } // namespace vision
  16. #endif
  17. #include "pass_level0.h"
  18. #include "pass_level1.h"
  19. #include "pass_level1/fuse_module_pass.h"
  20. namespace pnnx {
  21. static int get_at_tensor_type(const at::ScalarType& st)
  22. {
  23. if (st == c10::ScalarType::Float) return 1;
  24. if (st == c10::ScalarType::Double) return 2;
  25. if (st == c10::ScalarType::Half) return 3;
  26. if (st == c10::ScalarType::Int) return 4;
  27. if (st == c10::ScalarType::QInt32) return 4;
  28. if (st == c10::ScalarType::Long) return 5;
  29. if (st == c10::ScalarType::Short) return 6;
  30. if (st == c10::ScalarType::Char) return 7;
  31. if (st == c10::ScalarType::QInt8) return 7;
  32. if (st == c10::ScalarType::Byte) return 8;
  33. if (st == c10::ScalarType::QUInt8) return 8;
  34. if (st == c10::ScalarType::Bool) return 9;
  35. if (st == c10::ScalarType::ComplexFloat) return 10;
  36. if (st == c10::ScalarType::ComplexDouble) return 11;
  37. if (st == c10::ScalarType::ComplexHalf) return 12;
  38. if (st == c10::ScalarType::BFloat16) return 13;
  39. return 0; // unknown type
  40. }
  41. static size_t type_to_elemsize(int type)
  42. {
  43. if (type == 1) return 4;
  44. if (type == 2) return 8;
  45. if (type == 3) return 2;
  46. if (type == 4) return 4;
  47. if (type == 5) return 8;
  48. if (type == 6) return 2;
  49. if (type == 7) return 1;
  50. if (type == 8) return 1;
  51. if (type == 9) return 1;
  52. if (type == 10) return 8;
  53. if (type == 11) return 16;
  54. if (type == 12) return 4;
  55. if (type == 13) return 2;
  56. return 0; // null
  57. }
  58. Parameter::Parameter(const torch::jit::Node* value_node)
  59. {
  60. type = 0;
  61. if (value_node->kind() == c10::prim::Constant)
  62. {
  63. if (value_node->output()->type()->kind() == c10::TypeKind::NoneType)
  64. {
  65. type = 0;
  66. return;
  67. }
  68. if (!value_node->hasAttribute(torch::jit::attr::value))
  69. {
  70. fprintf(stderr, "no attribute value\n");
  71. value_node->dump();
  72. return;
  73. }
  74. switch (value_node->output()->type()->kind())
  75. {
  76. case c10::TypeKind::NoneType:
  77. {
  78. type = 0;
  79. break;
  80. }
  81. case c10::TypeKind::BoolType:
  82. {
  83. type = 1;
  84. b = value_node->i(torch::jit::attr::value);
  85. break;
  86. }
  87. case c10::TypeKind::IntType:
  88. {
  89. type = 2;
  90. int64_t i64 = value_node->i(torch::jit::attr::value);
  91. if (i64 == std::numeric_limits<int64_t>::max()) i64 = INT_MAX;
  92. if (i64 == std::numeric_limits<int64_t>::max() - 1) i64 = INT_MAX - 1;
  93. if (i64 == std::numeric_limits<int64_t>::min()) i64 = INT_MIN;
  94. if (i64 == std::numeric_limits<int64_t>::min() + 1) i64 = INT_MIN + 1;
  95. i = (int)i64;
  96. break;
  97. }
  98. case c10::TypeKind::FloatType:
  99. {
  100. type = 3;
  101. f = (float)value_node->f(torch::jit::attr::value);
  102. break;
  103. }
  104. case c10::TypeKind::StringType:
  105. {
  106. type = 4;
  107. s = value_node->s(torch::jit::attr::value);
  108. break;
  109. }
  110. case c10::TypeKind::DeviceObjType:
  111. {
  112. type = 4;
  113. s = value_node->s(torch::jit::attr::value);
  114. break;
  115. }
  116. #if TORCH_VERSION_MAJOR >= 2 || (TORCH_VERSION_MAJOR >= 1 && TORCH_VERSION_MINOR >= 9)
  117. case c10::TypeKind::ComplexType:
  118. {
  119. type = 10;
  120. c = std::complex<float>(value_node->c(torch::jit::attr::value));
  121. break;
  122. }
  123. #endif
  124. case c10::TypeKind::TensorType:
  125. {
  126. at::Tensor t = value_node->t(torch::jit::attr::value);
  127. if (t.dim() == 0 && t.numel() == 1)
  128. {
  129. if (t.scalar_type() == c10::ScalarType::Long)
  130. {
  131. type = 2;
  132. int64_t i64 = t.item<int64_t>();
  133. if (i64 == std::numeric_limits<int64_t>::max()) i64 = INT_MAX;
  134. if (i64 == std::numeric_limits<int64_t>::max() - 1) i64 = INT_MAX - 1;
  135. if (i64 == std::numeric_limits<int64_t>::min()) i64 = INT_MIN;
  136. if (i64 == std::numeric_limits<int64_t>::min() + 1) i64 = INT_MIN + 1;
  137. i = (int)i64;
  138. }
  139. else if (t.scalar_type() == c10::ScalarType::Int)
  140. {
  141. type = 2;
  142. i = t.item<int>();
  143. }
  144. else if (t.scalar_type() == c10::ScalarType::Double)
  145. {
  146. type = 3;
  147. f = (float)t.item<double>();
  148. }
  149. else if (t.scalar_type() == c10::ScalarType::Float)
  150. {
  151. type = 3;
  152. f = t.item<float>();
  153. }
  154. else if (t.scalar_type() == c10::ScalarType::ComplexDouble)
  155. {
  156. type = 10;
  157. c = std::complex<float>(t.item<c10::complex<double> >());
  158. }
  159. else if (t.scalar_type() == c10::ScalarType::ComplexFloat)
  160. {
  161. type = 10;
  162. c = std::complex<float>(t.item<c10::complex<float> >());
  163. }
  164. else
  165. {
  166. fprintf(stderr, "unknown Parameter value kind %s of TensorType, t.dim = 0\n", value_node->kind().toDisplayString());
  167. }
  168. }
  169. else
  170. {
  171. // constant tensor will become pnnx attribute node later
  172. type = 8;
  173. }
  174. break;
  175. }
  176. case c10::TypeKind::ListType:
  177. {
  178. switch (value_node->output()->type()->containedTypes()[0]->kind())
  179. {
  180. case c10::TypeKind::IntType:
  181. {
  182. type = 5;
  183. std::vector<int64_t> i64s = value_node->ival(torch::jit::attr::value).toIntVector();
  184. for (auto i64 : i64s)
  185. {
  186. if (i64 == std::numeric_limits<int64_t>::max()) i64 = INT_MAX;
  187. if (i64 == std::numeric_limits<int64_t>::max() - 1) i64 = INT_MAX - 1;
  188. if (i64 == std::numeric_limits<int64_t>::min()) i64 = INT_MIN;
  189. if (i64 == std::numeric_limits<int64_t>::min() + 1) i64 = INT_MIN + 1;
  190. ai.push_back(i64);
  191. }
  192. break;
  193. }
  194. case c10::TypeKind::FloatType:
  195. {
  196. type = 6;
  197. std::vector<double> fs = value_node->ival(torch::jit::attr::value).toDoubleVector();
  198. for (auto f : fs)
  199. {
  200. af.push_back((float)f);
  201. }
  202. break;
  203. }
  204. default:
  205. {
  206. fprintf(stderr, "unknown Parameter value list element kind %s\n", c10::typeKindToString(value_node->output()->type()->containedTypes()[0]->kind()));
  207. break;
  208. }
  209. }
  210. break;
  211. }
  212. default:
  213. {
  214. fprintf(stderr, "unknown Parameter value kind %s\n", c10::typeKindToString(value_node->output()->type()->kind()));
  215. break;
  216. }
  217. }
  218. }
  219. else if (value_node->kind() == c10::prim::ListConstruct)
  220. {
  221. switch (value_node->output()->type()->cast<c10::ListType>()->getElementType()->kind())
  222. {
  223. case c10::TypeKind::IntType:
  224. {
  225. type = 5;
  226. for (const auto& x : value_node->inputs())
  227. {
  228. if (!x->node()->hasAttribute(torch::jit::attr::value))
  229. {
  230. fprintf(stderr, "no attribute value in int list\n");
  231. ai.push_back(0);
  232. continue;
  233. }
  234. ai.push_back((int)x->node()->i(torch::jit::attr::value));
  235. }
  236. break;
  237. }
  238. case c10::TypeKind::FloatType:
  239. {
  240. type = 6;
  241. for (const auto& x : value_node->inputs())
  242. {
  243. if (!x->node()->hasAttribute(torch::jit::attr::value))
  244. {
  245. fprintf(stderr, "no attribute value in float list\n");
  246. af.push_back(0.f);
  247. continue;
  248. }
  249. af.push_back((float)x->node()->f(torch::jit::attr::value));
  250. }
  251. break;
  252. }
  253. case c10::TypeKind::StringType:
  254. {
  255. type = 7;
  256. for (const auto& x : value_node->inputs())
  257. {
  258. if (!x->node()->hasAttribute(torch::jit::attr::value))
  259. {
  260. fprintf(stderr, "no attribute value in string list\n");
  261. as.push_back("");
  262. continue;
  263. }
  264. as.push_back(x->node()->s(torch::jit::attr::value));
  265. }
  266. break;
  267. }
  268. #if TORCH_VERSION_MAJOR >= 2 || (TORCH_VERSION_MAJOR >= 1 && TORCH_VERSION_MINOR >= 9)
  269. case c10::TypeKind::ComplexType:
  270. {
  271. type = 11;
  272. for (const auto& x : value_node->inputs())
  273. {
  274. if (!x->node()->hasAttribute(torch::jit::attr::value))
  275. {
  276. fprintf(stderr, "no attribute value in complex list\n");
  277. ac.push_back(std::complex<float>(0.f, 0.f));
  278. continue;
  279. }
  280. ac.push_back(std::complex<float>(x->node()->c(torch::jit::attr::value)));
  281. }
  282. break;
  283. }
  284. #endif
  285. default:
  286. {
  287. fprintf(stderr, "unknown Parameter value list element kind %s\n", c10::typeKindToString(value_node->output()->type()->cast<c10::ListType>()->getElementType()->kind()));
  288. break;
  289. }
  290. }
  291. }
  292. else
  293. {
  294. fprintf(stderr, "unknown Parameter value_node kind %s\n", value_node->kind().toDisplayString());
  295. }
  296. }
  297. Parameter::Parameter(const torch::jit::Value* value)
  298. : Parameter(value->node())
  299. {
  300. }
  301. Attribute::Attribute(const at::Tensor& t)
  302. {
  303. type = get_at_tensor_type(t.scalar_type());
  304. const int ndim = (int)t.dim();
  305. if (ndim == 0)
  306. {
  307. shape = {1};
  308. data.resize(type_to_elemsize(type));
  309. if (t.scalar_type() == c10::ScalarType::Long)
  310. {
  311. int64_t i = t.item<int64_t>();
  312. memcpy((void*)data.data(), (const void*)&i, data.size());
  313. }
  314. else if (t.scalar_type() == c10::ScalarType::Int)
  315. {
  316. int i = t.item<int>();
  317. memcpy((void*)data.data(), (const void*)&i, data.size());
  318. }
  319. else if (t.scalar_type() == c10::ScalarType::Double)
  320. {
  321. double f = t.item<double>();
  322. memcpy((void*)data.data(), (const void*)&f, data.size());
  323. }
  324. else if (t.scalar_type() == c10::ScalarType::Float)
  325. {
  326. float f = t.item<float>();
  327. memcpy((void*)data.data(), (const void*)&f, data.size());
  328. }
  329. else
  330. {
  331. fprintf(stderr, "unknown Attribute tensor scalar type %d\n", type);
  332. }
  333. return;
  334. }
  335. shape.resize(ndim);
  336. for (int i = 0; i < ndim; i++)
  337. shape[i] = t.size(i);
  338. if (shape.size() > 0)
  339. {
  340. data.resize(elemcount() * type_to_elemsize(type));
  341. memcpy((void*)data.data(), (const void*)t.cpu().contiguous().data_ptr(), data.size());
  342. }
  343. }
  344. Attribute::Attribute(const TorchTensorProxy& t)
  345. : Attribute(t.t())
  346. {
  347. }
  348. Operand* Graph::new_operand(const torch::jit::Value* v)
  349. {
  350. // Operand* r = new Operand;
  351. // r->name = v->debugName();
  352. Operand* r = new_operand(v->debugName());
  353. r->type = -1;
  354. auto pt = v->type()->cast<c10::TensorType>();
  355. if (pt)
  356. {
  357. if (pt->scalarType().has_value() && pt->dim().has_value())
  358. {
  359. r->type = get_at_tensor_type(pt->scalarType().value());
  360. const int ndim = (int)pt->dim().value();
  361. r->shape.resize(ndim);
  362. for (int i = 0; i < ndim; i++)
  363. {
  364. if (pt->sizes()[i].has_value())
  365. r->shape[i] = (int)pt->sizes()[i].value();
  366. else
  367. r->shape[i] = -1;
  368. }
  369. }
  370. }
  371. // operands.push_back(r);
  372. return r;
  373. }
  374. static c10::ScalarType input_type_to_c10_ScalarType(const std::string& t)
  375. {
  376. if (t == "c64") return torch::kComplexFloat;
  377. if (t == "c32") return torch::kComplexHalf;
  378. if (t == "c128") return torch::kComplexDouble;
  379. if (t == "bf16") return torch::kBFloat16;
  380. if (t == "f32") return torch::kFloat32;
  381. if (t == "f16") return torch::kFloat16;
  382. if (t == "f64") return torch::kFloat64;
  383. if (t == "i32") return torch::kInt32;
  384. if (t == "i16") return torch::kInt16;
  385. if (t == "i64") return torch::kInt64;
  386. if (t == "i8") return torch::kInt8;
  387. if (t == "u8") return torch::kUInt8;
  388. fprintf(stderr, "unsupported type %s fallback to f32\n", t.c_str());
  389. return torch::kFloat32;
  390. }
  391. static const char* get_at_tensor_type_str(const at::ScalarType& st)
  392. {
  393. if (st == c10::ScalarType::Float) return "f32";
  394. if (st == c10::ScalarType::Double) return "f64";
  395. if (st == c10::ScalarType::Half) return "f16";
  396. if (st == c10::ScalarType::Int) return "i32";
  397. if (st == c10::ScalarType::Long) return "i64";
  398. if (st == c10::ScalarType::Short) return "i16";
  399. if (st == c10::ScalarType::Char) return "i8";
  400. if (st == c10::ScalarType::Byte) return "u8";
  401. if (st == c10::ScalarType::ComplexFloat) return "c64";
  402. if (st == c10::ScalarType::ComplexDouble) return "c128";
  403. if (st == c10::ScalarType::ComplexHalf) return "c32";
  404. if (st == c10::ScalarType::BFloat16) return "bf16";
  405. // unknown
  406. fprintf(stderr, "unsupported tensor elem data type %d\n", (int)st);
  407. return "";
  408. }
  409. static void print_shape_list(const std::vector<std::vector<int64_t> >& shapes, const std::vector<std::string>& types)
  410. {
  411. for (size_t i = 0; i < shapes.size(); i++)
  412. {
  413. const std::vector<int64_t>& s = shapes[i];
  414. const std::string& t = types[i];
  415. fprintf(stderr, "[");
  416. for (size_t j = 0; j < s.size(); j++)
  417. {
  418. fprintf(stderr, "%ld", s[j]);
  419. if (j != s.size() - 1)
  420. fprintf(stderr, ",");
  421. }
  422. fprintf(stderr, "]");
  423. fprintf(stderr, "%s", t.c_str());
  424. if (i != shapes.size() - 1)
  425. fprintf(stderr, ",");
  426. }
  427. }
  428. static void append_input(std::vector<std::vector<int64_t> >& input_shapes, std::vector<std::string>& input_types, const torch::jit::IValue& v)
  429. {
  430. if (v.isTensor())
  431. {
  432. const auto& tensor = v.toTensor();
  433. input_shapes.push_back(tensor.sizes().vec());
  434. input_types.push_back(get_at_tensor_type_str(tensor.scalar_type()));
  435. }
  436. else if (v.isList())
  437. {
  438. for (const auto& v2 : v.toList())
  439. append_input(input_shapes, input_types, v2);
  440. }
  441. else if (v.isTuple())
  442. {
  443. for (const auto& v2 : v.toTuple()->elements())
  444. append_input(input_shapes, input_types, v2);
  445. }
  446. else if (v.isGenericDict())
  447. {
  448. for (const auto& kv2 : v.toGenericDict())
  449. append_input(input_shapes, input_types, kv2.value());
  450. }
  451. else
  452. {
  453. fprintf(stderr, "unsupported traced input type %s\n", v.tagKind().c_str());
  454. }
  455. }
  456. static void get_traced_input_shape(const std::string& ptpath, std::vector<std::vector<int64_t> >& input_shapes, std::vector<std::string>& input_types)
  457. {
  458. try
  459. {
  460. // read traced_inputs.pkl
  461. caffe2::serialize::PyTorchStreamReader reader(ptpath);
  462. auto v = torch::jit::readArchiveAndTensors("traced_inputs", "", "traced_inputs/", c10::nullopt, c10::nullopt, c10::nullopt, reader);
  463. if (!v.isGenericDict())
  464. return;
  465. for (const auto& entry : v.toGenericDict())
  466. {
  467. if (entry.key() != "forward")
  468. continue;
  469. append_input(input_shapes, input_types, entry.value());
  470. break;
  471. }
  472. }
  473. catch (...)
  474. {
  475. // no traced_inputs.pkl pass
  476. }
  477. }
  478. static bool check_input_shape(const std::vector<std::vector<int64_t> >& traced_input_shapes, const std::vector<std::string>& traced_input_types, const std::vector<std::vector<int64_t> >& input_shapes, const std::vector<std::string>& input_types)
  479. {
  480. if (input_shapes.size() != traced_input_shapes.size())
  481. {
  482. fprintf(stderr, "input_shape expect %d tensors but got %d\n", (int)traced_input_shapes.size(), (int)input_shapes.size());
  483. return false;
  484. }
  485. for (size_t i = 0; i < traced_input_shapes.size(); i++)
  486. {
  487. bool matched = true;
  488. if (input_shapes[i].size() != traced_input_shapes[i].size())
  489. {
  490. matched = false;
  491. }
  492. else
  493. {
  494. for (size_t j = 0; j < traced_input_shapes[i].size(); j++)
  495. {
  496. if (input_shapes[i][j] != traced_input_shapes[i][j])
  497. matched = false;
  498. }
  499. }
  500. if (input_types[i] != traced_input_types[i])
  501. matched = false;
  502. if (!matched)
  503. {
  504. fprintf(stderr, "input_shapes[%d] expect [", (int)i);
  505. for (size_t j = 0; j < traced_input_shapes[i].size(); j++)
  506. {
  507. fprintf(stderr, "%ld", traced_input_shapes[i][j]);
  508. if (j + 1 != traced_input_shapes[i].size())
  509. fprintf(stderr, ",");
  510. }
  511. fprintf(stderr, "]%s but got ", traced_input_types[i].c_str());
  512. if (input_shapes.empty())
  513. {
  514. fprintf(stderr, "nothing\n");
  515. }
  516. else
  517. {
  518. fprintf(stderr, "[");
  519. for (size_t j = 0; j < input_shapes[i].size(); j++)
  520. {
  521. fprintf(stderr, "%ld", input_shapes[i][j]);
  522. if (j + 1 != input_shapes[i].size())
  523. fprintf(stderr, ",");
  524. }
  525. fprintf(stderr, "]%s\n", input_types[i].c_str());
  526. }
  527. return false;
  528. }
  529. }
  530. return true;
  531. }
  532. int load_torchscript(const std::string& ptpath, Graph& pnnx_graph,
  533. const std::string& device,
  534. const std::vector<std::vector<int64_t> >& input_shapes,
  535. const std::vector<std::string>& input_types,
  536. const std::vector<std::vector<int64_t> >& input_shapes2,
  537. const std::vector<std::string>& input_types2,
  538. const std::vector<std::string>& customop_modules,
  539. const std::vector<std::string>& module_operators,
  540. const std::string& foldable_constants_zippath,
  541. std::set<std::string>& foldable_constants)
  542. {
  543. // get input shape from traced torchscript
  544. std::vector<std::vector<int64_t> > traced_input_shapes;
  545. std::vector<std::string> traced_input_types;
  546. get_traced_input_shape(ptpath, traced_input_shapes, traced_input_types);
  547. if (!traced_input_shapes.empty())
  548. {
  549. fprintf(stderr, "get inputshape from traced inputs\n");
  550. fprintf(stderr, "inputshape = ");
  551. print_shape_list(traced_input_shapes, traced_input_types);
  552. fprintf(stderr, "\n");
  553. if (!input_shapes.empty())
  554. {
  555. // input shape sanity check
  556. if (!check_input_shape(traced_input_shapes, traced_input_types, input_shapes, input_types))
  557. {
  558. return -1;
  559. }
  560. }
  561. // traced torchscript always has static input shapes
  562. // if (!input_shapes2.empty() && !check_input_shape(ptpath, input_shapes2, input_types2))
  563. // {
  564. // return -1;
  565. // }
  566. }
  567. else
  568. {
  569. traced_input_shapes = input_shapes;
  570. traced_input_types = input_types;
  571. }
  572. #ifdef PNNX_TORCHVISION
  573. // call some vision api to register vision ops :P
  574. (void)vision::cuda_version();
  575. #endif
  576. for (auto m : customop_modules)
  577. {
  578. fprintf(stderr, "load custom module %s\n", m.c_str());
  579. #if _WIN32
  580. HMODULE handle = LoadLibraryExA(m.c_str(), NULL, LOAD_WITH_ALTERED_SEARCH_PATH);
  581. if (!handle)
  582. {
  583. fprintf(stderr, "LoadLibraryExA %s failed %d\n", m.c_str(), GetLastError());
  584. }
  585. #else
  586. void* handle = dlopen(m.c_str(), RTLD_LAZY);
  587. if (!handle)
  588. {
  589. fprintf(stderr, "dlopen %s failed %s\n", m.c_str(), dlerror());
  590. }
  591. #endif
  592. }
  593. std::vector<at::Tensor> input_tensors;
  594. for (size_t i = 0; i < traced_input_shapes.size(); i++)
  595. {
  596. const std::vector<int64_t>& shape = traced_input_shapes[i];
  597. const std::string& type = traced_input_types[i];
  598. at::Tensor t = torch::ones(shape, input_type_to_c10_ScalarType(type));
  599. if (device == "gpu")
  600. t = t.cuda();
  601. input_tensors.push_back(t);
  602. }
  603. std::vector<at::Tensor> input_tensors2;
  604. for (size_t i = 0; i < input_shapes2.size(); i++)
  605. {
  606. const std::vector<int64_t>& shape = input_shapes2[i];
  607. const std::string& type = input_types2[i];
  608. at::Tensor t = torch::ones(shape, input_type_to_c10_ScalarType(type));
  609. if (device == "gpu")
  610. t = t.cuda();
  611. input_tensors2.push_back(t);
  612. }
  613. torch::jit::Module mod;
  614. try
  615. {
  616. mod = torch::jit::load(ptpath, (device == "gpu") ? c10::kCUDA : c10::kCPU);
  617. }
  618. catch (const c10::Error& e)
  619. {
  620. fprintf(stderr, "Load torchscript failed: %s\n", e.what());
  621. fprintf(stderr, "Please export model to torchscript as follows\n");
  622. fprintf(stderr, "------------------------------------------\n");
  623. fprintf(stderr, "import torch\n");
  624. fprintf(stderr, "import torchvision.models as models\n\n");
  625. fprintf(stderr, "net = models.resnet18(pretrained=True)\n");
  626. fprintf(stderr, "net = net.eval()\n\n");
  627. fprintf(stderr, "x = torch.rand(1, 3, 224, 224)\n");
  628. fprintf(stderr, "mod = torch.jit.trace(net, x)\n");
  629. fprintf(stderr, "mod.save(\"resnet18.pt\")\n");
  630. fprintf(stderr, "------------------------------------------\n");
  631. return -1;
  632. }
  633. mod.eval();
  634. // mod.dump(true, false, false);
  635. // mod.dump(true, true, true);
  636. auto method = mod.find_method("forward");
  637. if (!method)
  638. {
  639. auto methods = mod.get_methods();
  640. if (methods.empty())
  641. {
  642. fprintf(stderr, "No method in torchscript\n");
  643. return -1;
  644. }
  645. method = methods[0];
  646. fprintf(stderr, "Use method %s as the entrypoint instead of forward\n", method->name().c_str());
  647. }
  648. auto g = method->graph();
  649. // g->dump();
  650. fprintf(stderr, "############# pass_level0\n");
  651. pnnx::pass_level0(mod, g, input_tensors, input_tensors2, module_operators, ptpath, device, foldable_constants, foldable_constants_zippath);
  652. // g->dump();
  653. fprintf(stderr, "############# pass_level1\n");
  654. pnnx::pass_level1(mod, g, module_operators, pnnx_graph);
  655. return 0;
  656. }
  657. } // namespace pnnx