Browse Source

Integrate two allreduce fusion set interfaces into one

tags/v0.3.0-alpha
yao_yf chang zherui 5 years ago
parent
commit
f57bd919e0
3 changed files with 13 additions and 9 deletions
  1. +1
    -3
      mindspore/parallel/__init__.py
  2. +10
    -4
      mindspore/parallel/_auto_parallel_context.py
  3. +2
    -2
      mindspore/parallel/_dp_allreduce_fusion.py

+ 1
- 3
mindspore/parallel/__init__.py View File

@@ -15,9 +15,7 @@
""" """
This interface is ONLY used in Auto-parallel procedure. This interface is ONLY used in Auto-parallel procedure.
""" """
from .dp_allreduce_fusion import set_fusion_strategy_by_idx, set_fusion_strategy_by_size
from .algo_parameter_config import get_algo_parameters, reset_algo_parameters, \ from .algo_parameter_config import get_algo_parameters, reset_algo_parameters, \
set_algo_parameters set_algo_parameters


__all__ = ["set_fusion_strategy_by_idx", "set_fusion_strategy_by_size", "get_algo_parameters",
"reset_algo_parameters", "set_algo_parameters"]
__all__ = ["get_algo_parameters", "reset_algo_parameters", "set_algo_parameters"]

+ 10
- 4
mindspore/parallel/_auto_parallel_context.py View File

@@ -14,6 +14,8 @@
# ============================================================================ # ============================================================================
"""Context of auto parallel""" """Context of auto parallel"""
import threading import threading
import mindspore.context as context
from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size
from mindspore._c_expression import AutoParallelContext from mindspore._c_expression import AutoParallelContext
from mindspore._extends.pynative_helper import args_type_check from mindspore._extends.pynative_helper import args_type_check


@@ -219,13 +221,15 @@ class _AutoParallelContext:
indices (list): Indices list. indices (list): Indices list.


Raises: Raises:
ValueError: If type of indices item is not int.
TypeError: If type of indices item is not int.
""" """
self.check_context_handle() self.check_context_handle()
for index in indices: for index in indices:
if not isinstance(index, int): if not isinstance(index, int):
raise TypeError('indices has invalid value') raise TypeError('indices has invalid value')
return self._context_handle.set_all_reduce_fusion_split_indices(indices)
self._context_handle.set_all_reduce_fusion_split_indices(indices)
if context.get_context("device_target") == "Ascend":
_set_fusion_strategy_by_idx(indices)


def get_all_reduce_fusion_split_indices(self): def get_all_reduce_fusion_split_indices(self):
"""Get allreduce fusion split indices.""" """Get allreduce fusion split indices."""
@@ -240,13 +244,15 @@ class _AutoParallelContext:
sizes (list): Sizes list. sizes (list): Sizes list.


Raises: Raises:
ValueError: If type of sizes item is not int.
TypeError: If type of sizes item is not int.
""" """
self.check_context_handle() self.check_context_handle()
for size in sizes: for size in sizes:
if not isinstance(size, int): if not isinstance(size, int):
raise TypeError('sizes has invalid value') raise TypeError('sizes has invalid value')
return self._context_handle.set_all_reduce_fusion_split_sizes(sizes)
self._context_handle.set_all_reduce_fusion_split_sizes(sizes)
if context.get_context("device_target") == "Ascend":
_set_fusion_strategy_by_size(sizes)


def get_all_reduce_fusion_split_sizes(self): def get_all_reduce_fusion_split_sizes(self):
"""Get allreduce fusion split sizes.""" """Get allreduce fusion split sizes."""


mindspore/parallel/dp_allreduce_fusion.py → mindspore/parallel/_dp_allreduce_fusion.py View File

@@ -43,7 +43,7 @@ def _c_array(ctype, values):
return (ctype * len(values))(*values) return (ctype * len(values))(*values)
def set_fusion_strategy_by_idx(idxList, group="hccl_world_group"):
def _set_fusion_strategy_by_idx(idxList, group="hccl_world_group"):
""" """
A function set gradient segment strategy according to the index list. A function set gradient segment strategy according to the index list.
@@ -100,7 +100,7 @@ def set_fusion_strategy_by_idx(idxList, group="hccl_world_group"):
raise RuntimeError('Allreduce split error') raise RuntimeError('Allreduce split error')
def set_fusion_strategy_by_size(dataSizeList, group="hccl_world_group"):
def _set_fusion_strategy_by_size(dataSizeList, group="hccl_world_group"):
""" """
A function set gradient segment strategy according to the data size percentage list. A function set gradient segment strategy according to the data size percentage list.

Loading…
Cancel
Save