You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_mpi_config.py 3.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. The MPI config, used to configure the MPI environment.
  17. """
  18. import threading
  19. from mindspore._c_expression import MpiConfig
  20. from mindspore._checkparam import args_type_check
  21. class _MpiConfig:
  22. """
  23. _MpiConfig is the config tool for controlling MPI
  24. Note:
  25. Create a config through instantiating MpiConfig object is not recommended.
  26. should use MpiConfig() to get the config since MpiConfig is singleton.
  27. """
  28. _instance = None
  29. _instance_lock = threading.Lock()
  30. def __init__(self):
  31. self._mpiconfig_handle = MpiConfig.get_instance()
  32. def __new__(cls, *args, **kwargs):
  33. if cls._instance is None:
  34. cls._instance_lock.acquire()
  35. cls._instance = object.__new__(cls)
  36. cls._instance_lock.release()
  37. return cls._instance
  38. def __getattribute__(self, attr):
  39. value = object.__getattribute__(self, attr)
  40. if attr == "_mpiconfig_handle" and value is None:
  41. raise ValueError("mpiconfig handle is none in MpiConfig!!!")
  42. return value
  43. @property
  44. def enable_mpi(self):
  45. return self._mpiconfig_handle.get_enable_mpi()
  46. @enable_mpi.setter
  47. def enable_mpi(self, enable_mpi):
  48. self._mpiconfig_handle.set_enable_mpi(enable_mpi)
  49. _k_mpi_config = None
  50. def _mpi_config():
  51. """
  52. Get the global mpi config, if mpi config is not created, create a new one.
  53. Returns:
  54. _MpiConfig, the global mpi config.
  55. """
  56. global _k_mpi_config
  57. if _k_mpi_config is None:
  58. _k_mpi_config = _MpiConfig()
  59. return _k_mpi_config
  60. @args_type_check(enable_mpi=bool)
  61. def _set_mpi_config(**kwargs):
  62. """
  63. Sets mpi config for running environment.
  64. mpi config should be configured before running your program. If there is no configuration,
  65. mpi moudle will be disabled by default.
  66. Note:
  67. Attribute name is required for setting attributes.
  68. Args:
  69. enable_mpi (bool): Whether to enable mpi. Default: False.
  70. Raises:
  71. ValueError: If input key is not an attribute in mpi config.
  72. Examples:
  73. >>> mpiconfig.set_mpi_config(enable_mpi=True)
  74. """
  75. for key, value in kwargs.items():
  76. if not hasattr(_mpi_config(), key):
  77. raise ValueError("Set mpi config keyword %s is not recognized!" % key)
  78. setattr(_mpi_config(), key, value)
  79. def _get_mpi_config(attr_key):
  80. """
  81. Gets mpi config attribute value according to the input key.
  82. Args:
  83. attr_key (str): The key of the attribute.
  84. Returns:
  85. Object, The value of given attribute key.
  86. Raises:
  87. ValueError: If input key is not an attribute in config.
  88. """
  89. if not hasattr(_mpi_config(), attr_key):
  90. raise ValueError("Get context keyword %s is not recognized!" % attr_key)
  91. return getattr(_mpi_config(), attr_key)