You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ms_context.cc 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "utils/context/ms_context.h"
  17. #include <thread>
  18. #include <atomic>
  19. #include <fstream>
  20. #include "./common.h"
  21. #include "utils/convert_utils.h"
  22. #include "utils/tensorprint_utils.h"
  23. #ifndef NO_DLIB
  24. #include "tdt/tsd_client.h"
  25. #include "tdt/tdt_host_interface.h"
  26. #include "tdt/data_common.h"
  27. #endif
  28. #ifdef ENABLE_GE
  29. #include "transform/df_graph_manager.h"
  30. #endif
  31. #include "ir/tensor.h"
  32. namespace mindspore {
  33. #ifdef ENABLE_GE
  34. using mindspore::transform::DfGraphManager;
  35. #endif
  36. std::atomic<bool> thread_1_must_end(false);
  37. std::shared_ptr<MsContext> MsContext::inst_context_ = nullptr;
  38. std::map<std::string, MsBackendPolicy> MsContext::policy_map_ = {{"ge", kMsBackendGePrior},
  39. {"vm", kMsBackendVmOnly},
  40. {"ms", kMsBackendMsPrior},
  41. {"ge_only", kMsBackendGeOnly},
  42. {"vm_prior", kMsBackendVmPrior}};
  43. MsContext::MsContext(const std::string &policy, const std::string &target) {
  44. save_graphs_flag_ = false;
  45. save_graphs_path_ = ".";
  46. save_ms_model_flag_ = false;
  47. save_ms_model_path_ = "./model.ms";
  48. enable_dump_ = false;
  49. save_dump_path_ = ".";
  50. tsd_ref_ = 0;
  51. ge_ref_ = 0;
  52. is_multi_graph_sink_ = false;
  53. is_pynative_ge_init_ = false;
  54. enable_reduce_precision_ = true;
  55. auto env_device = common::GetEnv("DEVICE_ID");
  56. if (!env_device.empty()) {
  57. device_id_ = UlongToUint(std::stoul(env_device.c_str()));
  58. } else {
  59. device_id_ = 0;
  60. }
  61. backend_policy_ = policy_map_[policy];
  62. device_target_ = target;
  63. execution_mode_ = kPynativeMode;
  64. enable_task_sink_ = true;
  65. ir_fusion_flag_ = true;
  66. enable_hccl_ = false;
  67. #ifdef ENABLE_DEBUGGER
  68. enable_mem_reuse_ = false;
  69. #else
  70. enable_mem_reuse_ = true;
  71. #endif
  72. enable_gpu_summary_ = true;
  73. precompile_only_ = false;
  74. auto_mixed_precision_flag_ = false;
  75. enable_pynative_infer_ = false;
  76. enable_pynative_hook_ = false;
  77. enable_dynamic_mem_pool_ = true;
  78. graph_memory_max_size_ = "0";
  79. variable_memory_max_size_ = "0";
  80. enable_loop_sink_ = target == kAscendDevice || target == kDavinciDevice;
  81. profiling_mode_ = false;
  82. profiling_options_ = "training_trace";
  83. check_bprop_flag_ = false;
  84. max_device_memory_ = kDefaultMaxDeviceMemory;
  85. print_file_path_ = "";
  86. enable_graph_kernel_ = false;
  87. enable_sparse_flag_ = false;
  88. }
  89. std::shared_ptr<MsContext> MsContext::GetInstance() {
  90. if (inst_context_ == nullptr) {
  91. MS_LOG(DEBUG) << "Create new mindspore context";
  92. #ifdef ENABLE_GE
  93. inst_context_.reset(new (std::nothrow) MsContext("ge", kAscendDevice));
  94. #elif defined(ENABLE_D)
  95. inst_context_.reset(new (std::nothrow) MsContext("ms", kAscendDevice));
  96. #elif defined(ENABLE_GPU)
  97. inst_context_.reset(new (std::nothrow) MsContext("ms", kGPUDevice));
  98. #else
  99. inst_context_.reset(new (std::nothrow) MsContext("vm", kCPUDevice));
  100. #endif
  101. }
  102. return inst_context_;
  103. }
  104. bool MsContext::set_backend_policy(const std::string &policy) {
  105. if (policy_map_.find(policy) == policy_map_.end()) {
  106. MS_LOG(ERROR) << "invalid backend policy name: " << policy;
  107. return false;
  108. }
  109. backend_policy_ = policy_map_[policy];
  110. MS_LOG(INFO) << "ms set context backend policy:" << policy;
  111. return true;
  112. }
  113. std::string MsContext::backend_policy() const {
  114. auto res = std::find_if(
  115. policy_map_.begin(), policy_map_.end(),
  116. [&, this](const std::pair<std::string, MsBackendPolicy> &item) { return item.second == backend_policy_; });
  117. if (res != policy_map_.end()) {
  118. return res->first;
  119. }
  120. return "unknown";
  121. }
  122. void MsContext::set_execution_mode(int execution_mode) {
  123. if (execution_mode != kGraphMode && execution_mode != kPynativeMode) {
  124. MS_LOG(EXCEPTION) << "The execution mode is invalid!";
  125. }
  126. execution_mode_ = execution_mode;
  127. }
  128. bool MsContext::set_device_target(const std::string &target) {
  129. if (kTargetSet.find(target) == kTargetSet.end()) {
  130. MS_LOG(ERROR) << "invalid device target name: " << target;
  131. return false;
  132. }
  133. if (target == kDavinciDevice) {
  134. device_target_ = kAscendDevice;
  135. } else {
  136. device_target_ = target;
  137. }
  138. MS_LOG(INFO) << "ms set context device target:" << target;
  139. return true;
  140. }
  141. bool MsContext::set_device_id(uint32_t device_id) {
  142. device_id_ = device_id;
  143. MS_LOG(INFO) << "ms set context device id:" << device_id;
  144. return true;
  145. }
  146. #ifndef NO_DLIB
  147. // Open tdt dataset
  148. bool MsContext::OpenTsd() {
  149. if (is_pynative_ge_init_) {
  150. return true;
  151. }
  152. if (tsd_ref_) {
  153. MS_LOG(DEBUG) << "TDT Dataset client is already opened.";
  154. tsd_ref_++;
  155. return true;
  156. }
  157. unsigned int device_id;
  158. unsigned int rank_size = 1;
  159. device_id = device_id_;
  160. auto rank_size_env = common::GetEnv("RANK_SIZE");
  161. if (rank_size_env.empty()) {
  162. MS_LOG(INFO) << "Should config rank size.";
  163. rank_size = 1;
  164. } else {
  165. int rank_env = std::stoi(rank_size_env);
  166. if (rank_env <= 0) {
  167. MS_LOG(EXCEPTION) << "Error rank size " << rank_env << ".";
  168. }
  169. rank_size = IntToUint(rank_env);
  170. }
  171. MS_LOG(INFO) << "Device id = " << device_id << ", rank size = " << rank_size << ".";
  172. TDT_StatusT status = tdt::TsdClient::GetInstance()->Open(device_id, rank_size);
  173. if (status != TDT_OK) {
  174. MS_LOG(EXCEPTION) << "Device " << device_id << " is occupied, open tsd failed, status = " << status << ".";
  175. return false;
  176. }
  177. tsd_ref_++;
  178. #ifdef ENABLE_TDTQUE
  179. int32_t initStatus = tdt::TdtHostInit(device_id);
  180. if (initStatus != TDT_OK_CODE) {
  181. MS_LOG(EXCEPTION) << "Init tsd failed, status = " << initStatus << ".";
  182. return false;
  183. }
  184. tdt_print_ = std::thread(TensorPrint());
  185. #endif
  186. MS_LOG(INFO) << "Open and init tsd successful, tsd reference = " << tsd_ref_ << ".";
  187. return true;
  188. }
  189. bool MsContext::CloseTsd(bool force) {
  190. if (tsd_ref_ == 0) {
  191. return true;
  192. }
  193. tsd_ref_--;
  194. if (force || tsd_ref_ == 0) {
  195. tsd_ref_ = 0;
  196. #ifdef ENABLE_TDTQUE
  197. int32_t stopStatus = tdt::TdtHostStop(KNpuLog);
  198. if (stopStatus != TDT_OK_CODE) {
  199. MS_LOG(EXCEPTION) << "Stop tsd failed, status = " << stopStatus << ".";
  200. return false;
  201. }
  202. py::gil_scoped_release gil_release;
  203. int32_t destroyStatus = tdt::TdtHostDestroy();
  204. if (destroyStatus != TDT_OK_CODE) {
  205. MS_LOG(EXCEPTION) << "Destroy tsd failed, status = " << destroyStatus << ".";
  206. return false;
  207. }
  208. try {
  209. if (tdt_print_.joinable()) {
  210. MS_LOG(INFO) << "join tdt host receive process";
  211. tdt_print_.join();
  212. }
  213. } catch (const std::exception &e) {
  214. MS_LOG(ERROR) << "tdt thread join failed: " << e.what();
  215. }
  216. #endif
  217. TDT_StatusT status = tdt::TsdClient::GetInstance()->Close();
  218. if (status != TDT_OK) {
  219. MS_LOG(EXCEPTION) << "Close tsd failed, status = " << status << ".";
  220. return false;
  221. }
  222. is_pynative_ge_init_ = false;
  223. MS_LOG(INFO) << "Destroy and close tsd successful, status = " << status << ".";
  224. } else {
  225. MS_LOG(DEBUG) << "TDT Dataset client is used, no need to close, tsd reference = " << tsd_ref_ << ".";
  226. }
  227. return true;
  228. }
  229. #else
  230. bool MsContext::OpenTsd() { return true; }
  231. bool MsContext::CloseTsd(bool) { return true; }
  232. #endif
  233. void MsContext::SetHcclOptions(std::map<std::string, std::string> *ge_options) const {
  234. auto env_table_file = common::GetEnv("RANK_TABLE_FILE");
  235. auto env_rank_id = common::GetEnv("RANK_ID");
  236. auto env_device_id = std::to_string(device_id_);
  237. if (!(env_table_file.empty() || env_rank_id.empty())) {
  238. MS_LOG(INFO) << "Initialize Ge for distribute parameter";
  239. MS_LOG(INFO) << "Use hccl, make sure hccl lib is set in OPTION_EXEC_EXTERN_PLUGIN_PATH.";
  240. auto env_hccl_flag = common::GetEnv("HCCL_FLAG");
  241. if (!env_hccl_flag.empty()) {
  242. (*ge_options)["ge.exec.hcclFlag"] = env_hccl_flag;
  243. }
  244. (*ge_options)["ge.exec.isUseHcom"] = "1";
  245. (*ge_options)["ge.exec.deviceId"] = env_device_id;
  246. (*ge_options)["ge.exec.rankId"] = env_rank_id;
  247. (*ge_options)["ge.exec.podName"] = env_rank_id;
  248. (*ge_options)["ge.exec.rankTableFile"] = env_table_file;
  249. (*ge_options)["ge.graphRunMode"] = "1";
  250. } else {
  251. // device id is still needed for non-distribute case
  252. (*ge_options)["ge.exec.deviceId"] = env_device_id;
  253. MS_LOG(INFO) << "No hccl mode. "
  254. "If use hccl, make sure [RANK_TABLE_FILE,RANK_ID,DEVICE_ID,DEPLOY_MODE] all be set in ENV.";
  255. }
  256. auto env_deploy_mode = common::GetEnv("DEPLOY_MODE");
  257. if (!env_deploy_mode.empty()) {
  258. (*ge_options)["ge.exec.deployMode"] = env_deploy_mode;
  259. } else {
  260. (*ge_options)["ge.exec.deployMode"] = "0";
  261. MS_LOG(WARNING) << "DEPLOY_MODE is not set in ENV. Now set to default value 0";
  262. }
  263. }
  264. void MsContext::GetGeOptions(std::map<std::string, std::string> *ge_options) const {
  265. #ifdef ENABLE_GE
  266. (*ge_options)["device_id"] = "0";
  267. (*ge_options)["ge.exec.enableDump"] = std::to_string(enable_dump_);
  268. (*ge_options)["ge.exec.dumpPath"] = save_dump_path_;
  269. (*ge_options)["ge.exec.dumpMode"] = "output";
  270. MS_LOG(INFO) << "The enable dump state is " << std::to_string(enable_dump_) << " and save dump path is "
  271. << save_dump_path_ << ".";
  272. (*ge_options)["ge.exec.profilingMode"] = std::to_string(profiling_mode_);
  273. if (profiling_mode_) {
  274. (*ge_options)["ge.exec.profilingOptions"] = profiling_options_;
  275. }
  276. // only not supported in ge
  277. auto tbe_plugin_path = common::GetEnv("ME_TBE_PLUGIN_PATH");
  278. if (!tbe_plugin_path.empty()) {
  279. char real_path[PATH_MAX] = {0};
  280. if (nullptr == realpath(tbe_plugin_path.c_str(), real_path)) {
  281. MS_LOG(ERROR) << "Ms tbe plugin Path error, " << tbe_plugin_path;
  282. } else {
  283. tbe_plugin_path = real_path;
  284. (*ge_options)["ge.TBE_plugin_path"] = tbe_plugin_path;
  285. }
  286. } else {
  287. MS_LOG(ERROR) << "Set TBE plugin path failed!";
  288. }
  289. (*ge_options)["rank_table_file"] = "";
  290. auto env_ddk_version = common::GetEnv("DDK_VERSION");
  291. if (!env_ddk_version.empty()) {
  292. (*ge_options)["ge.DDK_version"] = env_ddk_version;
  293. } else {
  294. (*ge_options)["ge.DDK_version"] = "1.60.T17.B830";
  295. }
  296. (*ge_options)["graphType"] = "1";
  297. if (graph_memory_max_size_ != "0") {
  298. (*ge_options)["ge.graphMemoryMaxSize"] = graph_memory_max_size_;
  299. }
  300. if (variable_memory_max_size_ != "0") {
  301. (*ge_options)["ge.variableMemoryMaxSize"] = variable_memory_max_size_;
  302. }
  303. #if ENABLE_TRAIN == 1
  304. (*ge_options)["ge.graphRunMode"] = "1";
  305. #endif
  306. SetDisableReuseMemoryFlag(ge_options);
  307. SetHcclOptions(ge_options);
  308. auto env_job_id = common::GetEnv("JOB_ID");
  309. if (!env_job_id.empty()) {
  310. (*ge_options)["ge.exec.jobId"] = env_job_id;
  311. } else {
  312. (*ge_options)["ge.exec.jobId"] = "0";
  313. MS_LOG(WARNING) << "JOB_ID is not set in ENV. Now set to default value 0";
  314. }
  315. auto env_fe_flag = common::GetEnv("FE_FLAG");
  316. if (!env_fe_flag.empty()) {
  317. (*ge_options)["ge.feFlag"] = env_fe_flag;
  318. MS_LOG(INFO) << "Use FE, make sure fe lib is set in OPTION_EXEC_EXTERN_PLUGIN_PATH.";
  319. }
  320. auto env_aicpu_flag = common::GetEnv("AICPU_FLAG");
  321. if (!env_aicpu_flag.empty()) {
  322. (*ge_options)["ge.aicpuFlag"] = env_aicpu_flag;
  323. MS_LOG(INFO) << "Use AICPU, make sure aicpu lib is set in OPTION_EXEC_EXTERN_PLUGIN_PATH.";
  324. }
  325. // all libs are set in same env variable "OPTION_EXEC_EXTERN_PLUGIN_PATH", such as FE, HCCL, AICPU, etc
  326. auto load_path = common::GetEnv("OPTION_EXEC_EXTERN_PLUGIN_PATH");
  327. if (!load_path.empty()) {
  328. char real_path[PATH_MAX] = {0};
  329. if (realpath(load_path.c_str(), real_path)) {
  330. load_path = real_path;
  331. (*ge_options)["ge.soLoadPath"] = load_path;
  332. }
  333. } else {
  334. MS_LOG(ERROR) << "Set lib load path failed!";
  335. }
  336. auto proto_lib_path = common::GetEnv("OPTION_PROTO_LIB_PATH");
  337. if (!proto_lib_path.empty()) {
  338. char real_path[PATH_MAX] = {0};
  339. if (realpath(proto_lib_path.c_str(), real_path)) {
  340. proto_lib_path = real_path;
  341. (*ge_options)["ge.opsProtoLibPath"] = proto_lib_path;
  342. }
  343. } else {
  344. MS_LOG(ERROR) << "Set proto lib path failed!";
  345. }
  346. // Enable auto mixed precision according to the context options
  347. if (auto_mixed_precision_flag_) {
  348. (*ge_options)["ge.exec.precision_mode"] = "allow_mix_precision";
  349. } else {
  350. (*ge_options)["ge.exec.precision_mode"] = "allow_fp32_to_fp16";
  351. }
  352. // Disable the global variable acc, only enable it whlie adding training graph in pipeline
  353. (*ge_options)["ge.exec.variable_acc"] = "0";
  354. #endif
  355. }
  356. void MsContext::SetDisableReuseMemoryFlag(std::map<std::string, std::string> *ge_options) const {
  357. auto env_disable_reuse_memory = common::GetEnv("DISABLE_REUSE_MEMORY");
  358. if (!env_disable_reuse_memory.empty()) {
  359. (*ge_options)["ge.exec.disableReuseMemory"] = env_disable_reuse_memory;
  360. } else {
  361. (*ge_options)["ge.exec.disableReuseMemory"] = "0";
  362. MS_LOG(WARNING) << "DISABLE_REUSE_MEMORY is not set in ENV. Now set to default value 0";
  363. }
  364. }
  365. bool MsContext::InitGe() {
  366. #ifdef ENABLE_GE
  367. if (is_pynative_ge_init_) {
  368. return true;
  369. }
  370. if (ge_ref_) {
  371. ge_ref_++;
  372. return true;
  373. }
  374. std::map<std::string, std::string> ge_options;
  375. GetGeOptions(&ge_options);
  376. {
  377. // Release GIL before calling into (potentially long-running) C++ code
  378. py::gil_scoped_release release;
  379. if (ge::GEInitialize(ge_options) != ge::GRAPH_SUCCESS) {
  380. MS_LOG(EXCEPTION) << "Initialize GE failed!";
  381. }
  382. }
  383. ge_ref_++;
  384. MS_LOG(INFO) << "Init ge successful, ge reference = " << ge_ref_ << ".";
  385. #endif
  386. return true;
  387. }
  388. bool MsContext::FinalizeGe(bool force) {
  389. #ifdef ENABLE_GE
  390. if (ge_ref_ == 0) {
  391. return true;
  392. }
  393. ge_ref_--;
  394. if (force || ge_ref_ == 0) {
  395. ge_ref_ = 0;
  396. try {
  397. DfGraphManager::GetInstance().DeleteGraphRunner();
  398. DfGraphManager::GetInstance().DeleteGeSession();
  399. } catch (const std::exception &e) {
  400. MS_LOG(ERROR) << "Error occurred when deleting GE graph runner and session fail. Error: " << e.what();
  401. } catch (...) {
  402. std::string exName(abi::__cxa_current_exception_type()->name());
  403. MS_LOG(ERROR) << "Error occurred when deleting GE graph runner and session fail. Exception name: " << exName;
  404. }
  405. if (ge::GEFinalize() != ge::GRAPH_SUCCESS) {
  406. MS_LOG(WARNING) << "Finalize GE failed!";
  407. }
  408. is_pynative_ge_init_ = false;
  409. } else {
  410. MS_LOG(INFO) << "Ge is used, no need to finalize, tsd reference = " << ge_ref_ << ".";
  411. }
  412. #endif
  413. return true;
  414. }
  415. bool MsContext::PynativeInitGe() {
  416. if (is_pynative_ge_init_ || ge_ref_ || tsd_ref_) {
  417. return true;
  418. }
  419. (void)OpenTsd();
  420. (void)InitGe();
  421. is_pynative_ge_init_ = true;
  422. return true;
  423. }
  424. bool MsContext::IsTsdOpened() {
  425. if (tsd_ref_ > 0) {
  426. return true;
  427. }
  428. return false;
  429. }
  430. bool MsContext::IsGeInited() {
  431. if (ge_ref_ > 0) {
  432. return true;
  433. }
  434. return false;
  435. }
  436. } // namespace mindspore