You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_ps_context.py 11 kB

5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Context for parameter server training mode"""
  16. import os
  17. from mindspore._checkparam import Validator
  18. from mindspore._c_expression import PSContext
  19. _ps_context = None
  20. def ps_context():
  21. """
  22. Get the global _ps_context, if it is not created, create a new one.
  23. Returns:
  24. _ps_context, the global parameter server training mode context.
  25. """
  26. global _ps_context
  27. if _ps_context is None:
  28. _ps_context = PSContext.get_instance()
  29. return _ps_context
  30. _set_ps_context_func_map = {
  31. "server_mode": ps_context().set_server_mode,
  32. "ms_role": ps_context().set_ms_role,
  33. "enable_ps": ps_context().set_ps_enable,
  34. "enable_fl": ps_context().set_ps_enable,
  35. "worker_num": ps_context().set_worker_num,
  36. "server_num": ps_context().set_server_num,
  37. "scheduler_ip": ps_context().set_scheduler_ip,
  38. "scheduler_port": ps_context().set_scheduler_port,
  39. "fl_server_port": ps_context().set_fl_server_port,
  40. "enable_fl_client": ps_context().set_fl_client_enable,
  41. "start_fl_job_threshold": ps_context().set_start_fl_job_threshold,
  42. "start_fl_job_time_window": ps_context().set_start_fl_job_time_window,
  43. "update_model_ratio": ps_context().set_update_model_ratio,
  44. "update_model_time_window": ps_context().set_update_model_time_window,
  45. "share_secrets_ratio": ps_context().set_share_secrets_ratio,
  46. "cipher_time_window": ps_context().set_cipher_time_window,
  47. "reconstruct_secrets_threshold": ps_context().set_reconstruct_secrets_threshold,
  48. "fl_name": ps_context().set_fl_name,
  49. "fl_iteration_num": ps_context().set_fl_iteration_num,
  50. "client_epoch_num": ps_context().set_client_epoch_num,
  51. "client_batch_size": ps_context().set_client_batch_size,
  52. "client_learning_rate": ps_context().set_client_learning_rate,
  53. "worker_step_num_per_iteration": ps_context().set_worker_step_num_per_iteration,
  54. "root_first_ca_path": ps_context().set_root_first_ca_path,
  55. "root_second_ca_path": ps_context().set_root_second_ca_path,
  56. "pki_verify": ps_context().set_pki_verify,
  57. "equip_crl_path": ps_context().set_equip_crl_path,
  58. "replay_attack_time_diff": ps_context().set_replay_attack_time_diff,
  59. "enable_ssl": ps_context().set_enable_ssl,
  60. "client_password": ps_context().set_client_password,
  61. "server_password": ps_context().set_server_password,
  62. "scheduler_manage_port": ps_context().set_scheduler_manage_port,
  63. "config_file_path": ps_context().set_config_file_path,
  64. "dp_eps": ps_context().set_dp_eps,
  65. "dp_delta": ps_context().set_dp_delta,
  66. "dp_norm_clip": ps_context().set_dp_norm_clip,
  67. "encrypt_type": ps_context().set_encrypt_type,
  68. "http_url_prefix": ps_context().set_http_url_prefix
  69. }
  70. _get_ps_context_func_map = {
  71. "server_mode": ps_context().server_mode,
  72. "ms_role": ps_context().ms_role,
  73. "enable_ps": ps_context().is_ps_mode,
  74. "enable_fl": ps_context().is_ps_mode,
  75. "worker_num": ps_context().worker_num,
  76. "server_num": ps_context().server_num,
  77. "scheduler_ip": ps_context().scheduler_ip,
  78. "scheduler_port": ps_context().scheduler_port,
  79. "fl_server_port": ps_context().fl_server_port,
  80. "enable_fl_client": ps_context().fl_client_enable,
  81. "start_fl_job_threshold": ps_context().start_fl_job_threshold,
  82. "start_fl_job_time_window": ps_context().start_fl_job_time_window,
  83. "update_model_ratio": ps_context().update_model_ratio,
  84. "update_model_time_window": ps_context().update_model_time_window,
  85. "share_secrets_ratio": ps_context().share_secrets_ratio,
  86. "cipher_time_window": ps_context().cipher_time_window,
  87. "reconstruct_secrets_threshold": ps_context().reconstruct_secrets_threshold,
  88. "fl_name": ps_context().fl_name,
  89. "fl_iteration_num": ps_context().fl_iteration_num,
  90. "client_epoch_num": ps_context().client_epoch_num,
  91. "client_batch_size": ps_context().client_batch_size,
  92. "client_learning_rate": ps_context().client_learning_rate,
  93. "worker_step_num_per_iteration": ps_context().worker_step_num_per_iteration,
  94. "dp_eps": ps_context().dp_eps,
  95. "dp_delta": ps_context().dp_delta,
  96. "dp_norm_clip": ps_context().dp_norm_clip,
  97. "encrypt_type": ps_context().encrypt_type,
  98. "root_first_ca_path": ps_context().root_first_ca_path,
  99. "root_second_ca_path": ps_context().root_second_ca_path,
  100. "pki_verify": ps_context().pki_verify,
  101. "equip_crl_path": ps_context().equip_crl_path,
  102. "replay_attack_time_diff": ps_context().replay_attack_time_diff,
  103. "enable_ssl": ps_context().enable_ssl,
  104. "client_password": ps_context().client_password,
  105. "server_password": ps_context().server_password,
  106. "scheduler_manage_port": ps_context().scheduler_manage_port,
  107. "config_file_path": ps_context().config_file_path,
  108. "http_url_prefix": ps_context().http_url_prefix
  109. }
  110. _check_positive_int_keys = ["server_num", "scheduler_port", "fl_server_port",
  111. "start_fl_job_threshold", "start_fl_job_time_window", "update_model_time_window",
  112. "fl_iteration_num", "client_epoch_num", "client_batch_size", "cipher_time_window",
  113. "reconstruct_secrets_threshold"]
  114. _check_non_negative_int_keys = ["worker_num"]
  115. _check_positive_float_keys = ["update_model_ratio", "client_learning_rate"]
  116. _check_port_keys = ["scheduler_port", "fl_server_port"]
  117. def _get_ps_mode_rank():
  118. ps_rank = ps_context().ps_rank_id()
  119. if ps_rank == -1:
  120. raise RuntimeError("The parameter server mode training is not enabled yet.")
  121. return ps_rank
  122. def _set_ps_context(**kwargs):
  123. """
  124. Set parameter server training mode context.
  125. Note:
  126. Some other environment variables should also be set for parameter server training mode.
  127. These environment variables are listed below:
  128. .. code-block::
  129. MS_SERVER_NUM # Server number
  130. MS_WORKER_NUM # Worker number
  131. MS_SCHED_HOST # Scheduler IP address
  132. MS_SCHED_PORT # Scheduler port
  133. MS_ROLE # The role of this process:
  134. # MS_SCHED represents the scheduler,
  135. # MS_WORKER represents the worker,
  136. # MS_PSERVER represents the Server
  137. Args:
  138. enable_ps (bool): Whether to enable parameter server training mode.
  139. Only after enable_ps is set True, the environment variables will be effective.
  140. Default: False.
  141. config_file_path (string): Configuration file path used by recovery. Default: ''.
  142. scheduler_manage_port (int): scheduler manage port used to scale out/in. Default: 11202.
  143. enable_ssl (bool): Set PS SSL mode enabled or disabled. Default: true.
  144. client_password (str): Password to decrypt the secret key stored in the client certificate.
  145. server_password (str): Password to decrypt the secret key stored in the server certificate.
  146. Raises:
  147. ValueError: If input key is not the attribute in parameter server training mode context.
  148. Examples:
  149. >>> context.set_ps_context(enable_ps=True, enable_ssl=True, client_password='123456', server_password='123456')
  150. """
  151. for key, value in kwargs.items():
  152. if key not in _set_ps_context_func_map:
  153. raise ValueError("Set PS context keyword %s is not recognized!" % key)
  154. _check_value(key, value)
  155. set_func = _set_ps_context_func_map[key]
  156. set_func(value)
  157. def _get_ps_context(attr_key):
  158. """
  159. Get parameter server training mode context attribute value according to the key.
  160. Args:
  161. attr_key (str): The key of the attribute.
  162. Returns:
  163. Returns attribute value according to the key.
  164. Raises:
  165. ValueError: If input key is not attribute in auto parallel context.
  166. """
  167. if attr_key not in _get_ps_context_func_map:
  168. raise ValueError("Get PS context keyword %s is not recognized!" % attr_key)
  169. get_func = _get_ps_context_func_map[attr_key]
  170. value = get_func()
  171. return value
  172. def _reset_ps_context():
  173. """
  174. Reset parameter server training mode context attributes to the default values:
  175. - enable_ps: False.
  176. """
  177. ps_context().reset()
  178. def _is_role_worker():
  179. return ps_context().is_worker()
  180. def _is_role_pserver():
  181. return ps_context().is_server()
  182. def _is_role_sched():
  183. return ps_context().is_scheduler()
  184. def _insert_hash_table_size(name, cache_vocab_size, embedding_size, vocab_size):
  185. ps_context().insert_hash_table_size(name, cache_vocab_size, embedding_size, vocab_size)
  186. def _reinsert_hash_table_size(new_name, cur_name, cache_vocab_size, embedding_size):
  187. ps_context().reinsert_hash_table_size(new_name, cur_name, cache_vocab_size, embedding_size)
  188. def _insert_weight_init_info(name, global_seed, op_seed):
  189. ps_context().insert_weight_init_info(name, global_seed, op_seed)
  190. def _insert_accumu_init_info(name, init_val):
  191. ps_context().insert_accumu_init_info(name, init_val)
  192. def _clone_hash_table(dest_param_name, src_param_name):
  193. ps_context().clone_hash_table(dest_param_name, src_param_name)
  194. def _set_cache_enable(cache_enable):
  195. # Environment variables are used to specify a maximum number of OpenBLAS threads:
  196. # In ubuntu(GPU) environment, numpy will use too many threads for computing,
  197. if cache_enable:
  198. os.environ['OPENBLAS_NUM_THREADS'] = '2'
  199. os.environ['GOTO_NUM_THREADS'] = '2'
  200. os.environ['OMP_NUM_THREADS'] = '2'
  201. ps_context().set_cache_enable(cache_enable)
  202. def _set_rank_id(rank_id):
  203. ps_context().set_rank_id(rank_id)
  204. def _check_value(key, value):
  205. """
  206. Validate the value for parameter server context keys.
  207. """
  208. if key in _check_positive_int_keys:
  209. Validator.check_positive_int(value, key)
  210. if key in _check_non_negative_int_keys:
  211. Validator.check_non_negative_int(value, key)
  212. if key in _check_positive_float_keys:
  213. Validator.check_positive_float(value, key)
  214. if key in _check_port_keys:
  215. if value < 1 or value > 65535:
  216. raise ValueError("The range of %s must be 1 to 65535, but got %d." % (key, value))