| @@ -24,11 +24,13 @@ | |||||
| #include <sstream> | #include <sstream> | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include "runtime/device/gpu/trt_loader.h" | #include "runtime/device/gpu/trt_loader.h" | ||||
| #include "runtime/device/gpu/cuda_driver.h" | |||||
| #include "backend/optimizer/trt_pass/trt_op_factory.h" | #include "backend/optimizer/trt_pass/trt_op_factory.h" | ||||
| #include "backend/kernel_compiler/gpu/trt/trt_utils.h" | #include "backend/kernel_compiler/gpu/trt/trt_utils.h" | ||||
| #include "utils/convert_utils.h" | #include "utils/convert_utils.h" | ||||
| #include "utils/utils.h" | #include "utils/utils.h" | ||||
| #include "utils/singleton.h" | #include "utils/singleton.h" | ||||
| #include "utils/ms_context.h" | |||||
| namespace mindspore::opt { | namespace mindspore::opt { | ||||
| namespace { | namespace { | ||||
| @@ -121,6 +123,15 @@ void GetRealInputs(const AnfNodePtr &node, std::vector<session::KernelWithIndex> | |||||
| } // namespace | } // namespace | ||||
| bool TrtConverterContext::Init() { | bool TrtConverterContext::Init() { | ||||
| // Set device id before invoke trt api as cudaSetDevice is thread level config. | |||||
| const auto &context = MsContext::GetInstance(); | |||||
| const auto &device_id = context->get_param<uint32_t>(MS_CTX_DEVICE_ID); | |||||
| bool ret = device::gpu::CudaDriver::SetDevice(UintToInt(device_id)); | |||||
| if (!ret) { | |||||
| MS_LOG(ERROR) << "Failed to set device id:" << device_id; | |||||
| return false; | |||||
| } | |||||
| auto trt_loader = Singleton<device::gpu::TrtLoader>::Instance(); | auto trt_loader = Singleton<device::gpu::TrtLoader>::Instance(); | ||||
| builder_ = trt_loader.CreateInferBuilder(&Singleton<TrtLogger>::Instance()); | builder_ = trt_loader.CreateInferBuilder(&Singleton<TrtLogger>::Instance()); | ||||
| MS_EXCEPTION_IF_NULL(builder_); | MS_EXCEPTION_IF_NULL(builder_); | ||||