Browse Source

!30491 ut for allgather fusion

Merge pull request !30491 from jiahongQian/master
feature/build-system-rewrite
i-robot Gitee 4 years ago
parent
commit
cfe0f76d2b
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 105 additions and 5 deletions
  1. +3
    -1
      docs/api/api_python/mindspore.context.rst
  2. +12
    -4
      mindspore/python/mindspore/context.py
  3. +90
    -0
      tests/ut/python/parallel/test_comm_fusion.py

+ 3
- 1
docs/api/api_python/mindspore.context.rst View File

@@ -220,7 +220,9 @@ MindSpore context,用于配置当前执行环境,包括执行模式、执行

- **comm_fusion** (dict) - 用于设置通信算子的融合配置。可以同一类型的通信算子按梯度张量的大小或者顺序分块传输。输入格式为{"通信类型": {"mode":str, "config": None int 或者 list}},每种通信算子的融合配置有两个键:"mode"和"config"。支持以下通信类型的融合类型和配置:

- allreduce: 进行allreduce算子的通信融合。"mode"包含:"auto"、"size"和"index"。在"auto"模式下,allreduce融合的是梯度变量的大小,默认值阈值为"64"MB,"config"对应的值为None。在"size"模式下,需要用户在config的字典中指定梯度大小阈值,这个值必须大于"0"MB。在"mode"为"index"时,它与"all_reduce_fusion_config"相同,用户需要给"config"传入一个列表,里面每个值表示梯度的索引。
- allreduce: 进行AllReduce算子的通信融合。"mode"包含:"auto"、"size"和"index"。在"auto"模式下,融合的是梯度变量的大小,默认值阈值为"64"MB,"config"对应的值为None。在"size"模式下,需要用户在config的字典中指定梯度大小阈值,这个值必须大于"0"MB。在"mode"为"index"时,它与"all_reduce_fusion_config"相同,用户需要给"config"传入一个列表,里面每个值表示梯度的索引。
- allgather: 进行AllGather算子的通信融合。"mode"包含:"auto"、"size"。"auto" 和 "size"模式的配置方式与AllReduce相同。
- reducescatter: 进行ReduceScatter算子的通信融合。"mode"包含:"auto"、"size"。"auto" 和 "size"模式的配置方式与AllReduce相同。

**异常:**



+ 12
- 4
mindspore/python/mindspore/context.py View File

@@ -514,11 +514,19 @@ def set_auto_parallel_context(**kwargs):
It supports following communication fusion types and configurations:

- allreduce: If communication fusion type is `allreduce`. The `mode` contains: `auto`, `size`
and `index`. In `auto` mode, allreduce fusion is configured by gradients size, and the default
fusion threshold is `64` MB. In 'size' mode, allreduce fusion is configured by gradients size
and `index`. In `auto` mode, AllReduce fusion is configured by gradients size and the default
fusion threshold is `64` MB. In 'size' mode, AllReduce fusion is configured by gradients size
manually, and the fusion threshold must be larger than `0` MB. In `index` mode, it is same as
`all_reduce_fusion_config`.

- allgather: If communication fusion type is `allgather`. The `mode` contains: `auto`, `size`.
In `auto` mode, AllGather fusion is configured by gradients size, and the default fusion
threshold is `64` MB. In 'size' mode, AllGather fusion is configured by gradients size
manually, and the fusion threshold must be larger than `0` MB.

- reducescatter: If communication fusion type is `reducescatter`. The `mode` contains: `auto`
and `size`. Config is same as `allgather`.

Raises:
ValueError: If input key is not attribute in auto parallel context.

@@ -540,8 +548,8 @@ def set_auto_parallel_context(**kwargs):
>>> context.set_auto_parallel_context(pipeline_stages=2)
>>> parallel_config = {"gradient_accumulation_shard": True}
>>> context.set_auto_parallel_context(parallel_optimizer_config=parallel_config, enable_parallel_optimizer=True)
>>> comm_fusion_config = {"allreduce": {"mode": "size", "config": 32}}
>>> context.set_auto_parallel_context(comm_fusion=comm_fusion_config)
>>> config = {"allreduce": {"mode": "size", "config": 32}, "allgather": {"mode": "size", "config": 32}}
>>> context.set_auto_parallel_context(comm_fusion=config)
"""
_set_auto_parallel_context(**kwargs)



+ 90
- 0
tests/ut/python/parallel/test_comm_fusion.py View File

@@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
import mindspore.nn as nn
from mindspore import context
@@ -23,6 +24,8 @@ from mindspore.common.initializer import initializer
from mindspore.train.model import Model
from mindspore.nn.wrap.cell_wrapper import PipelineCell
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from tests.ut.python.parallel.test_adafactor import compile_net
from tests.ut.python.parallel.test_adafactor import Net as Net2


class DatasetLenet():
@@ -146,3 +149,90 @@ def test_fusion_auto():
model.train(2, dataset, dataset_sink_mode=False)
assert auto_parallel_context().allgather_fusion_threshold_mb() == 64
assert auto_parallel_context().reducescatter_fusion_threshold_mb() == 64

def test_fusion_optimizer_parallel():
"""
Feature: test_fusion_optimizer_parallel in size mode
Description: allgather and reduce scatter size fusion in optimizer parallel
Expectation: compile success
"""
allgather_threshold = 16
reducescatter_threshold = 8
comm_fusion_dict = {"allgather": {"mode": "size", "config": allgather_threshold},
"reducescatter": {"mode": "size", "config": reducescatter_threshold}}
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0,
enable_parallel_optimizer=True, comm_fusion=comm_fusion_dict)
_w0 = Tensor(np.ones([64, 16, 2]), dtype=ms.float32)
_w1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
_w2 = Tensor(np.ones([32]), dtype=ms.float32)
strategy1 = ((4, 2), (2, 2))
strategy2 = ((4, 2), (2,))
net = Net2(_w0, _w1, _w2, strategy1, strategy2)
compile_net(net)

comm_fusion_dict = {"allgather": {"mode": "auto", "config": None},
"reducescatter": {"mode": "auto", "config": None}}
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0,
enable_parallel_optimizer=True, comm_fusion=comm_fusion_dict)
compile_net(net)

def test_allgather_fusion_invalid_value_failed():
"""
Feature: test_allgather_fusion with invalid value
Description: test_allgather_fusion with invalid value
Expectation: throw TypeError
"""
with pytest.raises(TypeError):
comm_fusion_dict = [1, 2]
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", comm_fusion=comm_fusion_dict)

with pytest.raises(TypeError):
comm_fusion_dict = {"allgather": [1, 2]}
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", comm_fusion=comm_fusion_dict)

with pytest.raises(TypeError):
comm_fusion_dict = {"allgather": {"mode": "size", "config": "30.12"}}
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", comm_fusion=comm_fusion_dict)

with pytest.raises(KeyError):
comm_fusion_dict = {"all": {"mode": "size", "config": 30}}
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", comm_fusion=comm_fusion_dict)

with pytest.raises(KeyError):
comm_fusion_dict = {"allgather": {"modes": "size", "config": 30}}
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", comm_fusion=comm_fusion_dict)

with pytest.raises(KeyError):
comm_fusion_dict = {"allgather": {"mode": "sizes", "config": 30}}
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", comm_fusion=comm_fusion_dict)

with pytest.raises(KeyError):
comm_fusion_dict = {"allgather": {"mode": "size"}}
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", comm_fusion=comm_fusion_dict)

def test_reducescatter_fusion_invalid_value_failed():
"""
Feature: test_reducescatter_fusion with invalid value
Description: test_reducescatter_fusion with invalid value
Expectation: throw TypeError
"""

with pytest.raises(TypeError):
comm_fusion_dict = {"reducescatter": [1, 2]}
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", comm_fusion=comm_fusion_dict)

with pytest.raises(TypeError):
comm_fusion_dict = {"reducescatter": {"mode": "size", "config": "30.12"}}
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", comm_fusion=comm_fusion_dict)

with pytest.raises(KeyError):
comm_fusion_dict = {"reducescatter": {"modes": "size", "config": 30}}
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", comm_fusion=comm_fusion_dict)

with pytest.raises(KeyError):
comm_fusion_dict = {"reducescatter": {"mode": "sizes", "config": 30}}
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", comm_fusion=comm_fusion_dict)

with pytest.raises(KeyError):
comm_fusion_dict = {"reducescatter": {"mode": "size"}}
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", comm_fusion=comm_fusion_dict)

Loading…
Cancel
Save