You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

init.cc 33 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. /**
  2. * Copyright 2019-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <pybind11/operators.h>
  17. #include "backend/kernel_compiler/oplib/oplib.h"
  18. #include "pipeline/jit/pipeline.h"
  19. #include "frontend/operator/composite/composite.h"
  20. #include "pipeline/pynative/pynative_execute.h"
  21. #include "utils/symbolic.h"
  22. #include "pybind_api/api_register.h"
  23. #include "pipeline/jit/parse/python_adapter.h"
  24. #ifndef ENABLE_SECURITY
  25. #include "utils/summary/event_writer.h"
  26. #endif
  27. #include "utils/config_manager.h"
  28. #include "utils/mpi/mpi_config.h"
  29. #include "utils/ms_utils.h"
  30. #include "frontend/parallel/context.h"
  31. #include "frontend/parallel/costmodel_context.h"
  32. #include "frontend/optimizer/ad/dfunctor.h"
  33. #ifdef ENABLE_GPU_COLLECTIVE
  34. #include "runtime/device/gpu/distribution/collective_init.h"
  35. #else
  36. #include "runtime/device/gpu/distribution/collective_fake_init.h"
  37. #endif
  38. #if ((defined ENABLE_CPU) && (!defined _WIN32))
  39. #include "ps/util.h"
  40. #endif
  41. #include "ps/ps_context.h"
  42. #include "pybind_api/gil_scoped_long_running.h"
  43. namespace py = pybind11;
  44. using EnvInstance = mindspore::EnvInstance;
  45. using GraphExecutorPy = mindspore::pipeline::GraphExecutorPy;
  46. using Pipeline = mindspore::pipeline::Pipeline;
  47. using PrimitivePy = mindspore::PrimitivePy;
  48. using MetaFuncGraph = mindspore::MetaFuncGraph;
  49. #ifndef ENABLE_SECURITY
  50. using EventWriter = mindspore::summary::EventWriter;
  51. #endif // ENABLE_SECURITY
  52. using OpLib = mindspore::kernel::OpLib;
  53. using ParallelContext = mindspore::parallel::ParallelContext;
  54. using CostModelContext = mindspore::parallel::CostModelContext;
  55. using mindspore::MsCtxParam;
  56. using PSContext = mindspore::ps::PSContext;
  57. // Interface with python
  58. PYBIND11_MODULE(_c_expression, m) {
  59. // The OMP_NUM_THREADS has no effect when set in backend, so set it here in advance.
  60. mindspore::common::SetOMPThreadNum();
  61. m.doc() = "MindSpore c plugin";
  62. auto fns = mindspore::PybindDefineRegister::AllFuncs();
  63. for (auto &item : fns) {
  64. item.second(&m);
  65. }
  66. mindspore::ScopedLongRunning::SetHook(std::make_unique<mindspore::GilScopedLongRunningHook>());
  67. // Class Pipeline interface
  68. (void)py::class_<GraphExecutorPy, std::shared_ptr<GraphExecutorPy>>(m, "GraphExecutor_")
  69. .def_static("get_instance", &GraphExecutorPy::GetInstance, "Executor get_instance.")
  70. .def("__call__", &GraphExecutorPy::Run, py::arg("args"), py::arg("phase") = py::str(""), "Executor run function.")
  71. .def("del_net_res", &GraphExecutorPy::DelNetRes, py::arg("network_id") = py::set(), "Delete network resource.")
  72. .def("get_func_graph", &GraphExecutorPy::GetFuncGraph, py::arg("phase") = py::str(""), "Get graph pointer.")
  73. .def("get_func_graph_proto", &GraphExecutorPy::GetFuncGraphProto, py::arg("phase") = py::str(""),
  74. py::arg("type") = py::str("onnx_ir"), "Get graph proto string by specifying ir type.")
  75. .def("compile", &GraphExecutorPy::Compile, py::arg("obj"), py::arg("args"), py::arg("phase") = py::str(""),
  76. py::arg("use_vm") = py::bool_(false), "Compile obj by executor.")
  77. .def("updata_param_node_default_input", &GraphExecutorPy::UpdataParamNodeDefaultInput, py::arg("phase"),
  78. py::arg("params"), "Fetch the inputs of Conv or Matmul for quant export.")
  79. .def("get_parameter_layout", &GraphExecutorPy::GetParameterLayout, py::arg("phase") = py::str("train"),
  80. "Get Parameter Tensor Layout Dictionary.")
  81. .def("get_parallel_parameter_name_list", &GraphExecutorPy::GetParallelParameterNameList,
  82. py::arg("phase") = py::str("train"), "Get Parallel Parameter Name List.")
  83. .def("get_strategy", &GraphExecutorPy::GetCNodeStrategy, py::arg("phase") = py::str("train"),
  84. "Get CNode Strategy Dictionary.")
  85. .def("get_num_parallel_ops", &GraphExecutorPy::GetNumOpsInfo, py::arg("phase") = py::str("train"),
  86. "Get the number of parallel operators.")
  87. .def("get_allreduce_fusion", &GraphExecutorPy::GetAllreduceFusion, py::arg("phase") = py::str("train"),
  88. "Get Allreduce Fusion Dictionary.")
  89. .def("fetch_info_for_quant_export", &GraphExecutorPy::FetchInfoForQuantExport, py::arg("phase") = py::str("train"),
  90. "Fetch the inputs of Conv or Matmul for quant export.")
  91. .def("build_data_graph", &GraphExecutorPy::BuildGraph, py::arg("build_params"), py::arg("phase") = py::str("train"),
  92. py::arg("broadcast_params") = py::dict(), "Build data graph.")
  93. .def("has_compiled", &GraphExecutorPy::HasCompiled, py::arg("phase") = py::str(""), "Get if cell compiled.")
  94. .def("run_init_graph", &GraphExecutorPy::RunInitGraph, "Run init Graph.")
  95. .def("set_py_exe_path", &GraphExecutorPy::PyExePath, py::arg("py_exe_path") = py::str(""),
  96. "Set python executable path.")
  97. .def("set_kernel_build_server_dir", &GraphExecutorPy::KernelBuildServerDir,
  98. py::arg("kernel_build_server_dir") = py::str(""), "Set kernel build server directory path.")
  99. .def("set_queue_name", &GraphExecutorPy::set_queue_name, py::arg("queue_name") = py::str(""),
  100. "Set queue name for the graph loaded from compile cache.")
  101. .def("set_enable_tuple_broaden", &GraphExecutorPy::set_enable_tuple_broaden,
  102. py::arg("enable_tuple_broaden") = py::bool_(false), "Set tuple broaden enable.")
  103. .def("set_compile_cache_dep_files", &GraphExecutorPy::set_compile_cache_dep_files,
  104. py::arg("compile_cache_dep_files") = py::list(), "Set the compilation cache dependent files.")
  105. .def("set_weights_values", &GraphExecutorPy::set_weights_values, py::arg("weights") = py::dict(),
  106. "Set values of weights.")
  107. .def("get_optimize_graph_proto", &GraphExecutorPy::GetOptimizeGraphProto, py::arg("phase") = py::str(""),
  108. "Get the optimize graph proto string.")
  109. .def("set_jit_config", &GraphExecutorPy::SetJitConfig, py::arg("jit_config") = py::dict(), "Set the jit config.");
  110. (void)py::class_<EnvInstance, std::shared_ptr<EnvInstance>>(m, "EnvInstance_").def(py::init());
  111. (void)m.def("generate_arguments_key", &mindspore::pipeline::GenerateArgumentsKey, "Generate unique key of argument.");
  112. (void)m.def("real_run_op", &mindspore::pynative::RealRunOp, "Run op pynatively.");
  113. (void)m.def("reset_op_id", &mindspore::pipeline::ResetOpId, "Reset Operator Id");
  114. (void)m.def("init_hccl", &mindspore::pipeline::InitHccl, "Init Hccl");
  115. (void)m.def("finalize_hccl", &mindspore::pipeline::FinalizeHccl, "Finalize Hccl");
  116. (void)m.def("get_hccl_rank_id", &mindspore::pipeline::GetHcclRankId, "Get Hccl Rank Id");
  117. (void)m.def("get_hccl_rank_size", &mindspore::pipeline::GetHcclRankSize, "Get Hccl Rank Size");
  118. (void)m.def("verify_inputs_signature", &mindspore::pipeline::VerifyInputSignature, "Verify input signature.");
  119. (void)m.def("init_exec_dataset", &mindspore::pipeline::InitExecDataset, py::arg("queue_name"), py::arg("size"),
  120. py::arg("batch_size"), py::arg("types"), py::arg("shapes"), py::arg("input_indexs"),
  121. py::arg("phase") = py::str("dataset"), py::arg("need_run") = py::bool_(true), "Init and exec dataset.");
  122. (void)m.def("_set_dataset_mode_config", &mindspore::ConfigManager::SetDatasetModeConfig, "API for set dataset mode.");
  123. (void)m.def("init_pipeline", &mindspore::pipeline::InitPipeline, "Init Pipeline.");
  124. (void)m.def("export_graph", &mindspore::pipeline::ExportGraph, "Export Graph.");
  125. (void)m.def("load_mindir", &mindspore::pipeline::LoadMindIR, py::arg("file_name"), py::arg("dec_key") = nullptr,
  126. py::arg("key_len") = py::int_(0), py::arg("dec_mode") = py::str("AES-GCM"), "Load model as Graph.");
  127. (void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig")
  128. .def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.")
  129. .def("get_enable_mpi", &mindspore::MpiConfig::enable_mpi, "Get whether enable mpi.")
  130. .def("set_enable_mpi", &mindspore::MpiConfig::set_enable_mpi, "Set whether to enable mpi.");
  131. (void)py::class_<ParallelContext, std::shared_ptr<ParallelContext>>(m, "AutoParallelContext")
  132. .def_static("get_instance", &ParallelContext::GetInstance, "Get auto parallel context instance.")
  133. .def("get_device_num", &ParallelContext::device_num, "Get device num.")
  134. .def("set_hccl_test_avaible", &ParallelContext::set_hccl_test_available, "Set hccl test available.")
  135. .def("set_device_num", &ParallelContext::set_device_num, "Set device num.")
  136. .def("get_device_num_is_set", &ParallelContext::device_num_is_set, "Get device num is set.")
  137. .def("get_global_rank", &ParallelContext::global_rank, "Get global rank.")
  138. .def("set_global_rank", &ParallelContext::set_global_rank, "Set global rank.")
  139. .def("get_grad_accumulation_shard", &ParallelContext::grad_accumulation_shard, "Get grad_accumulation_shard.")
  140. .def("set_grad_accumulation_shard", &ParallelContext::set_grad_accumulation_shard, "Set grad_accumulation_shard.")
  141. .def("get_global_rank_is_set", &ParallelContext::global_rank_is_set, "Get global rank is set.")
  142. .def("get_gradients_mean", &ParallelContext::gradients_mean, "Get mirror mean.")
  143. .def("set_gradients_mean", &ParallelContext::set_gradients_mean, "Set mirror mean.")
  144. .def("get_gradient_fp32_sync", &ParallelContext::gradient_fp32_sync, "Get cast before mirror.")
  145. .def("set_gradient_fp32_sync", &ParallelContext::set_gradient_fp32_sync, "Set cast before mirror.")
  146. .def("get_loss_repeated_mean", &ParallelContext::loss_repeated_mean, "Get loss repeated mean.")
  147. .def("set_loss_repeated_mean", &ParallelContext::set_loss_repeated_mean, "Set loss repeated mean.")
  148. .def("get_parallel_mode", &ParallelContext::parallel_mode, "Get parallel mode.")
  149. .def("set_parallel_mode", &ParallelContext::set_parallel_mode, "Set parallel mode.")
  150. .def("get_grad_accumulation_step", &ParallelContext::grad_accumulation_step, "Get grad accumulation step.")
  151. .def("set_grad_accumulation_step", &ParallelContext::set_grad_accumulation_step, "Set grad accumulation step.")
  152. .def("get_strategy_search_mode", &ParallelContext::strategy_search_mode, "Get strategy search mode.")
  153. .def("set_strategy_search_mode", &ParallelContext::set_strategy_search_mode, "Set strategy search mode.")
  154. .def("set_all_reduce_fusion_split_indices", &ParallelContext::SetAllReduceFusionSplitIndices,
  155. "Set all reduce fusion split indices.")
  156. .def("get_all_reduce_fusion_split_indices", &ParallelContext::GetAllReduceFusionSplitIndices,
  157. "Get all reduce fusion split indices.")
  158. .def("set_all_reduce_fusion_split_sizes", &ParallelContext::SetAllReduceFusionSplitSizes,
  159. "Set all reduce fusion split sizes.")
  160. .def("get_all_reduce_fusion_split_sizes", &ParallelContext::GetAllReduceFusionSplitSizes,
  161. "Get all reduce fusion split sizes.")
  162. .def("set_enable_all_reduce_fusion", &ParallelContext::set_enable_all_reduce_fusion,
  163. "Set enable/disable all reduce fusion.")
  164. .def("get_enable_all_reduce_fusion", &ParallelContext::enable_all_reduce_fusion,
  165. "Get enable/disable all reduce fusion.")
  166. .def("get_parameter_broadcast", &ParallelContext::parameter_broadcast, "Get parameter broadcast.")
  167. .def("get_parameter_broadcast_is_set", &ParallelContext::parameter_broadcast_is_set,
  168. "Get parameter broadcast is set.")
  169. .def("set_parameter_broadcast", &ParallelContext::set_parameter_broadcast, "Set parameter broadcast.")
  170. .def("set_strategy_ckpt_load_file", &ParallelContext::set_strategy_ckpt_load_file,
  171. "Set strategy checkpoint load file.")
  172. .def("set_strategy_ckpt_save_file", &ParallelContext::set_strategy_ckpt_save_file,
  173. "Set strategy checkpoint save file.")
  174. .def("get_strategy_ckpt_load_file", &ParallelContext::strategy_ckpt_load_file, "Get strategy checkpoint load file.")
  175. .def("get_strategy_ckpt_save_file", &ParallelContext::strategy_ckpt_save_file, "Get strategy checkpoint save file.")
  176. .def("set_group_ckpt_save_file", &ParallelContext::set_group_ckpt_save_file, "Set group checkpoint save file.")
  177. .def("set_pipeline_stage_split_num", &ParallelContext::set_pipeline_stage_split_num,
  178. "Set pipeline stage split num.")
  179. .def("get_pipeline_stage_split_num", &ParallelContext::pipeline_stage_split_num, "Get pipeline stage split num.")
  180. .def("set_full_batch", &ParallelContext::set_full_batch, "Set whether load full batch on each device.")
  181. .def("get_full_batch", &ParallelContext::full_batch, "Get whether load full batch on each device.")
  182. .def("set_dataset_strategy", &ParallelContext::set_dataset_strategy, "Set dataset sharding strategy.")
  183. .def("get_dataset_strategy", &ParallelContext::dataset_strategy, "Get dataset sharding strategy.")
  184. .def("set_enable_parallel_optimizer", &ParallelContext::set_enable_parallel_optimizer,
  185. "Set enable/disable parallel optimizer.")
  186. .def("get_enable_parallel_optimizer", &ParallelContext::enable_parallel_optimizer,
  187. "Get enable/disable parallel optimizer.")
  188. .def("set_communi_parallel_mode", &ParallelContext::set_communi_parallel_mode, "Set communication parallel mode.")
  189. .def("get_communi_parallel_mode", &ParallelContext::communi_parallel_mode, "Get communication parallel mode.")
  190. .def("set_optimizer_weight_shard_size", &ParallelContext::set_optimizer_weight_shard_size,
  191. "Set opt shard group size when not fully use parallel optimizer.")
  192. .def("get_optimizer_weight_shard_size", &ParallelContext::optimizer_weight_shard_size,
  193. "Get opt shard group size when not fully use parallel optimizer.")
  194. .def("set_optimizer_weight_shard_aggregated_save", &ParallelContext::set_optimizer_weight_shard_aggregated_save,
  195. "Set whether to integrated save weight shard when enable parallel optimizer.")
  196. .def("get_optimizer_weight_shard_aggregated_save", &ParallelContext::optimizer_weight_shard_aggregated_save,
  197. "Get whether to integrated save weight shard when enable parallel optimizer.")
  198. .def("set_enable_alltoall", &ParallelContext::set_enable_all2all, "Set the enabling AllToAll value.")
  199. .def("get_enable_alltoall", &ParallelContext::enable_all2all, "Get the enabling AllToAll value.")
  200. .def("set_sharding_propagation", &ParallelContext::set_sharding_propagation,
  201. "Set sharding strategy propagation value.")
  202. .def("get_sharding_propagation", &ParallelContext::sharding_propagation, "Get sharding strategy propagation value.")
  203. .def("reset", &ParallelContext::Reset, "Reset auto parallel context.");
  204. (void)py::class_<CostModelContext, std::shared_ptr<CostModelContext>>(m, "CostModelContext")
  205. .def_static("get_instance", &CostModelContext::GetInstance, "Get cost_model context instance.")
  206. .def("set_device_memory_capacity", &CostModelContext::set_device_memory_capacity,
  207. "Set the capacity of device memory.")
  208. .def("get_device_memory_capacity", &CostModelContext::device_memory_capacity, "Get the capacity of device memory.")
  209. .def("set_costmodel_alpha", &CostModelContext::set_costmodel_alpha,
  210. "Set the parameter cost_model_alpha of the DP algorithm.")
  211. .def("get_costmodel_alpha", &CostModelContext::costmodel_alpha,
  212. "Get the parameter cost_model_alpha of the DP algorithm.")
  213. .def("set_costmodel_beta", &CostModelContext::set_costmodel_beta,
  214. "Set the parameter cost_model_beta of the DP algorithm.")
  215. .def("get_costmodel_beta", &CostModelContext::costmodel_beta,
  216. "Get the parameter cost_model_beta of the DP algorithm.")
  217. .def("set_costmodel_gamma", &CostModelContext::set_costmodel_gamma,
  218. "Set the parameter cost_model_gamma of the DP algorithm")
  219. .def("get_costmodel_gamma", &CostModelContext::costmodel_gamma,
  220. "Get the parameter cost_model_gamma of the DP algorithm.")
  221. .def("set_costmodel_communi_threshold", &CostModelContext::set_costmodel_communi_threshold,
  222. "Set the parameter cost_model_communi_threshold of the DP algorithm.")
  223. .def("get_costmodel_communi_threshold", &CostModelContext::costmodel_communi_threshold,
  224. "Get the parameter cost_model_communi_threshold of the DP algorithm.")
  225. .def("set_costmodel_communi_const", &CostModelContext::set_costmodel_communi_const,
  226. "Set the parameter cost_model_communi_const of the DP algorithm.")
  227. .def("get_costmodel_communi_const", &CostModelContext::costmodel_communi_const,
  228. "Get the parameter cost_model_communi_const of the DP algorithm.")
  229. .def("set_costmodel_communi_bias", &CostModelContext::set_costmodel_communi_bias,
  230. "Set the parameter cost_model_communi_bias of the DP algorithm.")
  231. .def("get_costmodel_communi_bias", &CostModelContext::costmodel_communi_bias,
  232. "Get the parameter cost_model_communi_bias of the DP algorithm.")
  233. .def("set_multi_subgraphs", &CostModelContext::set_multi_subgraphs, "Set the parameter is_multi_subgraphs.")
  234. .def("get_multi_subgraphs", &CostModelContext::is_multi_subgraphs, "Get the parameter is_multi_subgraphs.")
  235. .def("set_run_phase", &CostModelContext::set_run_phase, "Set the flag run_phase.")
  236. .def("get_run_phase", &CostModelContext::run_phase, "Get the flag run_phase.")
  237. .def("set_costmodel_allreduce_fusion_algorithm", &CostModelContext::set_costmodel_allreduce_fusion_algorithm,
  238. "Set the parameter gradient AllReduce fusion algorithm.")
  239. .def("get_costmodel_allreduce_fusion_algorithm", &CostModelContext::costmodel_allreduce_fusion_algorithm,
  240. "Get the parameter gradient AllReduce fusion algorithm.")
  241. .def("set_costmodel_allreduce_fusion_times", &CostModelContext::set_costmodel_allreduce_fusion_times,
  242. "Set the parameter gradient AllReduce times.")
  243. .def("get_costmodel_allreduce_fusion_times", &CostModelContext::costmodel_allreduce_fusion_times,
  244. "Get the parameter gradient AllReduce times.")
  245. .def("set_costmodel_allreduce_fusion_tail_percent", &CostModelContext::set_costmodel_allreduce_fusion_tail_percent,
  246. "Set the parameter gradient AllReduce fusion tail percent.")
  247. .def("get_costmodel_allreduce_fusion_tail_percent", &CostModelContext::costmodel_allreduce_fusion_tail_percent,
  248. "Get the parameter gradient AllReduce fusion tail percent.")
  249. .def("set_costmodel_allreduce_fusion_tail_time", &CostModelContext::set_costmodel_allreduce_fusion_tail_time,
  250. "Set the parameter gradient AllReduce fusion tail time.")
  251. .def("get_costmodel_allreduce_fusion_tail_time", &CostModelContext::costmodel_allreduce_fusion_tail_time,
  252. "Get the parameter gradient AllReduce fusion tail time.")
  253. .def("set_costmodel_allreduce_fusion_allreduce_inherent_time",
  254. &CostModelContext::set_costmodel_allreduce_fusion_allreduce_inherent_time,
  255. "Set the parameter gradient AllReduce fusion allreduce inherent time.")
  256. .def("get_costmodel_allreduce_fusion_allreduce_inherent_time",
  257. &CostModelContext::costmodel_allreduce_fusion_allreduce_inherent_time,
  258. "Get the parameter gradient AllReduce fusion allreduce inherent time.")
  259. .def("set_costmodel_allreduce_fusion_allreduce_bandwidth",
  260. &CostModelContext::set_costmodel_allreduce_fusion_allreduce_bandwidth,
  261. "Set the parameter gradient AllReduce fusion allreduce bandwidth.")
  262. .def("get_costmodel_allreduce_fusion_allreduce_bandwidth",
  263. &CostModelContext::costmodel_allreduce_fusion_allreduce_bandwidth,
  264. "Get the parameter gradient AllReduce fusion allreduce bandwidth.")
  265. .def("set_costmodel_allreduce_fusion_computation_time_parameter",
  266. &CostModelContext::set_costmodel_allreduce_fusion_computation_time_parameter,
  267. "Set the parameter gradient AllReduce fusion computation time parameter.")
  268. .def("get_costmodel_allreduce_fusion_computation_time_parameter",
  269. &CostModelContext::costmodel_allreduce_fusion_computation_time_parameter,
  270. "Get the parameter gradient AllReduce fusion computation time parameter.")
  271. .def("set_tensor_slice_align_enable", &CostModelContext::set_tensor_slice_alignment_enable,
  272. "Set the parameter tensor_slice_align_enable in strategy generation.")
  273. .def("get_tensor_slice_align_enable", &CostModelContext::tensor_slice_alignment_enable,
  274. "Get the parameter tensor_slice_align_enable in strategy generation.")
  275. .def("set_tensor_slice_align_size", &CostModelContext::set_tensor_slice_alignment_size,
  276. "Set the parameter tensor_slice_size in strategy generation.")
  277. .def("get_tensor_slice_align_size", &CostModelContext::tensor_slice_alignment_size,
  278. "Get the parameter tensor_slice_size in strategy generation.")
  279. .def("set_fully_use_devices", &CostModelContext::set_fully_use_device,
  280. "Set the parameter fully_use_devices in the DP algorithm.")
  281. .def("get_fully_use_devices", &CostModelContext::fully_use_device,
  282. "Get the parameter fully_use_devices in the DP algorithm.")
  283. .def("set_elementwise_op_strategy_follow", &CostModelContext::set_elementwise_stra_follow,
  284. "Set the parameter elementwise_op_strategy_follow in the DP algorithm.")
  285. .def("get_elementwise_op_strategy_follow", &CostModelContext::elementwise_stra_follow,
  286. "Get the parameter elementwise_op_strategy_follow in the DP algorithm.")
  287. .def("set_dp_algo_enable_approxi", &CostModelContext::set_dp_algo_enable_approxi,
  288. "Set the flag whether enabling approximation in the DP algorithm.")
  289. .def("get_dp_algo_enable_approxi", &CostModelContext::dp_algo_enable_approxi,
  290. "Get the flag whether enabling approximation in the DP algorithm.")
  291. .def("set_dp_algo_approxi_epsilon", &CostModelContext::set_dp_algo_approxi_epsilon,
  292. "Set the epsilon which is used in the approximation of DP algorithm.")
  293. .def("get_dp_algo_approxi_epsilon", &CostModelContext::dp_algo_approxi_epsilon,
  294. "Get the epsilon which is used in the approximation of DP algorithm.")
  295. .def("set_dp_algo_single_loop", &CostModelContext::set_dp_algo_single_loop,
  296. "Set the flag of generating a single suite of OperatorInfos in for-loop.")
  297. .def("get_dp_algo_single_loop", &CostModelContext::dp_algo_single_loop,
  298. "Get the flag of whether or not generating a single suite of OperatorInfos in for-loop.")
  299. .def("reset_cost_model", &CostModelContext::ResetCostModel, "Reset the CostModelContext.")
  300. .def("reset_algo_parameters", &CostModelContext::ResetAlgoParameters, "Reset the AlgoParameters.");
  301. (void)py::module::import("atexit").attr("register")(py::cpp_function{[&]() -> void {
  302. #ifdef ENABLE_MINDDATA
  303. MS_LOG(INFO) << "Start releasing dataset handles...";
  304. py::module iterators = py::module::import("mindspore.dataset.engine.iterators");
  305. (void)iterators.attr("_cleanup")();
  306. MS_LOG(INFO) << "End release dataset handles.";
  307. #endif
  308. // only in case that c++ calling python interface, ClearResAtexit should be called.
  309. if (mindspore::parse::python_adapter::IsPythonEnv()) {
  310. mindspore::pipeline::ClearResAtexit();
  311. }
  312. }});
  313. #ifndef ENABLE_SECURITY
  314. (void)py::class_<EventWriter, std::shared_ptr<EventWriter>>(m, "EventWriter_")
  315. .def(py::init<const std::string &>())
  316. .def("GetFileName", &EventWriter::GetFileName, "Get the file name.")
  317. .def("Open", &EventWriter::Open, "Open the write file.")
  318. .def("Write", &EventWriter::Write, "Write the serialize event.")
  319. .def("EventCount", &EventWriter::GetWriteEventCount, "Write event count.")
  320. .def("Flush", &EventWriter::Flush, "Flush the event.")
  321. .def("Close", &EventWriter::Close, "Close the write.")
  322. .def("Shut", &EventWriter::Shut, "Final close the write.");
  323. #endif // ENABLE_SECURITY
  324. (void)py::class_<OpLib, std::shared_ptr<OpLib>>(m, "Oplib")
  325. .def(py::init())
  326. .def_static("reg_op", &OpLib::RegOp, "Register op info.");
  327. #ifdef ENABLE_GPU_COLLECTIVE
  328. (void)m.def("init_gpu_collective", &mindspore::device::gpu::CollectiveInitializer::InitCollective,
  329. "Init gpu collective communication mode.");
  330. (void)m.def("finalize_gpu_collective", &mindspore::device::gpu::CollectiveInitializer::FinalizeCollective,
  331. "Finalize gpu collective communication mode.");
  332. #else
  333. (void)m.def("init_gpu_collective", &mindspore::device::gpu::CollectiveFakeInitializer::InitCollective,
  334. "Init gpu collective communication mode.");
  335. (void)m.def("finalize_gpu_collective", &mindspore::device::gpu::CollectiveFakeInitializer::FinalizeCollective,
  336. "Finalize gpu collective communication mode.");
  337. #endif
  338. (void)py::class_<PSContext, std::shared_ptr<PSContext>>(m, "PSContext")
  339. .def_static("get_instance", &PSContext::instance, "Get PS context instance.")
  340. .def("set_ps_enable", &PSContext::SetPSEnable, "Set PS mode enabled or disabled.")
  341. .def("is_ps_mode", &PSContext::is_ps_mode, "Get PS mode enable-disable status.")
  342. .def("reset", &PSContext::Reset, "Reset PS context attributes.")
  343. .def("is_worker", &PSContext::is_worker, "Get whether the role of this process is Worker.")
  344. .def("is_server", &PSContext::is_server, "Get whether the role of this process is PServer.")
  345. .def("is_scheduler", &PSContext::is_scheduler, "Get whether the role of this process is Scheduler.")
  346. .def("ps_rank_id", &PSContext::ps_rank_id, "Get Worker and PServer rank id.")
  347. .def("insert_hash_table_size", &PSContext::InsertHashTableSize, "Insert hash table size.")
  348. .def("reinsert_hash_table_size", &PSContext::ReInsertHashTableSize,
  349. "Insert hash table size with new parameter name.")
  350. .def("insert_weight_init_info", &PSContext::InsertWeightInitInfo, "Insert embedding table initialization seed.")
  351. .def("insert_accumu_init_info", &PSContext::InsertAccumuInitInfo, "Insert accumulation initialization value.")
  352. .def("clone_hash_table", &PSContext::CloneHashTable, "Clone a hash table.")
  353. .def("set_cache_enable", &PSContext::set_cache_enable, "Set ps mode cache enable or not.")
  354. .def("set_rank_id", &PSContext::set_rank_id, "Set rank id for worker on ps mode.")
  355. .def("set_server_mode", &PSContext::set_server_mode, "Set server mode.")
  356. .def("server_mode", &PSContext::server_mode, "Get server mode.")
  357. .def("set_ms_role", &PSContext::set_ms_role, "Set role for this process.")
  358. .def("ms_role", &PSContext::ms_role, "Get role for this process.")
  359. .def("set_worker_num", &PSContext::set_worker_num, "Set worker number.")
  360. .def("worker_num", &PSContext::worker_num, "Get worker number.")
  361. .def("set_server_num", &PSContext::set_server_num, "Set server number.")
  362. .def("server_num", &PSContext::server_num, "Get server number.")
  363. .def("set_scheduler_ip", &PSContext::set_scheduler_ip, "Set scheduler ip.")
  364. .def("scheduler_ip", &PSContext::scheduler_ip, "Get scheduler ip.")
  365. .def("set_scheduler_port", &PSContext::set_scheduler_port, "Set scheduler port.")
  366. .def("scheduler_port", &PSContext::scheduler_port, "Get scheduler port.")
  367. .def("set_fl_server_port", &PSContext::set_fl_server_port, "Set federated learning server port.")
  368. .def("fl_server_port", &PSContext::fl_server_port, "Get federated learning server port.")
  369. .def("set_fl_client_enable", &PSContext::set_fl_client_enable, "Set federated learning client.")
  370. .def("fl_client_enable", &PSContext::fl_client_enable, "Get federated learning client.")
  371. .def("set_start_fl_job_threshold", &PSContext::set_start_fl_job_threshold,
  372. "Set threshold count for startFLJob round.")
  373. .def("start_fl_job_threshold", &PSContext::start_fl_job_threshold, "Get threshold count for startFLJob round.")
  374. .def("set_start_fl_job_time_window", &PSContext::set_start_fl_job_time_window,
  375. "Set time window for startFLJob round.")
  376. .def("start_fl_job_time_window", &PSContext::start_fl_job_time_window, "Get time window for startFLJob round.")
  377. .def("set_update_model_ratio", &PSContext::set_update_model_ratio,
  378. "Set threshold count ratio for updateModel round.")
  379. .def("update_model_ratio", &PSContext::update_model_ratio, "Get threshold count ratio for updateModel round.")
  380. .def("set_update_model_time_window", &PSContext::set_update_model_time_window,
  381. "Set time window for updateModel round.")
  382. .def("update_model_time_window", &PSContext::update_model_time_window, "Get time window for updateModel round.")
  383. .def("set_share_secrets_ratio", &PSContext::set_share_secrets_ratio,
  384. "Set threshold count ratio for share secrets round.")
  385. .def("share_secrets_ratio", &PSContext::share_secrets_ratio, "Get threshold count ratio for share secrets round.")
  386. .def("set_cipher_time_window", &PSContext::set_cipher_time_window, "Set time window for each cipher round.")
  387. .def("set_reconstruct_secrets_threshold", &PSContext::set_reconstruct_secrets_threshold,
  388. "Set threshold count for reconstruct secrets round.")
  389. .def("reconstruct_secrets_threshold", &PSContext::reconstruct_secrets_threshold,
  390. "Get threshold count for reconstruct secrets round.")
  391. .def("set_fl_name", &PSContext::set_fl_name, "Set federated learning name.")
  392. .def("fl_name", &PSContext::fl_name, "Get federated learning name.")
  393. .def("set_fl_iteration_num", &PSContext::set_fl_iteration_num, "Set federated learning iteration number.")
  394. .def("fl_iteration_num", &PSContext::fl_iteration_num, "Get federated learning iteration number.")
  395. .def("set_client_epoch_num", &PSContext::set_client_epoch_num, "Set federated learning client epoch number.")
  396. .def("client_epoch_num", &PSContext::client_epoch_num, "Get federated learning client epoch number.")
  397. .def("set_client_batch_size", &PSContext::set_client_batch_size, "Set federated learning client batch size.")
  398. .def("client_batch_size", &PSContext::client_batch_size, "Get federated learning client batch size.")
  399. .def("set_client_learning_rate", &PSContext::set_client_learning_rate,
  400. "Set worker's standalone training step number before communicating with server.")
  401. .def("client_learning_rate", &PSContext::client_learning_rate,
  402. "Get worker's standalone training step number before communicating with server.")
  403. .def("set_worker_step_num_per_iteration", &PSContext::set_worker_step_num_per_iteration,
  404. "Set federated learning client learning rate.")
  405. .def("worker_step_num_per_iteration", &PSContext::worker_step_num_per_iteration,
  406. "Get federated learning client learning rate.")
  407. .def("set_scheduler_manage_port", &PSContext::set_scheduler_manage_port,
  408. "Set scheduler manage port used to scale out/in.")
  409. .def("scheduler_manage_port", &PSContext::scheduler_manage_port, "Get scheduler manage port used to scale out/in.")
  410. .def("set_enable_ssl", &PSContext::set_enable_ssl, "Set PS SSL mode enabled or disabled.")
  411. .def("enable_ssl", &PSContext::enable_ssl, "Get PS SSL mode enabled or disabled.")
  412. .def("set_client_password", &PSContext::set_client_password, "Set the client password to decode the p12 file.")
  413. .def("client_password", &PSContext::client_password, "Get the client password to decode the p12 file.")
  414. .def("set_server_password", &PSContext::set_server_password, "Set the server password to decode the p12 file.")
  415. .def("server_password", &PSContext::server_password, "Get the server password to decode the p12 file.")
  416. .def("set_config_file_path", &PSContext::set_config_file_path,
  417. "Set configuration files required by the communication layer.")
  418. .def("config_file_path", &PSContext::config_file_path,
  419. "Get configuration files required by the communication layer.")
  420. .def("set_dp_eps", &PSContext::set_dp_eps, "Set dp epsilon for federated learning secure aggregation.")
  421. .def("set_dp_delta", &PSContext::set_dp_delta, "Set dp delta for federated learning secure aggregation.")
  422. .def("set_dp_norm_clip", &PSContext::set_dp_norm_clip,
  423. "Set dp norm clip for federated learning secure aggregation.")
  424. .def("set_encrypt_type", &PSContext::set_encrypt_type,
  425. "Set encrypt type for federated learning secure aggregation.");
  426. (void)m.def("_encrypt", &mindspore::pipeline::PyEncrypt, "Encrypt the data.");
  427. (void)m.def("_decrypt", &mindspore::pipeline::PyDecrypt, "Decrypt the data.");
  428. (void)m.def("_is_cipher_file", &mindspore::pipeline::PyIsCipherFile, "Determine whether the file is encrypted");
  429. #ifndef _WIN32
  430. (void)m.def("_export_bprop_mindir", &mindspore::ad::KPrim::ExportBpropMindir,
  431. "Export the backpropagation function to mindir file.");
  432. #endif
  433. }