You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

context_extends.cc 14 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "utils/context/context_extends.h"
  17. #include <map>
  18. #include <string>
  19. #include <memory>
  20. #include <thread>
  21. #include <atomic>
  22. #include "pybind11/pybind11.h"
  23. #include "utils/ms_utils.h"
  24. #include "utils/convert_utils_base.h"
  25. namespace py = pybind11;
  26. namespace mindspore {
  27. namespace context {
  28. #ifdef ENABLE_GE
  29. using mindspore::transform::DfGraphManager;
  30. #endif
  31. #ifndef NO_DLIB
  32. // Open tdt dataset
  33. bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) {
  34. if (ms_context_ptr == nullptr) {
  35. MS_LOG(EXCEPTION) << "nullptr";
  36. }
  37. if (ms_context_ptr->get_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT)) {
  38. return true;
  39. }
  40. if (ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF)) {
  41. MS_LOG(DEBUG) << "TDT Dataset client is already opened.";
  42. ms_context_ptr->increase_param<uint32_t>(MS_CTX_TSD_REF);
  43. return true;
  44. }
  45. auto role = common::GetEnv("MS_ROLE");
  46. if (strcmp(role.c_str(), "MS_SCHED") == 0 || strcmp(role.c_str(), "MS_PSERVER") == 0) {
  47. return true;
  48. }
  49. unsigned int device_id;
  50. unsigned int rank_size = 1;
  51. device_id = ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
  52. auto rank_size_env = common::GetEnv("RANK_SIZE");
  53. if (rank_size_env.empty()) {
  54. MS_LOG(INFO) << "Should config rank size.";
  55. rank_size = 1;
  56. } else {
  57. int rank_env = std::stoi(rank_size_env);
  58. if (rank_env <= 0) {
  59. MS_LOG(EXCEPTION) << "Error rank size " << rank_env << ".";
  60. }
  61. rank_size = IntToUint(rank_env);
  62. }
  63. MS_LOG(INFO) << "Device id = " << device_id << ", rank size = " << rank_size << ".";
  64. TDT_StatusT status = TsdOpen(device_id, rank_size);
  65. if (status != TDT_OK) {
  66. MS_LOG(EXCEPTION) << "Device " << device_id << " is occupied, open tsd failed, status = " << status << ".";
  67. return false;
  68. }
  69. ms_context_ptr->increase_param<uint32_t>(MS_CTX_TSD_REF);
  70. #ifdef ENABLE_TDTQUE
  71. int32_t initStatus = tdt::TdtHostInit(device_id);
  72. if (initStatus != TDT_OK_CODE) {
  73. MS_LOG(EXCEPTION) << "Init tsd failed, status = " << initStatus << ".";
  74. return false;
  75. }
  76. ms_context_ptr->tdt_print_ = std::thread(TensorPrint());
  77. #endif
  78. MS_LOG(INFO) << "Open and init tsd successful, tsd reference = "
  79. << ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) << ".";
  80. return true;
  81. }
  82. bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
  83. if (ms_context_ptr == nullptr) {
  84. MS_LOG(EXCEPTION) << "nullptr";
  85. }
  86. if (ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) == 0) {
  87. return true;
  88. }
  89. ms_context_ptr->decrease_param<uint32_t>(MS_CTX_TSD_REF);
  90. if (force || ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) == 0) {
  91. ms_context_ptr->set_param<uint32_t>(MS_CTX_TSD_REF, 0);
  92. #ifdef ENABLE_TDTQUE
  93. int32_t stopStatus = tdt::TdtHostStop(KNpuLog);
  94. if (stopStatus != TDT_OK_CODE) {
  95. MS_LOG(EXCEPTION) << "Stop tsd failed, status = " << stopStatus << ".";
  96. return false;
  97. }
  98. py::gil_scoped_release gil_release;
  99. int32_t destroyStatus = tdt::TdtHostDestroy();
  100. if (destroyStatus != TDT_OK_CODE) {
  101. MS_LOG(EXCEPTION) << "Destroy tsd failed, status = " << destroyStatus << ".";
  102. return false;
  103. }
  104. try {
  105. if (ms_context_ptr->tdt_print_.joinable()) {
  106. MS_LOG(INFO) << "join tdt host receive process";
  107. ms_context_ptr->tdt_print_.join();
  108. }
  109. } catch (const std::exception &e) {
  110. MS_LOG(ERROR) << "tdt thread join failed: " << e.what();
  111. }
  112. #endif
  113. auto device_id = ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
  114. TDT_StatusT status = TsdClose(device_id);
  115. if (status != TDT_OK) {
  116. MS_LOG(EXCEPTION) << "Close tsd failed, status = " << status << ".";
  117. return false;
  118. }
  119. ms_context_ptr->set_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT, false);
  120. MS_LOG(INFO) << "Destroy and close tsd successful, status = " << status << ".";
  121. } else {
  122. MS_LOG(DEBUG) << "TDT Dataset client is used, no need to close, tsd reference = "
  123. << ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) << ".";
  124. }
  125. return true;
  126. }
  127. #else
  128. bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) { return true; }
  129. bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool) { return true; }
  130. #endif
  131. void SetDisableReuseMemoryFlag(std::map<std::string, std::string> *ge_options) {
  132. auto env_disable_reuse_memory = common::GetEnv("DISABLE_REUSE_MEMORY");
  133. if (!env_disable_reuse_memory.empty()) {
  134. (*ge_options)["ge.exec.disableReuseMemory"] = env_disable_reuse_memory;
  135. } else {
  136. (*ge_options)["ge.exec.disableReuseMemory"] = "0";
  137. MS_LOG(WARNING) << "DISABLE_REUSE_MEMORY is not set in ENV. Now set to default value 0";
  138. }
  139. }
  140. void GetGeOptions(const std::shared_ptr<MsContext> &ms_context_ptr, std::map<std::string, std::string> *ge_options) {
  141. if (ms_context_ptr == nullptr) {
  142. MS_LOG(EXCEPTION) << "nullptr";
  143. }
  144. #ifdef ENABLE_GE
  145. (*ge_options)["device_id"] = "0";
  146. (*ge_options)["ge.exec.enableDump"] = std::to_string(ms_context_ptr->get_param<bool>(MS_CTX_ENABLE_DUMP));
  147. (*ge_options)["ge.exec.dumpPath"] = ms_context_ptr->get_param<std::string>(MS_CTX_SAVE_DUMP_PATH);
  148. (*ge_options)["ge.exec.dumpMode"] = "output";
  149. MS_LOG(INFO) << "The enable dump state is " << std::to_string(ms_context_ptr->get_param<bool>(MS_CTX_ENABLE_DUMP))
  150. << " and save dump path is " << ms_context_ptr->get_param<std::string>(MS_CTX_SAVE_DUMP_PATH) << ".";
  151. (*ge_options)["ge.exec.profilingMode"] = std::to_string(ms_context_ptr->get_param<bool>(MS_CTX_ENABLE_PROFILING));
  152. if (ms_context_ptr->get_param<bool>(MS_CTX_ENABLE_PROFILING)) {
  153. (*ge_options)["ge.exec.profilingOptions"] = ms_context_ptr->get_param<std::string>(MS_CTX_PROFILING_OPTIONS);
  154. }
  155. (*ge_options)["rank_table_file"] = "";
  156. auto env_ddk_version = common::GetEnv("DDK_VERSION");
  157. if (!env_ddk_version.empty()) {
  158. (*ge_options)["ge.DDK_version"] = env_ddk_version;
  159. } else {
  160. (*ge_options)["ge.DDK_version"] = "1.60.T17.B830";
  161. }
  162. (*ge_options)["graphType"] = "1";
  163. if (ms_context_ptr->get_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE) != "0") {
  164. (*ge_options)["ge.graphMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE);
  165. }
  166. if (ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE) != "0") {
  167. (*ge_options)["ge.variableMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE);
  168. }
  169. #if ENABLE_TRAIN == 1
  170. (*ge_options)["ge.graphRunMode"] = "1";
  171. #endif
  172. SetDisableReuseMemoryFlag(ge_options);
  173. SetHcclOptions(ms_context_ptr, ge_options);
  174. auto env_job_id = common::GetEnv("JOB_ID");
  175. if (!env_job_id.empty()) {
  176. (*ge_options)["ge.exec.jobId"] = env_job_id;
  177. } else {
  178. (*ge_options)["ge.exec.jobId"] = "0";
  179. MS_LOG(WARNING) << "JOB_ID is not set in ENV. Now set to default value 0";
  180. }
  181. auto env_fe_flag = common::GetEnv("FE_FLAG");
  182. if (!env_fe_flag.empty()) {
  183. (*ge_options)["ge.feFlag"] = env_fe_flag;
  184. MS_LOG(INFO) << "Use FE, make sure fe lib is set in OPTION_EXEC_EXTERN_PLUGIN_PATH.";
  185. }
  186. auto env_aicpu_flag = common::GetEnv("AICPU_FLAG");
  187. if (!env_aicpu_flag.empty()) {
  188. (*ge_options)["ge.aicpuFlag"] = env_aicpu_flag;
  189. MS_LOG(INFO) << "Use AICPU, make sure aicpu lib is set in OPTION_EXEC_EXTERN_PLUGIN_PATH.";
  190. }
  191. auto proto_lib_path = common::GetEnv("OPTION_PROTO_LIB_PATH");
  192. if (!proto_lib_path.empty()) {
  193. char real_path[PATH_MAX] = {0};
  194. if (realpath(proto_lib_path.c_str(), real_path)) {
  195. proto_lib_path = real_path;
  196. (*ge_options)["ge.opsProtoLibPath"] = proto_lib_path;
  197. }
  198. } else {
  199. MS_LOG(WARNING) << "Set proto lib path failed!";
  200. }
  201. // Enable auto mixed precision according to the context options
  202. if (ms_context_ptr->get_param<bool>(MS_CTX_ENABLE_AUTO_MIXED_PRECISION)) {
  203. (*ge_options)["ge.exec.precision_mode"] = "allow_mix_precision";
  204. } else {
  205. (*ge_options)["ge.exec.precision_mode"] = "allow_fp32_to_fp16";
  206. }
  207. // Disable the global variable acc, only enable it whlie adding training graph in pipeline
  208. (*ge_options)["ge.exec.variable_acc"] = "0";
  209. #endif
  210. }
  211. void SetHcclOptions(const std::shared_ptr<MsContext> &ms_context_ptr, std::map<std::string, std::string> *ge_options) {
  212. if (ms_context_ptr == nullptr) {
  213. MS_LOG(EXCEPTION) << "nullptr";
  214. }
  215. auto env_table_file = common::GetEnv("RANK_TABLE_FILE");
  216. auto env_rank_id = common::GetEnv("RANK_ID");
  217. auto env_device_id = std::to_string(ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID));
  218. if (!(env_table_file.empty() || env_rank_id.empty())) {
  219. MS_LOG(INFO) << "Initialize Ge for distribute parameter";
  220. MS_LOG(INFO) << "Use hccl, make sure hccl lib is set in OPTION_EXEC_EXTERN_PLUGIN_PATH.";
  221. auto env_hccl_flag = common::GetEnv("HCCL_FLAG");
  222. if (!env_hccl_flag.empty()) {
  223. (*ge_options)["ge.exec.hcclFlag"] = env_hccl_flag;
  224. }
  225. (*ge_options)["ge.exec.isUseHcom"] = "1";
  226. (*ge_options)["ge.exec.deviceId"] = env_device_id;
  227. (*ge_options)["ge.exec.rankId"] = env_rank_id;
  228. (*ge_options)["ge.exec.podName"] = env_rank_id;
  229. (*ge_options)["ge.exec.rankTableFile"] = env_table_file;
  230. (*ge_options)["ge.graphRunMode"] = "1";
  231. } else {
  232. // device id is still needed for non-distribute case
  233. (*ge_options)["ge.exec.deviceId"] = env_device_id;
  234. MS_LOG(INFO) << "No hccl mode. "
  235. "If use hccl, make sure [RANK_TABLE_FILE,RANK_ID,DEVICE_ID,DEPLOY_MODE] all be set in ENV.";
  236. }
  237. auto env_deploy_mode = common::GetEnv("DEPLOY_MODE");
  238. if (!env_deploy_mode.empty()) {
  239. (*ge_options)["ge.exec.deployMode"] = env_deploy_mode;
  240. } else {
  241. (*ge_options)["ge.exec.deployMode"] = "0";
  242. MS_LOG(WARNING) << "DEPLOY_MODE is not set in ENV. Now set to default value 0";
  243. }
  244. }
  245. bool InitGe(const std::shared_ptr<MsContext> &ms_context_ptr) {
  246. if (ms_context_ptr == nullptr) {
  247. MS_LOG(EXCEPTION) << "nullptr";
  248. }
  249. #ifdef ENABLE_GE
  250. if (ms_context_ptr->get_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT)) {
  251. return true;
  252. }
  253. if (ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF)) {
  254. ms_context_ptr->increase_param<uint32_t>(MS_CTX_GE_REF);
  255. return true;
  256. }
  257. std::map<std::string, std::string> ge_options;
  258. GetGeOptions(ms_context_ptr, &ge_options);
  259. {
  260. // Release GIL before calling into (potentially long-running) C++ code
  261. py::gil_scoped_release release;
  262. if (ge::GEInitialize(ge_options) != ge::GRAPH_SUCCESS) {
  263. MS_LOG(EXCEPTION) << "Initialize GE failed!";
  264. }
  265. }
  266. ms_context_ptr->increase_param<uint32_t>(MS_CTX_GE_REF);
  267. MS_LOG(INFO) << "Init ge successful, ge reference = " << ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) << ".";
  268. #endif
  269. return true;
  270. }
  271. bool PynativeInitGe(const std::shared_ptr<MsContext> &ms_context_ptr) {
  272. if (ms_context_ptr == nullptr) {
  273. MS_LOG(EXCEPTION) << "nullptr";
  274. }
  275. if (ms_context_ptr->get_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT) ||
  276. ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) || ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF)) {
  277. return true;
  278. }
  279. (void)OpenTsd(ms_context_ptr);
  280. (void)InitGe(ms_context_ptr);
  281. ms_context_ptr->set_param(MS_CTX_IS_PYNATIVE_GE_INIT, true);
  282. return true;
  283. }
  284. bool FinalizeGe(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
  285. if (ms_context_ptr == nullptr) {
  286. MS_LOG(EXCEPTION) << "nullptr";
  287. }
  288. #ifdef ENABLE_GE
  289. if (ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) == 0) {
  290. return true;
  291. }
  292. ms_context_ptr->decrease_param<uint32_t>(MS_CTX_GE_REF);
  293. if (force || ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) == 0) {
  294. ms_context_ptr->set_param<uint32_t>(MS_CTX_GE_REF, 0);
  295. try {
  296. DfGraphManager::GetInstance().DeleteGraphRunner();
  297. DfGraphManager::GetInstance().DeleteGeSession();
  298. } catch (const std::exception &e) {
  299. MS_LOG(ERROR) << "Error occurred when deleting GE graph runner and session fail. Error: " << e.what();
  300. } catch (...) {
  301. std::string exName(abi::__cxa_current_exception_type()->name());
  302. MS_LOG(ERROR) << "Error occurred when deleting GE graph runner and session fail. Exception name: " << exName;
  303. }
  304. if (ge::GEFinalize() != ge::GRAPH_SUCCESS) {
  305. MS_LOG(WARNING) << "Finalize GE failed!";
  306. }
  307. ms_context_ptr->set_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT, false);
  308. } else {
  309. MS_LOG(INFO) << "Ge is used, no need to finalize, tsd reference = "
  310. << ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) << ".";
  311. }
  312. #endif
  313. return true;
  314. }
  315. bool IsTsdOpened(const std::shared_ptr<MsContext> &ms_context_ptr) {
  316. if (ms_context_ptr == nullptr) {
  317. MS_LOG(EXCEPTION) << "nullptr";
  318. }
  319. return ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) > 0;
  320. }
  321. bool IsGeInited(const std::shared_ptr<MsContext> &ms_context_ptr) {
  322. if (ms_context_ptr == nullptr) {
  323. MS_LOG(EXCEPTION) << "nullptr";
  324. }
  325. return ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) > 0;
  326. }
  327. // Register for device type.
  328. struct DeviceTypeSetRegister {
  329. DeviceTypeSetRegister() {
  330. MsContext::device_type_seter([](std::shared_ptr<MsContext> &device_type_seter) {
  331. #ifdef ENABLE_GE
  332. device_type_seter.reset(new (std::nothrow) MsContext("ge", kAscendDevice));
  333. #elif defined(ENABLE_D)
  334. device_type_seter.reset(new (std::nothrow) MsContext("ms", kAscendDevice));
  335. #elif defined(ENABLE_GPU)
  336. device_type_seter.reset(new (std::nothrow) MsContext("ms", kGPUDevice));
  337. #else
  338. device_type_seter.reset(new (std::nothrow) MsContext("vm", kCPUDevice));
  339. #endif
  340. });
  341. }
  342. ~DeviceTypeSetRegister() = default;
  343. } device_type_set_regsiter;
  344. } // namespace context
  345. } // namespace mindspore