From bdc67ee2ca2a9198bc42e06b5b170d33bbabf3d1 Mon Sep 17 00:00:00 2001 From: changzherui Date: Thu, 23 Jul 2020 11:31:41 +0800 Subject: [PATCH] modify device id --- mindspore/ccsrc/utils/context/ms_context.cc | 3 ++- mindspore/ccsrc/vm/backend.cc | 21 +++++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/utils/context/ms_context.cc b/mindspore/ccsrc/utils/context/ms_context.cc index 6c3986e422..3c2429f89e 100644 --- a/mindspore/ccsrc/utils/context/ms_context.cc +++ b/mindspore/ccsrc/utils/context/ms_context.cc @@ -82,6 +82,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) { if (IsCloudTransDeviceId()) { device_id_ = 0; } + MS_LOG(INFO) << "context logic id: " << device_id_ << "context physics id: " << physics_id_; backend_policy_ = policy_map_[policy]; device_target_ = target; @@ -172,7 +173,7 @@ bool MsContext::set_device_id(uint32_t device_id) { if (IsCloudTransDeviceId()) { device_id_ = 0; } - MS_LOG(INFO) << "ms set context logic id:" << device_id; + MS_LOG(INFO) << "ms set context logic id:" << device_id_; return true; } diff --git a/mindspore/ccsrc/vm/backend.cc b/mindspore/ccsrc/vm/backend.cc index 4341427a67..135ced65f7 100644 --- a/mindspore/ccsrc/vm/backend.cc +++ b/mindspore/ccsrc/vm/backend.cc @@ -334,12 +334,33 @@ Backend::Backend(const std::string &name) : name_(name) { simu_flag_ = false; } +bool IsCloudTransSessDeviceId() { + auto deploy_mode = common::GetEnv("DEPLOY_MODE"); + if (deploy_mode.empty() || deploy_mode != "1") { + return false; + } + + auto rank_size = common::GetEnv("RANK_SIZE"); + if (rank_size.empty() || rank_size != "1") { + return false; + } + + return true; +} + MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_t device_id) : Backend(name) { convert_fn_ = std::bind(&MsBackend::MsConvert, this, std::placeholders::_1, std::placeholders::_2); target_sess_ = session::SessionFactory::Get().Create(target); if (target_sess_ == nullptr) { MS_LOG(EXCEPTION) << "Session create failed!, please make sure target device:" << target << " is available."; } + + MS_LOG(INFO) << "Before trans, device id: " << device_id; + if (IsCloudTransSessDeviceId()) { + device_id = 0; + } + MS_LOG(INFO) << "After trans, device id: " << device_id; + target_sess_->Init(device_id); target_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback); target_device_ = target;