|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """comm_helper"""
-
- import os
- from ._hccl_management import load_lib as hccl_load_lib
-
- _HCCL_AVAILABLE = False
- _NCCL_AVAILABLE = False
- try:
- import mindspore._ms_mpi as mpi
- _NCCL_AVAILABLE = True
- except ImportError:
- _NCCL_AVAILABLE = False
-
-
- try:
- hccl_load_lib()
- _HCCL_AVAILABLE = True
- except RuntimeError:
- _HCCL_AVAILABLE = False
-
- if _HCCL_AVAILABLE:
- from . import _hccl_management as hccl
- else:
- try:
- import hccl_test.manage.api as hccl
- _HCCL_AVAILABLE = True
- except ImportError:
- _HCCL_AVAILABLE = False
-
-
- HCCL_WORLD_COMM_GROUP = "hccl_world_group"
- NCCL_WORLD_COMM_GROUP = "nccl_world_group"
- MS_ROLE = os.getenv("MS_ROLE")
-
- class Backend:
- """
- Class for available backends.
-
- Note:
- The backends' value should be string, e.g., "hccl".
- If backend is set to Backend.UNDEFINED, it will be seen as invaliad.
-
- Args:
- name (str): The name of backend.
-
- Raises:
- TypeError: If name is not a string.
- ValueError: If backend is invalid.
-
- Examples:
- >>> Backend("abc")
- >>> hccl = Backend("hccl")
- """
- UNDEFINED = "undefined"
- HCCL = "hccl"
- NCCL = "nccl"
-
- def __new__(cls, name):
- """Create instance object of Backend."""
- if not isinstance(name, str):
- raise TypeError("Backend name must be a string, but got {}".format(type(name)))
- value = getattr(Backend, name.upper(), Backend.UNDEFINED)
- if value == Backend.UNDEFINED:
- raise ValueError("Invalid backend: '{}'".format(name))
- return value
-
-
- def is_hccl_available():
- """
- Check hccl api is available.
-
- Returns:
- Boolean. Return whether hccl is available or not.
- """
- return _HCCL_AVAILABLE
-
-
- def is_nccl_available():
- """
- Check nccl api is available.
-
- Returns:
- Boolean. Return whether nccl is available or not.
- """
- return _NCCL_AVAILABLE
-
-
- def check_parameter_available(func):
- """
- Check parameter is available. If not available, raise Error.
-
- Args:
- func (Function): The function to be run.
-
- Raises:
- RuntimeError.
-
- Returns:
- Wrapper. If not available, raise Error.
- """
- def wrapper(*args, **kargs):
- if MS_ROLE in ("MS_PSERVER", "MS_SCHED"):
- return func(*args, **kargs)
- group = None
- if "group" in kargs.keys():
- group = kargs.get("group")
- if group is not None and not isinstance(group, str):
- raise TypeError("Group should be str or None, "
- "but got group {}".format(type(group)))
-
- if "backend" in kargs.keys():
- backend = kargs.get("backend")
- if backend is Backend.HCCL and not is_hccl_available():
- raise RuntimeError("Distributed Communication doesn't have HCCL built in")
- if backend is Backend.NCCL and not is_nccl_available():
- raise RuntimeError("Distributed Communication doesn't have NCCL built in")
-
- if group is None:
- if backend is Backend.HCCL:
- group = HCCL_WORLD_COMM_GROUP
- elif backend is Backend.NCCL:
- group = NCCL_WORLD_COMM_GROUP
- return func(*args, **kargs)
- return wrapper
-
-
- @check_parameter_available
- def _get_rank_helper(group, backend):
- """
- The Helper to do get_rank_id.
-
- Args:
- group (str): The communication group.
- backend (str): The backend, like "hccl".
-
- Raises:
- ValueError: If backend is invalid.
-
- Returns:
- Integer. The local rank id of the calling process.
- """
- rank_id = None
- if MS_ROLE in ("MS_PSERVER", "MS_SCHED"):
- rank_id = 0
- return rank_id
- if backend == Backend.HCCL:
- if group == HCCL_WORLD_COMM_GROUP:
- rank_id = hccl.get_rank_id()
- else:
- rank_id = hccl.get_rank_id(group)
- elif backend == Backend.NCCL:
- rank_id = mpi.get_rank_id(group)
- else:
- raise ValueError("Invalid backend: '{}'".format(backend))
- return rank_id
-
-
- @check_parameter_available
- def _get_local_rank_helper(group, backend):
- """
- The Helper to do get_local_rank_id.
-
- Args:
- group (str): The communication group.
- backend (str): The backend, like "hccl".
-
- Raises:
- ValueError: If backend is invalid.
-
- Returns:
- Integer. The local rank id of the calling process.
- """
- rank_id = None
- if backend == Backend.HCCL:
- if group == HCCL_WORLD_COMM_GROUP:
- rank_id = hccl.get_local_rank_id()
- else:
- rank_id = hccl.get_local_rank_id(group)
- elif backend == Backend.NCCL:
- raise RuntimeError("Nccl doesn't support get_local_rank_id now.")
- else:
- raise ValueError("Invalid backend: '{}'".format(backend))
- return rank_id
-
-
- @check_parameter_available
- def _get_size_helper(group, backend):
- """
- The Helper to do get_rank_size.
-
- Args:
- group (str): The communication group.
- backend (str): The backend, like "hccl".
-
- Raises:
- ValueError: If backend is invalid.
-
- Returns:
- Integer. The rank size of specified group.
- """
- size = None
- if MS_ROLE in ("MS_PSERVER", "MS_SCHED"):
- size = 1
- return size
- if backend == Backend.HCCL:
- if group == HCCL_WORLD_COMM_GROUP:
- size = hccl.get_rank_size()
- else:
- size = hccl.get_rank_size(group)
- elif backend == Backend.NCCL:
- size = mpi.get_rank_size(group)
- else:
- raise ValueError("Invalid backend: '{}'".format(backend))
- return size
-
-
- @check_parameter_available
- def _get_local_size_helper(group, backend):
- """
- The Helper to do get_local_rank_size.
-
- Args:
- group (str): The communication group.
- backend (str): The backend, like "hccl".
-
- Raises:
- ValueError: If backend is invalid.
-
- Returns:
- Integer. The local rank size where the calling process is being within specified group.
- """
- size = None
- if backend == Backend.HCCL:
- if group == HCCL_WORLD_COMM_GROUP:
- size = hccl.get_local_rank_size()
- else:
- size = hccl.get_local_rank_size(group)
- elif backend == Backend.NCCL:
- raise RuntimeError("Nccl doesn't support get_local_rank_size now.")
- else:
- raise ValueError("Invalid backend: '{}'".format(backend))
- return size
-
-
- @check_parameter_available
- def _get_world_rank_from_group_rank_helper(group, group_rank_id, backend):
- """
- The Helper to do get_world_rank_from_group_rank.
-
- Args:
- group (str): The user communication group.
- group_rank_id (int): A rank id in user communication group.
- backend (str): The backend, like "hccl".
-
- Raises:
- TypeError: If group_rank_id is not int.
- ValueError: If group is "hccl_world_group" or backend is invalid.
-
- Returns:
- Integer. A rank id in world communication group.
- """
- world_rank_id = None
- if not isinstance(group_rank_id, int):
- raise TypeError("group_rank_id should be int, but got type {}".format(type(group_rank_id)))
- if backend == Backend.HCCL:
- if group == HCCL_WORLD_COMM_GROUP:
- raise ValueError("Group cannot be 'hccl_world_group'. ")
- world_rank_id = hccl.get_world_rank_from_group_rank(group, group_rank_id)
- elif backend == Backend.NCCL:
- raise RuntimeError("Nccl doesn't support get_world_rank_from_group_rank now.")
- else:
- raise ValueError("Invalid backend: '{}'".format(backend))
- return world_rank_id
-
-
- @check_parameter_available
- def _get_group_rank_from_world_rank_helper(world_rank_id, group, backend):
- """
- The Helper to do get_group_rank_from_world_rank.
-
- Args:
- world_rank_id (int): A rank id in world communication group.
- group (str): The user communication group.
- backend (str): The backend, like "hccl".
-
- Raises:
- TypeError: If world_rank_id is not int.
- ValueError: If group is 'hccl_world_group' or backend is invalid.
-
- Returns:
- Integer. A rank id in user communication group.
- """
- group_rank_id = None
- if not isinstance(world_rank_id, int):
- raise TypeError("world_rank_id should be int, but got type {}".format(type(world_rank_id)))
- if backend == Backend.HCCL:
- if group == HCCL_WORLD_COMM_GROUP:
- raise ValueError("Group cannot be 'hccl_world_group'. ")
- group_rank_id = hccl.get_group_rank_from_world_rank(world_rank_id, group)
- elif backend == Backend.NCCL:
- raise RuntimeError("Nccl doesn't support get_group_rank_from_world_rank now.")
- else:
- raise ValueError("Invalid backend: '{}'".format(backend))
- return group_rank_id
-
-
- @check_parameter_available
- def _create_group_helper(group, rank_ids, backend):
- """
- The Helper to do create_group.
-
- Args:
- group (str): The communication group.
- rank_ids (list): Rank ids in the group.
- backend (str): The backend, like "hccl".
-
- Raises:
- TypeError: If rank_ids is not a list.
- ValueError: If rank_ids size is not larger than 1 or rank_ids has duplicate data or backend is invalid.
- """
- if backend == Backend.HCCL:
- if not isinstance(rank_ids, list):
- raise TypeError("Rank_ids {} should be list".format(rank_ids))
- rank_size = len(rank_ids)
- if rank_size < 1:
- raise ValueError("Rank_ids size {} should be large than 0".format(rank_size))
- if len(rank_ids) - len(list(set(rank_ids))) > 0:
- raise ValueError("List rank_ids in Group {} has duplicate data!".format(group))
- hccl.create_group(group, rank_size, rank_ids)
- elif backend == Backend.NCCL:
- raise RuntimeError("Nccl doesn't support create_group now.")
- else:
- raise ValueError("Invalid backend: '{}'".format(backend))
-
-
- @check_parameter_available
- def _destroy_group_helper(group, backend):
- """
- The Helper to do destroy_group.
-
- Args:
- group (str): The user communication group.
- backend (str): The backend, like "hccl".
-
- Raises:
- ValueError: If group is "hccl_world_group" or backend is invalid.
- """
- if backend == Backend.HCCL:
- if group == HCCL_WORLD_COMM_GROUP:
- raise ValueError("The hccl_world_group does not support destruction.")
- hccl.destroy_group(group)
- elif backend == Backend.NCCL:
- raise RuntimeError("Nccl doesn't support destroy_group now.")
- else:
- raise ValueError("Invalid backend: '{}'".format(backend))
|