You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_ps_context.py 3.7 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Context for parameter server training mode"""
  16. from mindspore._c_expression import PSContext
  17. _ps_context = None
  18. def ps_context():
  19. """
  20. Get the global _ps_context, if it is not created, create a new one.
  21. Returns:
  22. _ps_context, the global parameter server training mode context.
  23. """
  24. global _ps_context
  25. if _ps_context is None:
  26. _ps_context = PSContext.get_instance()
  27. return _ps_context
  28. _set_ps_context_func_map = {
  29. "enable_ps": ps_context().set_ps_enable
  30. }
  31. _get_ps_context_func_map = {
  32. "enable_ps": ps_context().is_ps_enabled
  33. }
  34. def _get_ps_mode_rank():
  35. ps_rank = ps_context().ps_rank_id()
  36. if ps_rank == -1:
  37. raise RuntimeError("The parameter server mode training is not enabled yet.")
  38. return ps_rank
  39. def _set_ps_context(**kwargs):
  40. """
  41. Set parameter server training mode context.
  42. Note:
  43. Some other environment variables should also be set for parameter server training mode.
  44. These environment variables are listed below:
  45. .. code-block::
  46. MS_SERVER_NUM # Server number
  47. MS_WORKER_NUM # Worker number
  48. MS_SCHED_HOST # Scheduler IP address
  49. MS_SCHED_PORT # Scheduler port
  50. MS_ROLE # The role of this process:
  51. # MS_SCHED represents the scheduler,
  52. # MS_WORKER represents the worker,
  53. # MS_PSERVER represents the Server
  54. Args:
  55. enable_ps (bool): Whether to enable parameter server training mode.
  56. Only after enable_ps is set True, the environment variables will be effective.
  57. Default: False.
  58. Raises:
  59. ValueError: If input key is not the attribute in parameter server training mode context.
  60. Examples:
  61. >>> context.set_ps_context(enable_ps=True)
  62. """
  63. for key, value in kwargs.items():
  64. if key not in _set_ps_context_func_map:
  65. raise ValueError("Set PS context keyword %s is not recognized!" % key)
  66. set_func = _set_ps_context_func_map[key]
  67. set_func(value)
  68. def _get_ps_context(attr_key):
  69. """
  70. Get parameter server training mode context attribute value according to the key.
  71. Args:
  72. attr_key (str): The key of the attribute.
  73. Returns:
  74. Returns attribute value according to the key.
  75. Raises:
  76. ValueError: If input key is not attribute in auto parallel context.
  77. """
  78. if attr_key not in _get_ps_context_func_map:
  79. raise ValueError("Get PS context keyword %s is not recognized!" % attr_key)
  80. get_func = _get_ps_context_func_map[attr_key]
  81. value = get_func()
  82. return value
  83. def _reset_ps_context():
  84. """
  85. Reset parameter server training mode context attributes to the default values:
  86. - enable_ps: False.
  87. """
  88. ps_context().reset()
  89. def _is_role_worker():
  90. return ps_context().is_role_worker()
  91. def _is_role_pserver():
  92. return ps_context().is_role_pserver()
  93. def _is_role_sched():
  94. return ps_context().is_role_sched()