You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

init.cc 20 kB

5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <pybind11/operators.h>
  17. #include "backend/kernel_compiler/oplib/oplib.h"
  18. #include "backend/kernel_compiler/oplib/oploader.h"
  19. #include "pipeline/jit/pipeline.h"
  20. #include "frontend/operator/composite/composite.h"
  21. #include "pipeline/pynative/pynative_execute.h"
  22. #include "utils/symbolic.h"
  23. #include "pybind_api/api_register.h"
  24. #include "pipeline/jit/parse/python_adapter.h"
  25. #include "utils/summary/event_writer.h"
  26. #include "utils/config_manager.h"
  27. #include "utils/mpi/mpi_config.h"
  28. #include "frontend/parallel/context.h"
  29. #include "frontend/parallel/costmodel_context.h"
  30. #ifdef ENABLE_GPU_COLLECTIVE
  31. #include "runtime/device/gpu/distribution/collective_init.h"
  32. #else
  33. #include "runtime/device/gpu/distribution/collective_fake_init.h"
  34. #endif
  35. #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
  36. #include "ps/util.h"
  37. #endif
  38. #include "ps/ps_context.h"
  39. #include "pybind_api/gil_scoped_long_running.h"
  40. namespace py = pybind11;
  41. using EnvInstance = mindspore::EnvInstance;
  42. using ExecutorPy = mindspore::pipeline::ExecutorPy;
  43. using Pipeline = mindspore::pipeline::Pipeline;
  44. using PrimitivePy = mindspore::PrimitivePy;
  45. using MetaFuncGraph = mindspore::MetaFuncGraph;
  46. using EventWriter = mindspore::summary::EventWriter;
  47. using OpLib = mindspore::kernel::OpLib;
  48. using OpInfoLoaderPy = mindspore::kernel::OpInfoLoaderPy;
  49. using ParallelContext = mindspore::parallel::ParallelContext;
  50. using CostModelContext = mindspore::parallel::CostModelContext;
  51. using mindspore::MsCtxParam;
  52. using PSContext = mindspore::ps::PSContext;
  53. // Interface with python
  54. PYBIND11_MODULE(_c_expression, m) {
  55. m.doc() = "MindSpore c plugin";
  56. auto fns = mindspore::PybindDefineRegister::AllFuncs();
  57. for (auto &item : fns) {
  58. item.second(&m);
  59. }
  60. mindspore::ScopedLongRunning::SetHook(std::make_unique<mindspore::GilScopedLongRunningHook>());
  61. // Class Pipeline interface
  62. (void)py::class_<ExecutorPy, std::shared_ptr<ExecutorPy>>(m, "Executor_")
  63. .def_static("get_instance", &ExecutorPy::GetInstance, "Executor get_instance.")
  64. .def("__call__", &ExecutorPy::Run, py::arg("args"), py::arg("phase") = py::str(""), "Executor run function.")
  65. .def("del_net_res", &ExecutorPy::DelNetRes, py::arg("network_id") = py::str(""), "Delete network resource.")
  66. .def("get_func_graph", &ExecutorPy::GetFuncGraph, py::arg("phase") = py::str(""), "Get graph pointer.")
  67. .def("get_func_graph_proto", &ExecutorPy::GetFuncGraphProto, py::arg("phase") = py::str(""),
  68. py::arg("type") = py::str("onnx_ir"), "Get graph proto string by specifying ir type.")
  69. .def("compile", &ExecutorPy::Compile, py::arg("obj"), py::arg("args"), py::arg("phase") = py::str(""),
  70. py::arg("use_vm") = py::bool_(false), "Compile obj by executor.")
  71. .def("updata_param_node_default_input", &ExecutorPy::UpdataParamNodeDefaultInput, py::arg("phase"),
  72. py::arg("params"), "Fetch the inputs of Conv or Matmul for quant export.")
  73. .def("get_parameter_layout", &ExecutorPy::GetParameterLayout, py::arg("phase") = py::str("train"),
  74. "Get Parameter Tensor Layout Dictionary.")
  75. .def("get_strategy", &ExecutorPy::GetCNodeStrategy, py::arg("phase") = py::str("train"),
  76. "Get CNode Strategy Dictionary.")
  77. .def("get_allreduce_fusion", &ExecutorPy::GetAllreduceFusion, py::arg("phase") = py::str("train"),
  78. "Get Allreduce Fusion Dictionary.")
  79. .def("fetch_info_for_quant_export", &ExecutorPy::FetchInfoForQuantExport, py::arg("phase") = py::str("train"),
  80. "Fetch the inputs of Conv or Matmul for quant export.")
  81. .def("build_data_graph", &ExecutorPy::BuildGraph, py::arg("build_params"), py::arg("phase") = py::str("train"),
  82. py::arg("broadcast_params") = py::dict(), "Build data graph.")
  83. .def("has_compiled", &ExecutorPy::HasCompiled, py::arg("phase") = py::str(""), "get if cell compiled.")
  84. .def("run_init_graph", &ExecutorPy::RunInitGraph, "Run init Graph.");
  85. (void)py::class_<EnvInstance, std::shared_ptr<EnvInstance>>(m, "EnvInstance_").def(py::init());
  86. (void)m.def("generate_key", &mindspore::pipeline::GenerateKey, "Generate the function graph key.");
  87. (void)m.def("real_run_op", &mindspore::pynative::RunOp, "Run op pynatively.");
  88. (void)m.def("reset_op_id", &mindspore::pipeline::ResetOpId, "Reset Operator Id");
  89. (void)m.def("init_hccl", &mindspore::pipeline::InitHccl, "Init Hccl");
  90. (void)m.def("finalize_hccl", &mindspore::pipeline::FinalizeHccl, "Finalize Hccl");
  91. (void)m.def("verify_inputs_signature", &mindspore::pipeline::VerifyInputSignature, "Verify input signature.");
  92. (void)m.def("init_exec_dataset", &mindspore::pipeline::InitExecDataset, py::arg("queue_name"), py::arg("size"),
  93. py::arg("batch_size"), py::arg("types"), py::arg("shapes"), py::arg("input_indexs"),
  94. py::arg("phase") = py::str("dataset"), py::arg("need_run") = py::bool_(true), "Init and exec dataset.");
  95. (void)m.def("_set_dataset_mode_config", &mindspore::ConfigManager::SetDatasetModeConfig, "API for set dataset mode.");
  96. (void)m.def("init_backend", &mindspore::pipeline::InitBackend, "Init Backend.");
  97. (void)m.def("export_graph", &mindspore::pipeline::ExportGraph, "Export Graph.");
  98. (void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig")
  99. .def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.")
  100. .def("get_enable_mpi", &mindspore::MpiConfig::enable_mpi, "Get whether enable mpi.")
  101. .def("set_enable_mpi", &mindspore::MpiConfig::set_enable_mpi, "Set whether to enable mpi.");
  102. (void)py::class_<ParallelContext, std::shared_ptr<ParallelContext>>(m, "AutoParallelContext")
  103. .def_static("get_instance", &ParallelContext::GetInstance, "Get auto parallel context instance.")
  104. .def("get_device_num", &ParallelContext::device_num, "Get device num.")
  105. .def("set_device_num", &ParallelContext::set_device_num, "Set device num.")
  106. .def("get_device_num_is_set", &ParallelContext::device_num_is_set, "Get device num is set.")
  107. .def("get_global_rank", &ParallelContext::global_rank, "Get global rank.")
  108. .def("set_global_rank", &ParallelContext::set_global_rank, "Set global rank.")
  109. .def("get_global_rank_is_set", &ParallelContext::global_rank_is_set, "Get global rank is set.")
  110. .def("get_gradients_mean", &ParallelContext::gradients_mean, "Get mirror mean.")
  111. .def("set_gradients_mean", &ParallelContext::set_gradients_mean, "Set mirror mean.")
  112. .def("get_gradient_fp32_sync", &ParallelContext::gradient_fp32_sync, "Get cast before mirror.")
  113. .def("set_gradient_fp32_sync", &ParallelContext::set_gradient_fp32_sync, "Set cast before mirror.")
  114. .def("get_loss_repeated_mean", &ParallelContext::loss_repeated_mean, "Get loss repeated mean.")
  115. .def("set_loss_repeated_mean", &ParallelContext::set_loss_repeated_mean, "Set loss repeated mean.")
  116. .def("get_parallel_mode", &ParallelContext::parallel_mode, "Get parallel mode.")
  117. .def("set_parallel_mode", &ParallelContext::set_parallel_mode, "Set parallel mode.")
  118. .def("get_strategy_search_mode", &ParallelContext::strategy_search_mode, "Get strategy search mode.")
  119. .def("set_strategy_search_mode", &ParallelContext::set_strategy_search_mode, "Set strategy search mode.")
  120. .def("set_all_reduce_fusion_split_indices", &ParallelContext::SetAllReduceFusionSplitIndices,
  121. "Set all reduce fusion split indices.")
  122. .def("get_all_reduce_fusion_split_indices", &ParallelContext::GetAllReduceFusionSplitIndices,
  123. "Get all reduce fusion split indices.")
  124. .def("set_all_reduce_fusion_split_sizes", &ParallelContext::SetAllReduceFusionSplitSizes,
  125. "Set all reduce fusion split sizes.")
  126. .def("get_all_reduce_fusion_split_sizes", &ParallelContext::GetAllReduceFusionSplitSizes,
  127. "Get all reduce fusion split sizes.")
  128. .def("set_enable_all_reduce_fusion", &ParallelContext::set_enable_all_reduce_fusion,
  129. "Set enable/disable all reduce fusion.")
  130. .def("get_enable_all_reduce_fusion", &ParallelContext::enable_all_reduce_fusion,
  131. "Get enable/disable all reduce fusion.")
  132. .def("get_parameter_broadcast", &ParallelContext::parameter_broadcast, "Get parameter broadcast.")
  133. .def("get_parameter_broadcast_is_set", &ParallelContext::parameter_broadcast_is_set,
  134. "Get parameter broadcast is set.")
  135. .def("set_parameter_broadcast", &ParallelContext::set_parameter_broadcast, "Set parameter broadcast.")
  136. .def("set_strategy_ckpt_load_file", &ParallelContext::set_strategy_ckpt_load_file,
  137. "Set strategy checkpoint load file.")
  138. .def("set_strategy_ckpt_save_file", &ParallelContext::set_strategy_ckpt_save_file,
  139. "Set strategy checkpoint save file.")
  140. .def("get_strategy_ckpt_load_file", &ParallelContext::strategy_ckpt_load_file, "Get strategy checkpoint load file.")
  141. .def("get_strategy_ckpt_save_file", &ParallelContext::strategy_ckpt_save_file, "Get strategy checkpoint save file.")
  142. .def("set_pipeline_stage_split_num", &ParallelContext::set_pipeline_stage_split_num,
  143. "Set pipeline stage split num.")
  144. .def("get_pipeline_stage_split_num", &ParallelContext::pipeline_stage_split_num, "Get pipeline stage split num.")
  145. .def("set_full_batch", &ParallelContext::set_full_batch, "Set whether load full batch on each device.")
  146. .def("get_full_batch", &ParallelContext::full_batch, "Get whether load full batch on each device.")
  147. .def("set_enable_parallel_optimizer", &ParallelContext::set_enable_parallel_optimizer,
  148. "Set enable/disable parallel optimizer.")
  149. .def("get_enable_parallel_optimizer", &ParallelContext::enable_parallel_optimizer,
  150. "Get enable/disable parallel optimizer.")
  151. .def("reset", &ParallelContext::Reset, "Reset auto parallel context.");
  152. (void)py::class_<CostModelContext, std::shared_ptr<CostModelContext>>(m, "CostModelContext")
  153. .def_static("get_instance", &CostModelContext::GetInstance, "Get cost_model context instance.")
  154. .def("set_device_memory_capacity", &CostModelContext::set_device_memory_capacity,
  155. "Set the capacity of device memory.")
  156. .def("get_device_memory_capacity", &CostModelContext::device_memory_capacity, "Get the capacity of device memory.")
  157. .def("set_costmodel_alpha", &CostModelContext::set_costmodel_alpha,
  158. "Set the parameter cost_model_alpha of the DP algorithm.")
  159. .def("get_costmodel_alpha", &CostModelContext::costmodel_alpha,
  160. "Get the parameter cost_model_alpha of the DP algorithm.")
  161. .def("set_costmodel_beta", &CostModelContext::set_costmodel_beta,
  162. "Set the parameter cost_model_beta of the DP algorithm.")
  163. .def("get_costmodel_beta", &CostModelContext::costmodel_beta,
  164. "Get the parameter cost_model_beta of the DP algorithm.")
  165. .def("set_costmodel_gamma", &CostModelContext::set_costmodel_gamma,
  166. "Set the parameter cost_model_gamma of the DP algorithm")
  167. .def("get_costmodel_gamma", &CostModelContext::costmodel_gamma,
  168. "Get the parameter cost_model_gamma of the DP algorithm.")
  169. .def("set_costmodel_communi_threshold", &CostModelContext::set_costmodel_communi_threshold,
  170. "Set the parameter cost_model_communi_threshold of the DP algorithm.")
  171. .def("get_costmodel_communi_threshold", &CostModelContext::costmodel_communi_threshold,
  172. "Get the parameter cost_model_communi_threshold of the DP algorithm.")
  173. .def("set_costmodel_communi_const", &CostModelContext::set_costmodel_communi_const,
  174. "Set the parameter cost_model_communi_const of the DP algorithm.")
  175. .def("get_costmodel_communi_const", &CostModelContext::costmodel_communi_const,
  176. "Get the parameter cost_model_communi_const of the DP algorithm.")
  177. .def("set_costmodel_communi_bias", &CostModelContext::set_costmodel_communi_bias,
  178. "Set the parameter cost_model_communi_bias of the DP algorithm.")
  179. .def("get_costmodel_communi_bias", &CostModelContext::costmodel_communi_bias,
  180. "Get the parameter cost_model_communi_bias of the DP algorithm.")
  181. .def("set_multi_subgraphs", &CostModelContext::set_multi_subgraphs, "Set the parameter is_multi_subgraphs.")
  182. .def("get_multi_subgraphs", &CostModelContext::is_multi_subgraphs, "Get the parameter is_multi_subgraphs.")
  183. .def("set_run_phase", &CostModelContext::set_run_phase, "Set the flag run_phase.")
  184. .def("get_run_phase", &CostModelContext::run_phase, "Get the flag run_phase.")
  185. .def("set_costmodel_allreduce_fusion_algorithm", &CostModelContext::set_costmodel_allreduce_fusion_algorithm,
  186. "Set the parameter gradient AllReduce fusion algorithm.")
  187. .def("get_costmodel_allreduce_fusion_algorithm", &CostModelContext::costmodel_allreduce_fusion_algorithm,
  188. "Get the parameter gradient AllReduce fusion algorithm.")
  189. .def("set_costmodel_allreduce_fusion_times", &CostModelContext::set_costmodel_allreduce_fusion_times,
  190. "Set the parameter gradient AllReduce times.")
  191. .def("get_costmodel_allreduce_fusion_times", &CostModelContext::costmodel_allreduce_fusion_times,
  192. "Get the parameter gradient AllReduce times.")
  193. .def("set_costmodel_allreduce_fusion_tail_percent", &CostModelContext::set_costmodel_allreduce_fusion_tail_percent,
  194. "Set the parameter gradient AllReduce fusion tail percent.")
  195. .def("get_costmodel_allreduce_fusion_tail_percent", &CostModelContext::costmodel_allreduce_fusion_tail_percent,
  196. "Get the parameter gradient AllReduce fusion tail percent.")
  197. .def("set_costmodel_allreduce_fusion_tail_time", &CostModelContext::set_costmodel_allreduce_fusion_tail_time,
  198. "Set the parameter gradient AllReduce fusion tail time.")
  199. .def("get_costmodel_allreduce_fusion_tail_time", &CostModelContext::costmodel_allreduce_fusion_tail_time,
  200. "Get the parameter gradient AllReduce fusion tail time.")
  201. .def("set_costmodel_allreduce_fusion_allreduce_inherent_time",
  202. &CostModelContext::set_costmodel_allreduce_fusion_allreduce_inherent_time,
  203. "Set the parameter gradient AllReduce fusion allreduce inherent time.")
  204. .def("get_costmodel_allreduce_fusion_allreduce_inherent_time",
  205. &CostModelContext::costmodel_allreduce_fusion_allreduce_inherent_time,
  206. "Get the parameter gradient AllReduce fusion allreduce inherent time.")
  207. .def("set_costmodel_allreduce_fusion_allreduce_bandwidth",
  208. &CostModelContext::set_costmodel_allreduce_fusion_allreduce_bandwidth,
  209. "Set the parameter gradient AllReduce fusion allreduce bandwidth.")
  210. .def("get_costmodel_allreduce_fusion_allreduce_bandwidth",
  211. &CostModelContext::costmodel_allreduce_fusion_allreduce_bandwidth,
  212. "Get the parameter gradient AllReduce fusion allreduce bandwidth.")
  213. .def("set_costmodel_allreduce_fusion_computation_time_parameter",
  214. &CostModelContext::set_costmodel_allreduce_fusion_computation_time_parameter,
  215. "Set the parameter gradient AllReduce fusion computation time parameter.")
  216. .def("get_costmodel_allreduce_fusion_computation_time_parameter",
  217. &CostModelContext::costmodel_allreduce_fusion_computation_time_parameter,
  218. "Get the parameter gradient AllReduce fusion computation time parameter.")
  219. .def("set_tensor_slice_align_enable", &CostModelContext::set_tensor_slice_alignment_enable,
  220. "Set the parameter tensor_slice_align_enable in strategy generation.")
  221. .def("get_tensor_slice_align_enable", &CostModelContext::tensor_slice_alignment_enable,
  222. "Get the parameter tensor_slice_align_enable in strategy generation.")
  223. .def("set_tensor_slice_align_size", &CostModelContext::set_tensor_slice_alignment_size,
  224. "Set the parameter tensor_slice_size in strategy generation.")
  225. .def("get_tensor_slice_align_size", &CostModelContext::tensor_slice_alignment_size,
  226. "Get the parameter tensor_slice_size in strategy generation.")
  227. .def("set_fully_use_devices", &CostModelContext::set_fully_use_device,
  228. "Set the parameter fully_use_devices in the DP algorithm.")
  229. .def("get_fully_use_devices", &CostModelContext::fully_use_device,
  230. "Get the parameter fully_use_devices in the DP algorithm.")
  231. .def("set_elementwise_op_strategy_follow", &CostModelContext::set_elementwise_stra_follow,
  232. "Set the parameter elementwise_op_strategy_follow in the DP algorithm.")
  233. .def("get_elementwise_op_strategy_follow", &CostModelContext::elementwise_stra_follow,
  234. "Get the parameter elementwise_op_strategy_follow in the DP algorithm.")
  235. .def("reset_cost_model", &CostModelContext::ResetCostModel, "Reset the CostModelContext.")
  236. .def("reset_algo_parameters", &CostModelContext::ResetAlgoParameters, "Reset the AlgoParameters.");
  237. (void)py::module::import("atexit").attr("register")(py::cpp_function{[&]() -> void {
  238. // only in case that c++ calling python interface, ClearResAtexit should be called.
  239. if (mindspore::parse::python_adapter::IsPythonEnv()) {
  240. mindspore::pipeline::ClearResAtexit();
  241. #ifdef ENABLE_MINDDATA
  242. py::module iterators = py::module::import("mindspore.dataset.engine.iterators");
  243. (void)iterators.attr("_cleanup")();
  244. #endif
  245. }
  246. }});
  247. (void)py::class_<EventWriter, std::shared_ptr<EventWriter>>(m, "EventWriter_")
  248. .def(py::init<const std::string &>())
  249. .def("GetFileName", &EventWriter::GetFileName, "Get the file name.")
  250. .def("Open", &EventWriter::Open, "Open the write file.")
  251. .def("Write", &EventWriter::Write, "Write the serialize event.")
  252. .def("EventCount", &EventWriter::GetWriteEventCount, "Write event count.")
  253. .def("Flush", &EventWriter::Flush, "Flush the event.")
  254. .def("Close", &EventWriter::Close, "Close the write.")
  255. .def("Shut", &EventWriter::Shut, "Final close the write.");
  256. (void)py::class_<OpLib, std::shared_ptr<OpLib>>(m, "Oplib")
  257. .def(py::init())
  258. .def_static("reg_op", &OpLib::RegOp, "Register op info.");
  259. #ifdef ENABLE_GPU_COLLECTIVE
  260. (void)m.def("init_gpu_collective", &mindspore::device::gpu::CollectiveInitializer::InitCollective,
  261. "Init gpu collective communication mode.");
  262. (void)m.def("finalize_gpu_collective", &mindspore::device::gpu::CollectiveInitializer::FinalizeCollective,
  263. "Finalize gpu collective communication mode.");
  264. #else
  265. (void)m.def("init_gpu_collective", &mindspore::device::gpu::CollectiveFakeInitializer::InitCollective,
  266. "Init gpu collective communication mode.");
  267. (void)m.def("finalize_gpu_collective", &mindspore::device::gpu::CollectiveFakeInitializer::FinalizeCollective,
  268. "Finalize gpu collective communication mode.");
  269. #endif
  270. (void)py::class_<PSContext, std::shared_ptr<PSContext>>(m, "PSContext")
  271. .def_static("get_instance", &PSContext::instance, "Get PS context instance.")
  272. .def("set_ps_enable", &PSContext::SetPSEnable, "Set PS mode enabled or disabled.")
  273. .def("is_ps_enabled", &PSContext::is_ps_enabled, "Get PS mode enable-disable status.")
  274. .def("reset", &PSContext::Reset, "Reset PS context attributes.")
  275. .def("is_role_worker", &PSContext::is_role_worker, "Get whether the role of this process is Worker.")
  276. .def("is_role_pserver", &PSContext::is_role_pserver, "Get whether the role of this process is PServer.")
  277. .def("is_role_sched", &PSContext::is_role_sched, "Get whether the role of this process is Scheduler.")
  278. .def("ps_rank_id", &PSContext::ps_rank_id, "Get Worker and PServer rank id.");
  279. (void)py::class_<OpInfoLoaderPy, std::shared_ptr<OpInfoLoaderPy>>(m, "OpInfoLoaderPy")
  280. .def(py::init())
  281. .def("get_all_ops_info", &OpInfoLoaderPy::GetAllOpsInfo, "get all ops info.");
  282. }