You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

context_extends.cc 14 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "utils/context/context_extends.h"
  17. #include <map>
  18. #include <string>
  19. #include <memory>
  20. #include <thread>
  21. #include "pybind11/pybind11.h"
  22. #include "utils/ms_utils.h"
  23. #include "utils/convert_utils_base.h"
  24. #ifndef NO_DLIB
  25. #include "acl/acl_tdt.h"
  26. #include "runtime/dev.h"
  27. #include "toolchain/plog.h"
  28. #include "common/util/error_manager/error_manager.h"
  29. #endif
  30. #ifdef ENABLE_GE
  31. #include "transform/graph_ir/df_graph_manager.h"
  32. #endif
  33. #include "profiler/device/profiling.h"
  34. namespace py = pybind11;
  35. namespace mindspore {
  36. namespace context {
  37. #ifdef ENABLE_GE
  38. using mindspore::transform::DfGraphManager;
  39. #endif
  40. constexpr auto kUnknowErrorString = "Unknown error occurred";
  41. #ifndef NO_DLIB
  42. // Open tdt dataset
  43. bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) {
  44. if (ms_context_ptr == nullptr) {
  45. MS_LOG(EXCEPTION) << "nullptr";
  46. }
  47. if (ms_context_ptr->get_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT)) {
  48. return true;
  49. }
  50. if (ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF)) {
  51. MS_LOG(DEBUG) << "ACLTDT Dataset client is already opened.";
  52. ms_context_ptr->increase_param<uint32_t>(MS_CTX_TSD_REF);
  53. return true;
  54. }
  55. auto role = common::GetEnv("MS_ROLE");
  56. if (strcmp(role.c_str(), "MS_SCHED") == 0 || strcmp(role.c_str(), "MS_PSERVER") == 0) {
  57. return true;
  58. }
  59. uint32_t rank_size = 1;
  60. uint32_t device_id = ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
  61. auto rank_size_env = common::GetEnv("RANK_SIZE");
  62. if (rank_size_env.empty()) {
  63. MS_LOG(INFO) << "Should config rank size.";
  64. rank_size = 1;
  65. } else {
  66. int rank_env = std::stoi(rank_size_env);
  67. if (rank_env <= 0) {
  68. MS_LOG(EXCEPTION) << "Error rank size " << rank_env << ".";
  69. }
  70. rank_size = IntToUint(rank_env);
  71. }
  72. int log_ret = DlogReportInitialize();
  73. if (log_ret != 0) {
  74. MS_LOG(WARNING) << "Init slog failed, ret = " << log_ret;
  75. }
  76. MS_LOG(INFO) << "Device id = " << device_id << ", rank size = " << rank_size << ".";
  77. auto ret = rtSetDevice(static_cast<int32_t>(device_id));
  78. if (ret != RT_ERROR_NONE) {
  79. const std::string &error_message = ErrorManager::GetInstance().GetErrorMessage();
  80. if (!error_message.empty() && error_message.find(kUnknowErrorString) == std::string::npos) {
  81. MS_LOG(ERROR) << "Ascend error occurred, error message:\n" << error_message;
  82. }
  83. MS_LOG(EXCEPTION) << "Device " << device_id << " call rtSetDevice failed, ret[" << static_cast<int>(ret) << "]";
  84. }
  85. ms_context_ptr->increase_param<uint32_t>(MS_CTX_TSD_REF);
  86. #ifdef ENABLE_TDTQUE
  87. auto thread_crt = [](const std::string &path, const acltdtChannelHandle *acl_handle) {
  88. return std::thread(TensorPrint(path, acl_handle));
  89. };
  90. ms_context_ptr->CreateTensorPrintThread(thread_crt);
  91. #endif
  92. return true;
  93. }
  94. bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
  95. if (ms_context_ptr == nullptr) {
  96. MS_LOG(EXCEPTION) << "ms_context_prt is nullptr";
  97. }
  98. if (ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) == 0) {
  99. return true;
  100. }
  101. ms_context_ptr->decrease_param<uint32_t>(MS_CTX_TSD_REF);
  102. if (force || ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) == 0) {
  103. ms_context_ptr->set_param<uint32_t>(MS_CTX_TSD_REF, 0);
  104. #ifdef ENABLE_TDTQUE
  105. py::gil_scoped_release gil_release;
  106. ms_context_ptr->DestroyTensorPrintThread();
  107. #endif
  108. uint32_t device_id = ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
  109. auto ret = rtDeviceReset(static_cast<int32_t>(device_id));
  110. if (ret != RT_ERROR_NONE) {
  111. const std::string &error_message = ErrorManager::GetInstance().GetErrorMessage();
  112. if (!error_message.empty() && error_message.find(kUnknowErrorString) == std::string::npos) {
  113. MS_LOG(ERROR) << "Ascend error occurred, error message:\n" << error_message;
  114. }
  115. MS_LOG(EXCEPTION) << "Device " << device_id << " call rtDeviceReset failed, ret[" << static_cast<int>(ret) << "]";
  116. return false;
  117. }
  118. ms_context_ptr->set_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT, false);
  119. MS_LOG(INFO) << "Call rtDeviceReset, destroy and close tsd successful, ret[" << static_cast<int>(ret) << "]";
  120. (void)DlogReportFinalize();
  121. } else {
  122. MS_LOG(DEBUG) << "Acltdt Dataset client is used, no need to close, tsd reference = "
  123. << ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) << ".";
  124. }
  125. return true;
  126. }
  127. #else
  128. bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) { return true; }
  129. bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool) { return true; }
  130. #endif
  131. void SetDisableReuseMemoryFlag(std::map<std::string, std::string> *ge_options) {
  132. auto env_disable_reuse_memory = common::GetEnv("DISABLE_REUSE_MEMORY");
  133. if (!env_disable_reuse_memory.empty()) {
  134. (*ge_options)["ge.exec.disableReuseMemory"] = env_disable_reuse_memory;
  135. } else {
  136. (*ge_options)["ge.exec.disableReuseMemory"] = "0";
  137. MS_LOG(WARNING) << "DISABLE_REUSE_MEMORY is not set in ENV. Now set to default value 0";
  138. }
  139. }
  140. void GetGeOptions(const std::shared_ptr<MsContext> &ms_context_ptr, std::map<std::string, std::string> *ge_options) {
  141. if (ms_context_ptr == nullptr) {
  142. MS_LOG(EXCEPTION) << "nullptr";
  143. }
  144. #ifdef ENABLE_GE
  145. (*ge_options)["device_id"] = "0";
  146. (*ge_options)["ge.exec.enableDump"] = std::to_string(ms_context_ptr->get_param<bool>(MS_CTX_ENABLE_DUMP));
  147. (*ge_options)["ge.exec.dumpPath"] = ms_context_ptr->get_param<std::string>(MS_CTX_SAVE_DUMP_PATH);
  148. (*ge_options)["ge.exec.dumpMode"] = "output";
  149. MS_LOG(INFO) << "The enable dump state is " << std::to_string(ms_context_ptr->get_param<bool>(MS_CTX_ENABLE_DUMP))
  150. << " and save dump path is " << ms_context_ptr->get_param<std::string>(MS_CTX_SAVE_DUMP_PATH) << ".";
  151. auto profiler_manager = profiler::ProfilerManager::GetInstance();
  152. if (profiler_manager == nullptr) {
  153. MS_LOG(EXCEPTION) << "Profiler manager is nullptr";
  154. }
  155. (*ge_options)["ge.exec.profilingMode"] = std::to_string(profiler_manager->GetProfilingEnableFlag());
  156. if (profiler_manager->GetProfilingEnableFlag()) {
  157. (*ge_options)["ge.exec.profilingOptions"] = profiler_manager->GetProfilingOptions();
  158. }
  159. (*ge_options)["rank_table_file"] = "";
  160. auto env_ddk_version = common::GetEnv("DDK_VERSION");
  161. if (!env_ddk_version.empty()) {
  162. (*ge_options)["ge.DDK_version"] = env_ddk_version;
  163. } else {
  164. (*ge_options)["ge.DDK_version"] = "1.60.T17.B830";
  165. }
  166. (*ge_options)["graphType"] = "1";
  167. if (ms_context_ptr->get_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE) != "0") {
  168. (*ge_options)["ge.graphMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE);
  169. }
  170. if (ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE) != "0") {
  171. (*ge_options)["ge.variableMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE);
  172. }
  173. #if ENABLE_TRAIN == 1
  174. (*ge_options)["ge.graphRunMode"] = "1";
  175. #endif
  176. SetDisableReuseMemoryFlag(ge_options);
  177. SetHcclOptions(ms_context_ptr, ge_options);
  178. auto env_job_id = common::GetEnv("JOB_ID");
  179. if (!env_job_id.empty()) {
  180. (*ge_options)["ge.exec.jobId"] = env_job_id;
  181. } else {
  182. (*ge_options)["ge.exec.jobId"] = "0";
  183. MS_LOG(WARNING) << "JOB_ID is not set in ENV. Now set to default value 0";
  184. }
  185. auto env_fe_flag = common::GetEnv("FE_FLAG");
  186. if (!env_fe_flag.empty()) {
  187. (*ge_options)["ge.feFlag"] = env_fe_flag;
  188. MS_LOG(INFO) << "Use FE, make sure fe lib is set in OPTION_EXEC_EXTERN_PLUGIN_PATH.";
  189. }
  190. auto env_aicpu_flag = common::GetEnv("AICPU_FLAG");
  191. if (!env_aicpu_flag.empty()) {
  192. (*ge_options)["ge.aicpuFlag"] = env_aicpu_flag;
  193. MS_LOG(INFO) << "Use AICPU, make sure aicpu lib is set in OPTION_EXEC_EXTERN_PLUGIN_PATH.";
  194. }
  195. auto proto_lib_path = common::GetEnv("OPTION_PROTO_LIB_PATH");
  196. if (!proto_lib_path.empty()) {
  197. char real_path[PATH_MAX] = {0};
  198. if (realpath(proto_lib_path.c_str(), real_path)) {
  199. proto_lib_path = real_path;
  200. (*ge_options)["ge.opsProtoLibPath"] = proto_lib_path;
  201. }
  202. } else {
  203. MS_LOG(WARNING) << "Set proto lib path failed!";
  204. }
  205. (*ge_options)["ge.exec.precision_mode"] = "force_fp16";
  206. // Disable the global variable acc, only enable it while adding training graph in pipeline
  207. (*ge_options)["ge.exec.variable_acc"] = "0";
  208. #endif
  209. }
  210. void SetHcclOptions(const std::shared_ptr<MsContext> &ms_context_ptr, std::map<std::string, std::string> *ge_options) {
  211. if (ms_context_ptr == nullptr) {
  212. MS_LOG(EXCEPTION) << "nullptr";
  213. }
  214. auto env_table_file = common::GetEnv("RANK_TABLE_FILE");
  215. auto env_rank_id = common::GetEnv("RANK_ID");
  216. auto env_device_id = std::to_string(ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID));
  217. if (!(env_table_file.empty() || env_rank_id.empty())) {
  218. MS_LOG(INFO) << "Initialize Ge for distribute parameter";
  219. MS_LOG(INFO) << "Use hccl, make sure hccl lib is set in OPTION_EXEC_EXTERN_PLUGIN_PATH.";
  220. auto env_hccl_flag = common::GetEnv("HCCL_FLAG");
  221. if (!env_hccl_flag.empty()) {
  222. (*ge_options)["ge.exec.hcclFlag"] = env_hccl_flag;
  223. }
  224. (*ge_options)["ge.exec.isUseHcom"] = "1";
  225. (*ge_options)["ge.exec.deviceId"] = env_device_id;
  226. (*ge_options)["ge.exec.rankId"] = env_rank_id;
  227. (*ge_options)["ge.exec.podName"] = env_rank_id;
  228. (*ge_options)["ge.exec.rankTableFile"] = env_table_file;
  229. (*ge_options)["ge.graphRunMode"] = "1";
  230. } else {
  231. // device id is still needed for non-distribute case
  232. (*ge_options)["ge.exec.deviceId"] = env_device_id;
  233. MS_LOG(INFO) << "No hccl mode. "
  234. "If use hccl, make sure [RANK_TABLE_FILE,RANK_ID,DEVICE_ID,DEPLOY_MODE] all be set in ENV.";
  235. }
  236. auto env_deploy_mode = common::GetEnv("DEPLOY_MODE");
  237. if (!env_deploy_mode.empty()) {
  238. (*ge_options)["ge.exec.deployMode"] = env_deploy_mode;
  239. } else {
  240. (*ge_options)["ge.exec.deployMode"] = "0";
  241. MS_LOG(WARNING) << "DEPLOY_MODE is not set in ENV. Now set to default value 0";
  242. }
  243. }
  244. bool InitGe(const std::shared_ptr<MsContext> &ms_context_ptr) {
  245. if (ms_context_ptr == nullptr) {
  246. MS_LOG(EXCEPTION) << "nullptr";
  247. }
  248. #ifdef ENABLE_GE
  249. if (ms_context_ptr->get_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT)) {
  250. return true;
  251. }
  252. if (ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF)) {
  253. ms_context_ptr->increase_param<uint32_t>(MS_CTX_GE_REF);
  254. return true;
  255. }
  256. std::map<std::string, std::string> ge_options;
  257. GetGeOptions(ms_context_ptr, &ge_options);
  258. {
  259. // Release GIL before calling into (potentially long-running) C++ code
  260. py::gil_scoped_release release;
  261. if (ge::GEInitialize(ge_options) != ge::GRAPH_SUCCESS) {
  262. MS_LOG(EXCEPTION) << "Initialize GE failed!";
  263. }
  264. }
  265. ms_context_ptr->increase_param<uint32_t>(MS_CTX_GE_REF);
  266. MS_LOG(INFO) << "Init ge successful, ge reference = " << ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) << ".";
  267. #endif
  268. return true;
  269. }
  270. bool PynativeInitGe(const std::shared_ptr<MsContext> &ms_context_ptr) {
  271. if (ms_context_ptr == nullptr) {
  272. MS_LOG(EXCEPTION) << "nullptr";
  273. }
  274. if (ms_context_ptr->get_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT) ||
  275. ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) || ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF)) {
  276. return true;
  277. }
  278. (void)OpenTsd(ms_context_ptr);
  279. (void)InitGe(ms_context_ptr);
  280. ms_context_ptr->set_param(MS_CTX_IS_PYNATIVE_GE_INIT, true);
  281. return true;
  282. }
  283. bool FinalizeGe(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
  284. if (ms_context_ptr == nullptr) {
  285. MS_LOG(EXCEPTION) << "nullptr";
  286. }
  287. #ifdef ENABLE_GE
  288. if (ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) == 0) {
  289. return true;
  290. }
  291. ms_context_ptr->decrease_param<uint32_t>(MS_CTX_GE_REF);
  292. if (force || ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) == 0) {
  293. ms_context_ptr->set_param<uint32_t>(MS_CTX_GE_REF, 0);
  294. try {
  295. DfGraphManager::GetInstance().DeleteGraphRunner();
  296. DfGraphManager::GetInstance().DeleteGeSession();
  297. } catch (const std::exception &e) {
  298. MS_LOG(ERROR) << "Error occurred when deleting GE graph runner and session fail. Error: " << e.what();
  299. } catch (...) {
  300. std::string exName(abi::__cxa_current_exception_type()->name());
  301. MS_LOG(ERROR) << "Error occurred when deleting GE graph runner and session fail. Exception name: " << exName;
  302. }
  303. if (ge::GEFinalize() != ge::GRAPH_SUCCESS) {
  304. MS_LOG(WARNING) << "Finalize GE failed!";
  305. }
  306. ms_context_ptr->set_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT, false);
  307. } else {
  308. MS_LOG(INFO) << "Ge is used, no need to finalize, tsd reference = "
  309. << ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) << ".";
  310. }
  311. #endif
  312. return true;
  313. }
  314. bool IsTsdOpened(const std::shared_ptr<MsContext> &ms_context_ptr) {
  315. if (ms_context_ptr == nullptr) {
  316. MS_LOG(EXCEPTION) << "nullptr";
  317. }
  318. return ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) > 0;
  319. }
  320. bool IsGeInited(const std::shared_ptr<MsContext> &ms_context_ptr) {
  321. if (ms_context_ptr == nullptr) {
  322. MS_LOG(EXCEPTION) << "nullptr";
  323. }
  324. return ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) > 0;
  325. }
  326. // Register for device type.
  327. struct DeviceTypeSetRegister {
  328. DeviceTypeSetRegister() {
  329. MsContext::device_type_seter([](std::shared_ptr<MsContext> &device_type_seter) {
  330. #ifdef ENABLE_GE
  331. device_type_seter.reset(new (std::nothrow) MsContext("ge", kAscendDevice));
  332. #elif defined(ENABLE_D)
  333. device_type_seter.reset(new (std::nothrow) MsContext("ms", kAscendDevice));
  334. #elif defined(ENABLE_GPU)
  335. device_type_seter.reset(new (std::nothrow) MsContext("ms", kGPUDevice));
  336. #else
  337. device_type_seter.reset(new (std::nothrow) MsContext("vm", kCPUDevice));
  338. #endif
  339. });
  340. }
  341. DeviceTypeSetRegister(const DeviceTypeSetRegister &) = delete;
  342. DeviceTypeSetRegister &operator=(const DeviceTypeSetRegister &) = delete;
  343. ~DeviceTypeSetRegister() = default;
  344. } device_type_set_regsiter;
  345. } // namespace context
  346. } // namespace mindspore