|
|
|
@@ -24,11 +24,13 @@ |
|
|
|
#include <sstream> |
|
|
|
#include <algorithm> |
|
|
|
#include "runtime/device/gpu/trt_loader.h" |
|
|
|
#include "runtime/device/gpu/cuda_driver.h" |
|
|
|
#include "backend/optimizer/trt_pass/trt_op_factory.h" |
|
|
|
#include "backend/kernel_compiler/gpu/trt/trt_utils.h" |
|
|
|
#include "utils/convert_utils.h" |
|
|
|
#include "utils/utils.h" |
|
|
|
#include "utils/singleton.h" |
|
|
|
#include "utils/ms_context.h" |
|
|
|
|
|
|
|
namespace mindspore::opt { |
|
|
|
namespace { |
|
|
|
@@ -121,6 +123,15 @@ void GetRealInputs(const AnfNodePtr &node, std::vector<session::KernelWithIndex> |
|
|
|
} // namespace |
|
|
|
|
|
|
|
bool TrtConverterContext::Init() { |
|
|
|
// Set device id before invoke trt api as cudaSetDevice is thread level config. |
|
|
|
const auto &context = MsContext::GetInstance(); |
|
|
|
const auto &device_id = context->get_param<uint32_t>(MS_CTX_DEVICE_ID); |
|
|
|
bool ret = device::gpu::CudaDriver::SetDevice(UintToInt(device_id)); |
|
|
|
if (!ret) { |
|
|
|
MS_LOG(ERROR) << "Failed to set device id:" << device_id; |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
auto trt_loader = Singleton<device::gpu::TrtLoader>::Instance(); |
|
|
|
builder_ = trt_loader.CreateInferBuilder(&Singleton<TrtLogger>::Instance()); |
|
|
|
MS_EXCEPTION_IF_NULL(builder_); |
|
|
|
|