diff --git a/mindspore/ccsrc/utils/context/context_extends.cc b/mindspore/ccsrc/utils/context/context_extends.cc index 1e4ebaf7a2..9c987b6f48 100644 --- a/mindspore/ccsrc/utils/context/context_extends.cc +++ b/mindspore/ccsrc/utils/context/context_extends.cc @@ -74,9 +74,9 @@ bool OpenTsd(const std::shared_ptr &ms_context_ptr) { } MS_LOG(INFO) << "Device id = " << device_id << ", rank size = " << rank_size << "."; - TDT_StatusT status = TsdOpen(device_id, rank_size); - if (status != TDT_OK) { - MS_LOG(EXCEPTION) << "Device " << device_id << " open tsd failed, status = " << status << "."; + auto ret = rtSetDevice(device_id); + if (ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Device " << device_id << " call rtSetDevice failed, ret[" << static_cast(ret) << "]"; return false; } ms_context_ptr->increase_param(MS_CTX_TSD_REF); @@ -125,13 +125,13 @@ bool CloseTsd(const std::shared_ptr &ms_context_ptr, bool force) { } #endif auto device_id = ms_context_ptr->get_param(MS_CTX_DEVICE_ID); - TDT_StatusT status = TsdClose(device_id); - if (status != TDT_OK) { - MS_LOG(EXCEPTION) << "Close tsd failed, status = " << status << "."; + auto ret = rtDeviceReset(device_id); + if (ret != RT_ERROR_NONE) { + MS_LOG(EXCEPTION) << "Device " << device_id << " call rtDeviceReset failed, ret[" << static_cast(ret) << "]"; return false; } ms_context_ptr->set_param(MS_CTX_IS_PYNATIVE_GE_INIT, false); - MS_LOG(INFO) << "Destroy and close tsd successful, status = " << status << "."; + MS_LOG(INFO) << "Call rtDeviceReset, destroy and close tsd successful, ret[" << static_cast(ret) << "]"; } else { MS_LOG(DEBUG) << "TDT Dataset client is used, no need to close, tsd reference = " << ms_context_ptr->get_param(MS_CTX_TSD_REF) << "."; diff --git a/mindspore/ccsrc/utils/context/context_extends.h b/mindspore/ccsrc/utils/context/context_extends.h index 64408e277b..36e6036e17 100644 --- a/mindspore/ccsrc/utils/context/context_extends.h +++ b/mindspore/ccsrc/utils/context/context_extends.h @@ -27,6 +27,7 @@ #include "tdt/tsd_client.h" #include "tdt/tdt_host_interface.h" #include "tdt/data_common.h" +#include "runtime/dev.h" #endif #ifdef ENABLE_GE #include "transform/graph_ir/df_graph_manager.h"