|
|
|
@@ -334,12 +334,33 @@ Backend::Backend(const std::string &name) : name_(name) { |
|
|
|
simu_flag_ = false; |
|
|
|
} |
|
|
|
|
|
|
|
bool IsCloudTransSessDeviceId() { |
|
|
|
auto deploy_mode = common::GetEnv("DEPLOY_MODE"); |
|
|
|
if (deploy_mode.empty() || deploy_mode != "1") { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
auto rank_size = common::GetEnv("RANK_SIZE"); |
|
|
|
if (rank_size.empty() || rank_size != "1") { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_t device_id) : Backend(name) { |
|
|
|
convert_fn_ = std::bind(&MsBackend::MsConvert, this, std::placeholders::_1, std::placeholders::_2); |
|
|
|
target_sess_ = session::SessionFactory::Get().Create(target); |
|
|
|
if (target_sess_ == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "Session create failed!, please make sure target device:" << target << " is available."; |
|
|
|
} |
|
|
|
|
|
|
|
MS_LOG(INFO) << "Before trans, device id: " << device_id; |
|
|
|
if (IsCloudTransSessDeviceId()) { |
|
|
|
device_id = 0; |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "After trans, device id: " << device_id; |
|
|
|
|
|
|
|
target_sess_->Init(device_id); |
|
|
|
target_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback); |
|
|
|
target_device_ = target; |
|
|
|
|