You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensorrt_opr.cpp 23 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599
  1. /**
  2. * \file src/tensorrt/impl/tensorrt_opr.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/tensorrt/tensorrt_opr.h"
  12. #include "megbrain/tensorrt/tensorrt_engine_cache.h"
  13. #include "megbrain/common.h"
  14. #include "megbrain/plugin/profiler.h"
  15. #include "megbrain/version_symbol.h"
  16. #include "megbrain/utils/timer.h"
  17. #include <cinttypes>
  18. #if MGB_ENABLE_TENSOR_RT
  19. using namespace mgb;
  20. using namespace opr;
  21. using TensorRTManager = intl::TensorRTManager;
  22. namespace {
  23. #if MGB_ENABLE_JSON
  24. class TensorRTProfiler : public nvinfer1::IProfiler {
  25. public:
  26. typedef std::pair<std::string, float> Record;
  27. std::vector<Record> profile;
  28. void reportLayerTime(const char* layerName, float ms) override;
  29. void print_layer_times();
  30. std::shared_ptr<json::Value> to_json();
  31. };
  32. void TensorRTProfiler::reportLayerTime(const char* layerName, float ms) {
  33. profile.push_back(std::make_pair(layerName, ms));
  34. }
  35. void TensorRTProfiler::print_layer_times() {
  36. float total_time = 0;
  37. for (size_t i = 0; i < profile.size(); ++i) {
  38. printf("%s %4.3fms\n", profile[i].first.c_str(), profile[i].second);
  39. total_time += profile[i].second;
  40. }
  41. printf("Total time: %4.3fms\n", total_time);
  42. }
  43. std::shared_ptr<json::Value> TensorRTProfiler::to_json() {
  44. using namespace json;
  45. auto prof_arr = Array::make();
  46. for (auto&& rec : profile) {
  47. auto&& item = Array::make();
  48. item->add(String::make(rec.first));
  49. item->add(Number::make(rec.second));
  50. prof_arr->add(item);
  51. }
  52. return prof_arr;
  53. }
  54. #endif // MGB_ENABLE_JSON
  55. } // anonymous namespace
  56. /* ========================== Logger ========================== */
  57. void TensorRTOpr::Logger::log(nvinfer1::ILogger::Severity severity,
  58. const char* msg) {
  59. switch (severity) {
  60. case Severity::kINTERNAL_ERROR:
  61. mgb_log("TRT_INTERNAL_ERROR: %s", msg);
  62. return;
  63. case Severity::kERROR:
  64. mgb_log("TRT_ERROR: %s", msg);
  65. return;
  66. case Severity::kWARNING:
  67. mgb_log("TRT_WARNING: %s", msg);
  68. return;
  69. case Severity::kINFO:
  70. mgb_log_debug("TRT_INFO: %s", msg);
  71. return;
  72. #if NV_TENSOR_RT_VERSION >= 6001
  73. case Severity::kVERBOSE:
  74. mgb_log_debug("TRT_VERBOSE: %s", msg);
  75. return;
  76. #endif
  77. default:
  78. mgb_log_debug("TRT_UNKNOWN: %s", msg);
  79. return;
  80. }
  81. }
  82. TensorRTOpr::Logger::Logger() {
  83. int expect = NV_TENSORRT_MAJOR * 1000 + NV_TENSORRT_MINOR * 100 +
  84. NV_TENSORRT_PATCH,
  85. got = getInferLibVersion();
  86. mgb_log("loaded TensorRT version: %d", got);
  87. mgb_assert(expect <= got,
  88. "TensorRT library is older than mgb compiled version: got=%d "
  89. "compiled_with=%d",
  90. got, expect);
  91. if (expect != got) {
  92. mgb_log_warn(
  93. "MegBrain is compiled with TensorRT %d but get %d at runtime",
  94. expect, got);
  95. }
  96. }
  97. TensorRTOpr::Logger& TensorRTOpr::Logger::instance() {
  98. static Logger logger;
  99. return logger;
  100. }
  101. /* ========================== GpuAllocator ========================== */
  102. TensorRTOpr::GpuAllocator::GpuAllocator(CompNode cn) : m_cn{cn} {
  103. mgb_assert(cn.device_type() == CompNode::DeviceType::CUDA,
  104. "can not use GPU allocator on comp node %s",
  105. cn.to_string().c_str());
  106. }
  107. TensorRTOpr::GpuAllocator::~GpuAllocator() noexcept {
  108. MGB_LOCK_GUARD(m_ptr2size_mtx);
  109. if (!m_ptr2size.empty()) {
  110. std::string msg{"there are unreleased TRT mem buffers:\n"};
  111. for (auto&& i : m_ptr2size) {
  112. msg.append(ssprintf(" %p: %zu\n", i.first, i.second));
  113. }
  114. mgb_log_error("%sabort now", msg.c_str());
  115. mgb_trap();
  116. }
  117. }
  118. void* TensorRTOpr::GpuAllocator::allocate(uint64_t size, uint64_t alignment,
  119. uint32_t flags) {
  120. static bool enable_log = getenv("MGB_LOG_TRT_MEM_ALLOC");
  121. mgb_assert(!flags && !(alignment & (alignment - 1)),
  122. "flags=%u alignment=%" PRIu64, flags, alignment);
  123. auto ret = m_cn.alloc_device(size);
  124. mgb_assert(!(reinterpret_cast<uintptr_t>(ret) & (alignment - 1)),
  125. "ptr=%p alignment=%" PRIu64, ret, alignment);
  126. if (enable_log) {
  127. mgb_log("trt mem alloc on %s: size=%" PRIu64 " align=%" PRIu64
  128. " ptr=%p",
  129. m_cn.to_string().c_str(), size, alignment, ret);
  130. }
  131. {
  132. MGB_LOCK_GUARD(m_ptr2size_mtx);
  133. m_ptr2size[ret] = size;
  134. }
  135. return ret;
  136. }
  137. void TensorRTOpr::GpuAllocator::free(void* memory) {
  138. {
  139. auto iter = m_ptr2size.find(memory);
  140. mgb_assert(iter != m_ptr2size.end(), "ptr %p not found", memory);
  141. m_ptr2size.erase(iter);
  142. }
  143. m_cn.free_device(memory);
  144. }
  145. /* ========================== TensorRTManager ========================== */
  146. void TensorRTManager::exec(cg::SingleCNOperatorNodeBase* opr,
  147. CompNode comp_node_check,
  148. nvinfer1::ICudaEngine* engine,
  149. size_t batch) {
  150. auto comp_node = opr->comp_node();
  151. // ICudaEngine is bound to the currently active device
  152. comp_node.activate();
  153. if (comp_node_check.valid()) {
  154. mgb_assert(comp_node_check == comp_node,
  155. "gpu allocator is on %s, but execution is on %s",
  156. comp_node_check.to_string().c_str(),
  157. comp_node.to_string().c_str());
  158. }
  159. #if MGB_ENABLE_JSON
  160. auto pf_holder_pair =
  161. opr->owner_graph()
  162. ->options()
  163. .user_data.get_user_data<opr_profile::OprProfileHolder>();
  164. if (m_has_profiler && !pf_holder_pair.second) {
  165. m_context.reset();
  166. m_has_profiler = false;
  167. }
  168. #endif
  169. auto workspace_ptr = opr->output().back()->dev_tensor().raw_ptr();
  170. bool should_reinit_device_memory =
  171. !m_context || m_device_workspace_memory_ptr != workspace_ptr;
  172. if (!m_context) {
  173. m_context = {engine->createExecutionContextWithoutDeviceMemory(), {}};
  174. m_has_profiler = false;
  175. }
  176. m_trt_iobuf.resize(opr->input().size() + opr->output().size() - 1);
  177. bool is_trt_opr = false;
  178. if (opr->same_type<TensorRTOpr>()) {
  179. is_trt_opr = true;
  180. auto network = opr->cast_final_safe<TensorRTOpr>().trt_network_def();
  181. int nr_input = network->getNbInputs();
  182. for (int i = 0; i < nr_input; ++i) {
  183. int binding_idx =
  184. engine->getBindingIndex(network->getInput(i)->getName());
  185. m_trt_iobuf[binding_idx] = opr->input(i)->dev_tensor().raw_ptr();
  186. }
  187. int nr_output = network->getNbOutputs();
  188. for (int i = 0; i < nr_output; ++i) {
  189. int binding_idx =
  190. engine->getBindingIndex(network->getOutput(i)->getName());
  191. m_trt_iobuf[binding_idx] = opr->output(i)->dev_tensor().raw_ptr();
  192. }
  193. } else {
  194. for (size_t i = 0; i < opr->input().size(); ++i) {
  195. m_trt_iobuf[i] = opr->input(i)->dev_tensor().raw_ptr();
  196. }
  197. for (size_t i = 0; i < opr->output().size() - 1; ++i) {
  198. m_trt_iobuf[opr->input().size() + i] =
  199. opr->output(i)->dev_tensor().raw_ptr();
  200. }
  201. }
  202. MGB_MARK_USED_VAR(is_trt_opr);
  203. if (should_reinit_device_memory) {
  204. mgb_assert(opr->output().back()->shape()[0] ==
  205. intl::workspace_size(engine) &&
  206. !(reinterpret_cast<uintptr_t>(workspace_ptr) % 256));
  207. m_context->setDeviceMemory(workspace_ptr);
  208. m_device_workspace_memory_ptr = workspace_ptr;
  209. }
  210. auto&& env = mgb::CompNodeEnv::from_comp_node(comp_node);
  211. bool exec_success = false;
  212. #if MGB_ENABLE_JSON
  213. if (!pf_holder_pair.second) {
  214. mgb_assert(!m_has_profiler,
  215. "Invalid state of TensorRTRuntimeOpr: should not have "
  216. "profiler.");
  217. #if NV_TENSOR_RT_VERSION >= 6001
  218. if (is_trt_opr)
  219. exec_success = m_context->enqueueV2(m_trt_iobuf.data(),
  220. env.cuda_env().stream, nullptr);
  221. else
  222. exec_success = m_context->enqueue(batch, m_trt_iobuf.data(),
  223. env.cuda_env().stream, nullptr);
  224. #else
  225. exec_success = m_context->enqueue(batch, m_trt_iobuf.data(),
  226. env.cuda_env().stream, nullptr);
  227. #endif
  228. mgb_assert(exec_success, "TensorRTOpr failed in execution.");
  229. } else {
  230. TensorRTProfiler trt_profiler;
  231. m_context->setProfiler(&trt_profiler);
  232. m_has_profiler = true;
  233. // TensorRT documentation stated that IExecutionContext->execute
  234. // "Synchronously execute inference on a batch", and it does not take a
  235. // cudaStream_t, we expect it do a device synchronize. But it seems like
  236. // what it really does is execute and sync on its own stream instead of
  237. // synchronize entire device, execute then synchronize again. So we have
  238. // to synchronize before execution to make profiling accurate.
  239. comp_node.sync();
  240. #if NV_TENSOR_RT_VERSION >= 6001
  241. if (is_trt_opr)
  242. exec_success = m_context->executeV2(m_trt_iobuf.data());
  243. else
  244. exec_success = m_context->execute(batch, m_trt_iobuf.data());
  245. #else
  246. exec_success = m_context->execute(batch, m_trt_iobuf.data());
  247. #endif
  248. mgb_assert(exec_success, "trt execution failed: opr=%s", opr->cname());
  249. pf_holder_pair.first[0]->id2object_map[opr] = trt_profiler.to_json();
  250. printf("TRT profile info of opr %s:\n", opr->name().c_str());
  251. trt_profiler.print_layer_times();
  252. }
  253. #else
  254. #if NV_TENSOR_RT_VERSION >= 6001
  255. if (is_trt_opr)
  256. exec_success = m_context->enqueueV2(m_trt_iobuf.data(),
  257. env.cuda_env().stream, nullptr);
  258. else
  259. exec_success = m_context->enqueue(batch, m_trt_iobuf.data(),
  260. env.cuda_env().stream, nullptr);
  261. #else
  262. exec_success = m_context->enqueue(batch, m_trt_iobuf.data(),
  263. env.cuda_env().stream, nullptr);
  264. #endif
  265. mgb_assert(exec_success, "trt execution failed: opr=%s", opr->cname());
  266. #endif
  267. }
  268. /* ========================== TensorRTOpr ========================== */
  269. MGB_DYN_TYPE_OBJ_FINAL_IMPL(TensorRTOpr);
  270. TensorRTOpr::TensorRTOpr(std::shared_ptr<nvinfer1::IBuilder> builder,
  271. std::shared_ptr<nvinfer1::INetworkDefinition> network,
  272. TensorRTGraphFeatureBits feature_bits,
  273. std::shared_ptr<GpuAllocator> gpu_allocator,
  274. const VarNodeArray& inputs,
  275. std::shared_ptr<nvinfer1::ICudaEngine> engine,
  276. const OperatorNodeConfig& config)
  277. : Super(inputs.at(0)->owner_graph(), config, "tensor_rt",
  278. {inputs.at(0)}),
  279. m_gpu_allocator{std::move(gpu_allocator)},
  280. m_network{std::move(network)},
  281. m_builder{std::move(builder)},
  282. m_engine{std::move(engine)},
  283. m_feature_bits{feature_bits} {
  284. mgb_assert(
  285. inputs[0]->comp_node().device_type() == CompNode::DeviceType::CUDA,
  286. "TensorRTOpr can only be used on cuda comp nodes; got %s",
  287. inputs[0]->comp_node().to_string().c_str());
  288. mgb_assert(inputs.size() == static_cast<size_t>(m_network->getNbInputs()),
  289. "inputs size not equal: expect=%zu got=%d", inputs.size(),
  290. m_network->getNbInputs());
  291. for (auto i : inputs) {
  292. add_input({i});
  293. }
  294. if (m_network->getNbOutputs() == 1)
  295. add_output(None);
  296. else {
  297. for (int i = 0; i < m_network->getNbOutputs(); ++i)
  298. add_output(ssprintf("o%d", i));
  299. }
  300. cg::add_workspace_output(this);
  301. add_equivalence_component<mgb::ScalarHash<void*>>(m_network.get());
  302. mgb_assert(m_builder != nullptr);
  303. #if NV_TENSOR_RT_VERSION >= 6001
  304. m_builder_config = {m_builder->createBuilderConfig(),
  305. TensorRTDeleter<nvinfer1::IBuilderConfig>()};
  306. m_builder_config->setMaxWorkspaceSize(1 << 30);
  307. if (m_feature_bits == TensorRTGraphFeatureBits::NCHW4_QINT8) {
  308. mgb_assert(m_builder->platformHasFastInt8(),
  309. "Cuda platform does not support fast native int8");
  310. m_builder_config->setInt8Calibrator(nullptr);
  311. nvinfer1::BuilderFlags flags;
  312. flags = 1 << static_cast<int>(nvinfer1::BuilderFlag::kINT8);
  313. m_builder_config->setFlags(flags);
  314. }
  315. #else
  316. m_builder->setMaxWorkspaceSize(1 << 30);
  317. if (m_feature_bits == TensorRTGraphFeatureBits::NCHW4_QINT8) {
  318. // check has fast int8
  319. m_builder->setInt8Mode(true);
  320. m_builder->setInt8Calibrator(nullptr);
  321. m_builder->setStrictTypeConstraints(false);
  322. }
  323. #endif
  324. if (!m_gpu_allocator) {
  325. m_gpu_allocator =
  326. std::make_shared<GpuAllocator>(inputs[0]->comp_node());
  327. }
  328. m_builder->setGpuAllocator(m_gpu_allocator.get());
  329. }
  330. SymbolVarArray TensorRTOpr::make(
  331. std::shared_ptr<nvinfer1::IBuilder> builder,
  332. std::shared_ptr<nvinfer1::INetworkDefinition> network,
  333. TensorRTGraphFeatureBits feature_bits,
  334. std::shared_ptr<GpuAllocator> gpu_allocator, const SymbolVarArray& src,
  335. std::shared_ptr<nvinfer1::ICudaEngine> engine,
  336. const OperatorNodeConfig& config) {
  337. VarNodeArray var_node_array = cg::to_var_node_array(src);
  338. auto tensor_rt_opr = std::make_unique<TensorRTOpr>(
  339. std::move(builder), std::move(network), feature_bits,
  340. std::move(gpu_allocator), var_node_array, std::move(engine),
  341. config);
  342. auto ret = cg::to_symbol_var_array(
  343. src[0].node()
  344. ->owner_graph()
  345. ->insert_opr(std::move(tensor_rt_opr))
  346. ->output());
  347. ret.pop_back(); // remove workspace
  348. return ret;
  349. }
  350. TensorShape TensorRTOpr::dims2shape(const nvinfer1::Dims& dims, size_t batch) {
  351. TensorShape ret;
  352. ret.ndim = dims.nbDims;
  353. if (batch > 0)
  354. ++ret.ndim;
  355. mgb_assert(ret.ndim <= TensorShape::MAX_NDIM,
  356. "TensorShape ndim > MAX_NDIM");
  357. if (batch > 0) {
  358. ret[0] = batch;
  359. for (size_t i = 1; i < ret.ndim; ++i) {
  360. ret[i] = dims.d[i-1];
  361. }
  362. } else {
  363. for (size_t i = 0; i < ret.ndim; ++i) {
  364. ret[i] = dims.d[i];
  365. }
  366. }
  367. return ret;
  368. }
  369. void TensorRTOpr::set_input_by_tensor_shape(
  370. nvinfer1::ITensor* const input, const TensorShape& tensor_shape) const {
  371. nvinfer1::Dims dims = input->getDimensions();
  372. #if NV_TENSOR_RT_VERSION >= 6001
  373. auto tensor_format = input->getAllowedFormats();
  374. if (tensor_format &
  375. (1 << static_cast<int>(nvinfer1::TensorFormat::kCHW4))) {
  376. mgb_assert(dims.nbDims == 4 && tensor_shape.ndim == 5 &&
  377. tensor_shape[4] == 4,
  378. "input tensor format need to be NCHW4(got: %s)",
  379. tensor_shape.to_string().c_str());
  380. for (int i = 0; i < dims.nbDims; i++) {
  381. dims.d[i] = tensor_shape.shape[i];
  382. }
  383. dims.d[1] *= 4;
  384. } else {
  385. mgb_assert(tensor_format &
  386. (1 << static_cast<int>(nvinfer1::TensorFormat::kLINEAR)));
  387. mgb_assert(static_cast<int>(tensor_shape.ndim) == dims.nbDims,
  388. "input dim is not qual to which in trt network created");
  389. for (size_t i = 0; i < tensor_shape.ndim; i++) {
  390. dims.d[i] = tensor_shape.shape[i];
  391. }
  392. }
  393. #else
  394. mgb_assert(static_cast<int>(tensor_shape.ndim) == dims.nbDims,
  395. "input dim is not qual to which in trt network created");
  396. for (size_t i = 0; i < tensor_shape.ndim; i++) {
  397. dims.d[i] = tensor_shape.shape[i];
  398. }
  399. #endif
  400. input->setDimensions(dims);
  401. }
  402. void TensorRTOpr::init_output_dtype() {
  403. auto get_mgb_dtype_from_itensor = [](nvinfer1::ITensor* tensor) -> DType {
  404. switch (tensor->getType()) {
  405. case nvinfer1::DataType::kFLOAT:
  406. return dtype::Float32();
  407. case nvinfer1::DataType::kHALF:
  408. return dtype::Float16();
  409. case nvinfer1::DataType::kINT8: {
  410. #if NV_TENSOR_RT_VERSION >= 5020
  411. #if NV_TENSOR_RT_VERSION >= 5120
  412. auto range_max = tensor->getDynamicRangeMax(),
  413. range_min = tensor->getDynamicRangeMin();
  414. auto range = std::max(range_max, range_min);
  415. #else
  416. auto range = tensor->getDynamicRange();
  417. #endif
  418. mgb_assert(range >= 0,
  419. "trt dynamic range should be non-negative");
  420. static constexpr int8_t i_max =
  421. std::numeric_limits<int8_t>().max();
  422. float scale =
  423. static_cast<float>(range) / static_cast<float>(i_max);
  424. return dtype::QuantizedS8{scale};
  425. #else
  426. return dtype::Int8();
  427. #endif
  428. }
  429. case nvinfer1::DataType::kINT32:
  430. return dtype::Int32();
  431. default:
  432. mgb_throw(InternalError,
  433. "trt DataType should be kFLOAT/kHALF/kINT8/kINT32.");
  434. }
  435. };
  436. for (int i = 0; i < m_network->getNbOutputs(); ++i) {
  437. output(i)->dtype(get_mgb_dtype_from_itensor(m_network->getOutput(i)));
  438. }
  439. }
  440. void TensorRTOpr::get_output_var_shape(const TensorShapeArray& inp_shape,
  441. TensorShapeArray& out_shape) const {
  442. for (size_t i = 0; i < inp_shape.size(); ++i) {
  443. set_input_by_tensor_shape(m_network->getInput(i), inp_shape[i]);
  444. }
  445. for (int i = 0; i < m_network->getNbOutputs(); ++i) {
  446. #if NV_TENSOR_RT_VERSION >= 6001
  447. auto output = m_network->getOutput(i);
  448. out_shape[i] = dims2shape(output->getDimensions());
  449. auto tensor_format = output->getAllowedFormats();
  450. // fix tensor shape from tensor format
  451. if (tensor_format &
  452. (1 << static_cast<int>(nvinfer1::TensorFormat::kCHW4))) {
  453. mgb_assert(out_shape[i].ndim == 4);
  454. out_shape[i].ndim++;
  455. out_shape[i].shape[1] /= 4;
  456. out_shape[i].shape[4] = 4;
  457. }
  458. #else
  459. out_shape[i] = dims2shape(m_network->getOutput(i)->getDimensions());
  460. #endif
  461. }
  462. // Because input shape is NCHW, so the batch size should always be 1.
  463. m_builder->setMaxBatchSize(1);
  464. auto self = const_cast<TensorRTOpr*>(this);
  465. if (m_engine == nullptr && TensorRTEngineCache::enable_engine_cache()) {
  466. self->build_engine_from_cache();
  467. }
  468. bool engine_valid = true;
  469. if (m_engine == nullptr) {
  470. engine_valid = false;
  471. } else {
  472. int nr_input = m_network->getNbInputs();
  473. mgb_assert(static_cast<size_t>(nr_input) == input().size(),
  474. "input size changed");
  475. for (int i = 0; i < nr_input; ++i) {
  476. int binding_idx = m_engine->getBindingIndex(
  477. m_network->getInput(i)->getName());
  478. auto cuda_engine_shp =
  479. dims2shape(m_engine->getBindingDimensions(binding_idx));
  480. #if NV_TENSOR_RT_VERSION >= 6001
  481. auto tensor_format = m_engine->getBindingFormat(binding_idx);
  482. // fix tensor shape from tensor format
  483. if (tensor_format == nvinfer1::TensorFormat::kCHW4) {
  484. mgb_assert(cuda_engine_shp.ndim == 4);
  485. cuda_engine_shp.ndim++;
  486. cuda_engine_shp[1] /= 4;
  487. cuda_engine_shp[4] = 4;
  488. }
  489. #endif
  490. if (!cuda_engine_shp.eq_shape(inp_shape[i])) {
  491. engine_valid = false;
  492. break;
  493. }
  494. }
  495. }
  496. if (!engine_valid) {
  497. // If a context created by a cuda engine, the context must be destroyed
  498. // before the corresponding cuda engine. Otherwise, a segmentfault will
  499. // occur.
  500. self->m_manager.clear_trt_context();
  501. RealTimer timer;
  502. #if NV_TENSOR_RT_VERSION >= 6001
  503. self->m_engine = {
  504. m_builder->buildEngineWithConfig(*m_network, *m_builder_config),
  505. TensorRTDeleter<nvinfer1::ICudaEngine>()};
  506. #else
  507. self->m_engine = {m_builder->buildCudaEngine(*m_network),
  508. TensorRTDeleter<nvinfer1::ICudaEngine>()};
  509. #endif
  510. mgb_assert(m_engine != nullptr, "build engine failed");
  511. mgb_log_warn("TensorRTOpr(name:%s) engine build time %.2f ms", cname(),
  512. timer.get_msecs());
  513. if (TensorRTEngineCache::enable_engine_cache()) {
  514. serialize_engine_to_cache();
  515. }
  516. }
  517. out_shape.back() = {intl::workspace_size(m_engine.get())};
  518. }
  519. void TensorRTOpr::add_input_layout_constraint() {
  520. for (auto i : input()) {
  521. i->add_layout_constraint_contiguous();
  522. }
  523. }
  524. void TensorRTOpr::scn_do_execute() {
  525. m_manager.exec(this, m_gpu_allocator->comp_node(), m_engine.get());
  526. }
  527. void TensorRTOpr::build_engine_from_cache() {
  528. TensorRTUniquePtr<nvinfer1::IRuntime> runtime{
  529. nvinfer1::createInferRuntime(TensorRTOpr::Logger::instance()), {}};
  530. runtime->setGpuAllocator(m_gpu_allocator.get());
  531. auto ret = TensorRTEngineCache::inst().get(
  532. TensorRTEngineCache::make_key_from_trt_opr(this));
  533. if (!ret.valid())
  534. return;
  535. auto engine = runtime->deserializeCudaEngine(
  536. reinterpret_cast<const void*>(ret->ptr), ret->size, nullptr);
  537. mgb_assert(engine, "failed to deserialize ICudaEngine");
  538. m_engine = {engine, TensorRTDeleter<nvinfer1::ICudaEngine>()};
  539. }
  540. void TensorRTOpr::serialize_engine_to_cache() const {
  541. TensorRTUniquePtr<nvinfer1::IHostMemory> buf{trt_cuda_engine()->serialize(),
  542. {}};
  543. mgb_assert(buf, "failed to serialize ICudaEngine");
  544. TensorRTEngineCache::inst().put(
  545. TensorRTEngineCache::make_key_from_trt_opr(this),
  546. {buf->data(), buf->size()});
  547. }
  548. MGB_VERSION_SYMBOL3(TENSORRT, NV_TENSORRT_MAJOR, NV_TENSORRT_MINOR,
  549. NV_TENSORRT_PATCH);
  550. #endif // MGB_ENABLE_TENSOR_RT
  551. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台