You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

context_extends.cc 14 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "utils/context/context_extends.h"
  17. #include <map>
  18. #include <string>
  19. #include <memory>
  20. #include <thread>
  21. #include <atomic>
  22. #include "pybind11/pybind11.h"
  23. #include "utils/ms_utils.h"
  24. #include "utils/convert_utils_base.h"
  25. namespace py = pybind11;
  26. namespace mindspore {
  27. namespace context {
  28. #ifdef ENABLE_GE
  29. using mindspore::transform::DfGraphManager;
  30. #endif
  31. #ifndef NO_DLIB
  32. // Open tdt dataset
  33. bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) {
  34. if (ms_context_ptr == nullptr) {
  35. MS_LOG(EXCEPTION) << "nullptr";
  36. }
  37. if (ms_context_ptr->get_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT)) {
  38. return true;
  39. }
  40. if (ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF)) {
  41. MS_LOG(DEBUG) << "ACLTDT Dataset client is already opened.";
  42. ms_context_ptr->increase_param<uint32_t>(MS_CTX_TSD_REF);
  43. return true;
  44. }
  45. auto role = common::GetEnv("MS_ROLE");
  46. if (strcmp(role.c_str(), "MS_SCHED") == 0 || strcmp(role.c_str(), "MS_PSERVER") == 0) {
  47. return true;
  48. }
  49. uint32_t rank_size = 1;
  50. uint32_t device_id = ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
  51. auto rank_size_env = common::GetEnv("RANK_SIZE");
  52. if (rank_size_env.empty()) {
  53. MS_LOG(INFO) << "Should config rank size.";
  54. rank_size = 1;
  55. } else {
  56. int rank_env = std::stoi(rank_size_env);
  57. if (rank_env <= 0) {
  58. MS_LOG(EXCEPTION) << "Error rank size " << rank_env << ".";
  59. }
  60. rank_size = IntToUint(rank_env);
  61. }
  62. MS_LOG(INFO) << "Device id = " << device_id << ", rank size = " << rank_size << ".";
  63. auto ret = rtSetDevice(device_id);
  64. if (ret != RT_ERROR_NONE) {
  65. MS_LOG(EXCEPTION) << "Device " << device_id << " call rtSetDevice failed, ret[" << static_cast<int>(ret) << "]";
  66. return false;
  67. }
  68. ms_context_ptr->increase_param<uint32_t>(MS_CTX_TSD_REF);
  69. #ifdef ENABLE_TDTQUE
  70. acltdtChannelHandle *acl_handle = ms_context_ptr->CreateAclTdtChannelHandle();
  71. if (acl_handle == nullptr) {
  72. MS_LOG(EXCEPTION) << "Get acltdt handle failed";
  73. return false;
  74. }
  75. ms_context_ptr->acl_tdt_print = std::thread(TensorPrint(acl_handle));
  76. #endif
  77. MS_LOG(INFO) << "Get the acltdt handle successful, tsd reference = "
  78. << ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) << ".";
  79. return true;
  80. }
  81. bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
  82. if (ms_context_ptr == nullptr) {
  83. MS_LOG(EXCEPTION) << "ms_context_prt is nullptr";
  84. }
  85. if (ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) == 0) {
  86. return true;
  87. }
  88. ms_context_ptr->decrease_param<uint32_t>(MS_CTX_TSD_REF);
  89. if (force || ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) == 0) {
  90. ms_context_ptr->set_param<uint32_t>(MS_CTX_TSD_REF, 0);
  91. #ifdef ENABLE_TDTQUE
  92. ms_context_ptr->DestroyAclTdtChannelHandle();
  93. py::gil_scoped_release gil_release;
  94. try {
  95. if (ms_context_ptr->acl_tdt_print.joinable()) {
  96. MS_LOG(INFO) << "join acl tdt host receive process";
  97. ms_context_ptr->acl_tdt_print.join();
  98. }
  99. } catch (const std::exception &e) {
  100. MS_LOG(ERROR) << "tdt thread join failed: " << e.what();
  101. }
  102. #endif
  103. uint32_t device_id = ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
  104. auto ret = rtDeviceReset(device_id);
  105. if (ret != RT_ERROR_NONE) {
  106. MS_LOG(EXCEPTION) << "Device " << device_id << " call rtDeviceReset failed, ret[" << static_cast<int>(ret) << "]";
  107. return false;
  108. }
  109. ms_context_ptr->set_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT, false);
  110. MS_LOG(INFO) << "Call rtDeviceReset, destroy and close tsd successful, ret[" << static_cast<int>(ret) << "]";
  111. } else {
  112. MS_LOG(DEBUG) << "Acltdt Dataset client is used, no need to close, tsd reference = "
  113. << ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) << ".";
  114. }
  115. return true;
  116. }
  117. #else
  118. bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) { return true; }
  119. bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool) { return true; }
  120. #endif
  121. void SetDisableReuseMemoryFlag(std::map<std::string, std::string> *ge_options) {
  122. auto env_disable_reuse_memory = common::GetEnv("DISABLE_REUSE_MEMORY");
  123. if (!env_disable_reuse_memory.empty()) {
  124. (*ge_options)["ge.exec.disableReuseMemory"] = env_disable_reuse_memory;
  125. } else {
  126. (*ge_options)["ge.exec.disableReuseMemory"] = "0";
  127. MS_LOG(WARNING) << "DISABLE_REUSE_MEMORY is not set in ENV. Now set to default value 0";
  128. }
  129. }
  130. void GetGeOptions(const std::shared_ptr<MsContext> &ms_context_ptr, std::map<std::string, std::string> *ge_options) {
  131. if (ms_context_ptr == nullptr) {
  132. MS_LOG(EXCEPTION) << "nullptr";
  133. }
  134. #ifdef ENABLE_GE
  135. (*ge_options)["device_id"] = "0";
  136. (*ge_options)["ge.exec.enableDump"] = std::to_string(ms_context_ptr->get_param<bool>(MS_CTX_ENABLE_DUMP));
  137. (*ge_options)["ge.exec.dumpPath"] = ms_context_ptr->get_param<std::string>(MS_CTX_SAVE_DUMP_PATH);
  138. (*ge_options)["ge.exec.dumpMode"] = "output";
  139. MS_LOG(INFO) << "The enable dump state is " << std::to_string(ms_context_ptr->get_param<bool>(MS_CTX_ENABLE_DUMP))
  140. << " and save dump path is " << ms_context_ptr->get_param<std::string>(MS_CTX_SAVE_DUMP_PATH) << ".";
  141. (*ge_options)["ge.exec.profilingMode"] = std::to_string(ms_context_ptr->get_param<bool>(MS_CTX_ENABLE_PROFILING));
  142. if (ms_context_ptr->get_param<bool>(MS_CTX_ENABLE_PROFILING)) {
  143. (*ge_options)["ge.exec.profilingOptions"] = ms_context_ptr->get_param<std::string>(MS_CTX_PROFILING_OPTIONS);
  144. }
  145. (*ge_options)["rank_table_file"] = "";
  146. auto env_ddk_version = common::GetEnv("DDK_VERSION");
  147. if (!env_ddk_version.empty()) {
  148. (*ge_options)["ge.DDK_version"] = env_ddk_version;
  149. } else {
  150. (*ge_options)["ge.DDK_version"] = "1.60.T17.B830";
  151. }
  152. (*ge_options)["graphType"] = "1";
  153. if (ms_context_ptr->get_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE) != "0") {
  154. (*ge_options)["ge.graphMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_GRAPH_MEMORY_MAX_SIZE);
  155. }
  156. if (ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE) != "0") {
  157. (*ge_options)["ge.variableMemoryMaxSize"] = ms_context_ptr->get_param<std::string>(MS_CTX_VARIABLE_MEMORY_MAX_SIZE);
  158. }
  159. #if ENABLE_TRAIN == 1
  160. (*ge_options)["ge.graphRunMode"] = "1";
  161. #endif
  162. SetDisableReuseMemoryFlag(ge_options);
  163. SetHcclOptions(ms_context_ptr, ge_options);
  164. auto env_job_id = common::GetEnv("JOB_ID");
  165. if (!env_job_id.empty()) {
  166. (*ge_options)["ge.exec.jobId"] = env_job_id;
  167. } else {
  168. (*ge_options)["ge.exec.jobId"] = "0";
  169. MS_LOG(WARNING) << "JOB_ID is not set in ENV. Now set to default value 0";
  170. }
  171. auto env_fe_flag = common::GetEnv("FE_FLAG");
  172. if (!env_fe_flag.empty()) {
  173. (*ge_options)["ge.feFlag"] = env_fe_flag;
  174. MS_LOG(INFO) << "Use FE, make sure fe lib is set in OPTION_EXEC_EXTERN_PLUGIN_PATH.";
  175. }
  176. auto env_aicpu_flag = common::GetEnv("AICPU_FLAG");
  177. if (!env_aicpu_flag.empty()) {
  178. (*ge_options)["ge.aicpuFlag"] = env_aicpu_flag;
  179. MS_LOG(INFO) << "Use AICPU, make sure aicpu lib is set in OPTION_EXEC_EXTERN_PLUGIN_PATH.";
  180. }
  181. auto proto_lib_path = common::GetEnv("OPTION_PROTO_LIB_PATH");
  182. if (!proto_lib_path.empty()) {
  183. char real_path[PATH_MAX] = {0};
  184. if (realpath(proto_lib_path.c_str(), real_path)) {
  185. proto_lib_path = real_path;
  186. (*ge_options)["ge.opsProtoLibPath"] = proto_lib_path;
  187. }
  188. } else {
  189. MS_LOG(WARNING) << "Set proto lib path failed!";
  190. }
  191. // Enable auto mixed precision according to the context options
  192. if (ms_context_ptr->get_param<bool>(MS_CTX_ENABLE_AUTO_MIXED_PRECISION)) {
  193. (*ge_options)["ge.exec.precision_mode"] = "allow_mix_precision";
  194. } else {
  195. (*ge_options)["ge.exec.precision_mode"] = "allow_fp32_to_fp16";
  196. }
  197. // Disable the global variable acc, only enable it while adding training graph in pipeline
  198. (*ge_options)["ge.exec.variable_acc"] = "0";
  199. #endif
  200. }
  201. void SetHcclOptions(const std::shared_ptr<MsContext> &ms_context_ptr, std::map<std::string, std::string> *ge_options) {
  202. if (ms_context_ptr == nullptr) {
  203. MS_LOG(EXCEPTION) << "nullptr";
  204. }
  205. auto env_table_file = common::GetEnv("RANK_TABLE_FILE");
  206. auto env_rank_id = common::GetEnv("RANK_ID");
  207. auto env_device_id = std::to_string(ms_context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID));
  208. if (!(env_table_file.empty() || env_rank_id.empty())) {
  209. MS_LOG(INFO) << "Initialize Ge for distribute parameter";
  210. MS_LOG(INFO) << "Use hccl, make sure hccl lib is set in OPTION_EXEC_EXTERN_PLUGIN_PATH.";
  211. auto env_hccl_flag = common::GetEnv("HCCL_FLAG");
  212. if (!env_hccl_flag.empty()) {
  213. (*ge_options)["ge.exec.hcclFlag"] = env_hccl_flag;
  214. }
  215. (*ge_options)["ge.exec.isUseHcom"] = "1";
  216. (*ge_options)["ge.exec.deviceId"] = env_device_id;
  217. (*ge_options)["ge.exec.rankId"] = env_rank_id;
  218. (*ge_options)["ge.exec.podName"] = env_rank_id;
  219. (*ge_options)["ge.exec.rankTableFile"] = env_table_file;
  220. (*ge_options)["ge.graphRunMode"] = "1";
  221. } else {
  222. // device id is still needed for non-distribute case
  223. (*ge_options)["ge.exec.deviceId"] = env_device_id;
  224. MS_LOG(INFO) << "No hccl mode. "
  225. "If use hccl, make sure [RANK_TABLE_FILE,RANK_ID,DEVICE_ID,DEPLOY_MODE] all be set in ENV.";
  226. }
  227. auto env_deploy_mode = common::GetEnv("DEPLOY_MODE");
  228. if (!env_deploy_mode.empty()) {
  229. (*ge_options)["ge.exec.deployMode"] = env_deploy_mode;
  230. } else {
  231. (*ge_options)["ge.exec.deployMode"] = "0";
  232. MS_LOG(WARNING) << "DEPLOY_MODE is not set in ENV. Now set to default value 0";
  233. }
  234. }
  235. bool InitGe(const std::shared_ptr<MsContext> &ms_context_ptr) {
  236. if (ms_context_ptr == nullptr) {
  237. MS_LOG(EXCEPTION) << "nullptr";
  238. }
  239. #ifdef ENABLE_GE
  240. if (ms_context_ptr->get_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT)) {
  241. return true;
  242. }
  243. if (ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF)) {
  244. ms_context_ptr->increase_param<uint32_t>(MS_CTX_GE_REF);
  245. return true;
  246. }
  247. std::map<std::string, std::string> ge_options;
  248. GetGeOptions(ms_context_ptr, &ge_options);
  249. {
  250. // Release GIL before calling into (potentially long-running) C++ code
  251. py::gil_scoped_release release;
  252. if (ge::GEInitialize(ge_options) != ge::GRAPH_SUCCESS) {
  253. MS_LOG(EXCEPTION) << "Initialize GE failed!";
  254. }
  255. }
  256. ms_context_ptr->increase_param<uint32_t>(MS_CTX_GE_REF);
  257. MS_LOG(INFO) << "Init ge successful, ge reference = " << ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) << ".";
  258. #endif
  259. return true;
  260. }
  261. bool PynativeInitGe(const std::shared_ptr<MsContext> &ms_context_ptr) {
  262. if (ms_context_ptr == nullptr) {
  263. MS_LOG(EXCEPTION) << "nullptr";
  264. }
  265. if (ms_context_ptr->get_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT) ||
  266. ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) || ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF)) {
  267. return true;
  268. }
  269. (void)OpenTsd(ms_context_ptr);
  270. (void)InitGe(ms_context_ptr);
  271. ms_context_ptr->set_param(MS_CTX_IS_PYNATIVE_GE_INIT, true);
  272. return true;
  273. }
  274. bool FinalizeGe(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
  275. if (ms_context_ptr == nullptr) {
  276. MS_LOG(EXCEPTION) << "nullptr";
  277. }
  278. #ifdef ENABLE_GE
  279. if (ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) == 0) {
  280. return true;
  281. }
  282. ms_context_ptr->decrease_param<uint32_t>(MS_CTX_GE_REF);
  283. if (force || ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) == 0) {
  284. ms_context_ptr->set_param<uint32_t>(MS_CTX_GE_REF, 0);
  285. try {
  286. DfGraphManager::GetInstance().DeleteGraphRunner();
  287. DfGraphManager::GetInstance().DeleteGeSession();
  288. } catch (const std::exception &e) {
  289. MS_LOG(ERROR) << "Error occurred when deleting GE graph runner and session fail. Error: " << e.what();
  290. } catch (...) {
  291. std::string exName(abi::__cxa_current_exception_type()->name());
  292. MS_LOG(ERROR) << "Error occurred when deleting GE graph runner and session fail. Exception name: " << exName;
  293. }
  294. if (ge::GEFinalize() != ge::GRAPH_SUCCESS) {
  295. MS_LOG(WARNING) << "Finalize GE failed!";
  296. }
  297. ms_context_ptr->set_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT, false);
  298. } else {
  299. MS_LOG(INFO) << "Ge is used, no need to finalize, tsd reference = "
  300. << ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) << ".";
  301. }
  302. #endif
  303. return true;
  304. }
  305. bool IsTsdOpened(const std::shared_ptr<MsContext> &ms_context_ptr) {
  306. if (ms_context_ptr == nullptr) {
  307. MS_LOG(EXCEPTION) << "nullptr";
  308. }
  309. return ms_context_ptr->get_param<uint32_t>(MS_CTX_TSD_REF) > 0;
  310. }
  311. bool IsGeInited(const std::shared_ptr<MsContext> &ms_context_ptr) {
  312. if (ms_context_ptr == nullptr) {
  313. MS_LOG(EXCEPTION) << "nullptr";
  314. }
  315. return ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) > 0;
  316. }
  317. // Register for device type.
  318. struct DeviceTypeSetRegister {
  319. DeviceTypeSetRegister() {
  320. MsContext::device_type_seter([](std::shared_ptr<MsContext> &device_type_seter) {
  321. #ifdef ENABLE_GE
  322. device_type_seter.reset(new (std::nothrow) MsContext("ge", kAscendDevice));
  323. #elif defined(ENABLE_D)
  324. device_type_seter.reset(new (std::nothrow) MsContext("ms", kAscendDevice));
  325. #elif defined(ENABLE_GPU)
  326. device_type_seter.reset(new (std::nothrow) MsContext("ms", kGPUDevice));
  327. #else
  328. device_type_seter.reset(new (std::nothrow) MsContext("vm", kCPUDevice));
  329. #endif
  330. });
  331. }
  332. ~DeviceTypeSetRegister() = default;
  333. } device_type_set_regsiter;
  334. } // namespace context
  335. } // namespace mindspore