Browse Source

!5782 change allreduce fusion function

Merge pull request !5782 from wangmin0104/master
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
4e7346509f
7 changed files with 148 additions and 247 deletions
  1. +1
    -1
      model_zoo/official/cv/resnet_thor/README.md
  2. +64
    -109
      model_zoo/official/cv/resnet_thor/src/grad_reducer_thor.py
  3. +12
    -9
      model_zoo/official/cv/resnet_thor/src/thor.py
  4. +1
    -5
      model_zoo/official/cv/resnet_thor/train.py
  5. +64
    -114
      tests/st/networks/models/resnet50/src_thor/grad_reducer_thor.py
  6. +5
    -4
      tests/st/networks/models/resnet50/src_thor/thor.py
  7. +1
    -5
      tests/st/networks/models/resnet50/test_resnet50_imagenet.py

+ 1
- 1
model_zoo/official/cv/resnet_thor/README.md View File

@@ -217,7 +217,7 @@ Inference result will be stored in the example path, whose folder name is "eval"
``` ```
Inference result will be stored in the example path, whose folder name is "eval". Under this, you can find result like the followings in log. Inference result will be stored in the example path, whose folder name is "eval". Under this, you can find result like the followings in log.
``` ```
result: {'top_5_accuracy': 0.9286771766965429, 'top_1_accuracy': 0.7613036171574904} ckpt=train_parallel/resnet-36_5004.ckpt
result: {'top_5_accuracy': 0.9287972151088348, 'top_1_accuracy': 0.7597031049935979} ckpt=train_parallel/resnet-36_5004.ckpt
``` ```


## Model Description ## Model Description


+ 64
- 109
model_zoo/official/cv/resnet_thor/src/grad_reducer_thor.py View File

@@ -12,149 +12,109 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""grad_reducer_thor"""
import mindspore.common.dtype as mstype
from mindspore.communication.management import GlobalComm, get_group_size
"""grad reducer cell for distributed training"""
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
from mindspore.communication.management import GlobalComm, get_group_size
from mindspore.ops import functional as F, composite as C, operations as P from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.ops.operations.comm_ops import AllReduce, ReduceOp
from mindspore.ops.operations.comm_ops import AllReduce
import mindspore.common.dtype as mstype


reduce_opt = C.MultitypeFuncGraph("reduce_opt") reduce_opt = C.MultitypeFuncGraph("reduce_opt")


_all_reduce_A = AllReduce()


def _init_allreduce_operators(length, split_indices):
""" initialize allreduce communication operators"""
indices = split_indices[0]
fusion = split_indices[1]
op_list = ()
j = 0
for i in range(length):
if j <= len(indices)-1:
temp = indices[j]
else:
temp = length
if i >= temp:
j = j + 1
fusion = fusion + 1
op = AllReduce('sum', GlobalComm.WORLD_COMM_GROUP)
op.add_prim_attr('fusion', fusion)
op_list = op_list + (op,)
return op_list


@reduce_opt.register("Function", "Number", "Function", "Tensor")
def _tensors_allreduce_mean(mul, degree, allreduce, parameters):
"""
Apply allreduce on parameters.


def _init_optimizer_allreduce(group):
global _all_reduce_A
_all_reduce_A = AllReduce(ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP)
_all_reduce_A.add_prim_attr('fusion', group)

Args:
mul(Primitive): The mul operator for parameters.
degree (int): The mean coefficient.
allreduce (Primitive): The communication operator for parameters.
parameters (Tensor): The parameters before operation.


@reduce_opt.register("Function", "Number", "Tensor")
def _tensors_allreduce_mean(mul, degree, grad):
degree = F.scalar_cast(degree, F.dtype(grad))
grad = _all_reduce_A(grad)
Returns:
Tensor, the parameters after operation.
"""
degree = F.scalar_cast(degree, F.dtype(parameters))
parameters = allreduce(parameters)
cast_op = P.Cast() cast_op = P.Cast()
return mul(grad, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(grad)))


@reduce_opt.register("Bool", "Tensor")
def _tensors_allreduce(allreduce_filter, grad):
if allreduce_filter:
return _all_reduce_A(grad)
return grad
return mul(parameters, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(parameters)))




_get_datatype = C.MultitypeFuncGraph("_get_datatype") _get_datatype = C.MultitypeFuncGraph("_get_datatype")




@_get_datatype.register("Tensor") @_get_datatype.register("Tensor")
def _tensors_get_datatype(grad):
def _tensors_get_datatype(parameters):
""" """
Acquire gradient datatype.
Acquire parameters datatype.


Args: Args:
grad (Tensor): The gradient tensor before operation.
parameters (Tensor): The parameters before operation.


Returns: Returns:
mstype, the datatype of gradient.
mstype, the datatype of parameters.
""" """
return F.dtype(grad)
return F.dtype(parameters)




_cast_datatype = C.MultitypeFuncGraph("_cast_datatype") _cast_datatype = C.MultitypeFuncGraph("_cast_datatype")




@_cast_datatype.register("TypeType", "Tensor") @_cast_datatype.register("TypeType", "Tensor")
def _tensors_cast_datatype(datatype, grad):
def _tensors_cast_datatype(datatype, parameters):
""" """
Cast gradient to datatype.
Cast parameters to datatype.


Args: Args:
datatype (mstype): the destination datatype of gradient.
grad (Tensor): The gradient tensor before operation.
datatype (mstype): the destination datatype of parameters.
parameters (Tensor): The parameters before operation.


Returns: Returns:
Tensor, the gradient tensor after operation.
Tensor, the parameters after operation.
""" """
return F.cast(grad, datatype)
return F.cast(parameters, datatype)




class DistributedGradReducerThor(Cell): class DistributedGradReducerThor(Cell):
""" """
A distributed optimizer. A distributed optimizer.


Constructs a gradient reducer Cell, which applies communication and average operations on
single-process gradient values.
Constructs a parameters reducer Cell, which applies communication and average operations on
single-process parameters values.


Args: Args:
parameters (list): the parameters to be updated.
mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. Default: False.
parameter_length (int): length of the parameters to be updated.
split_indices(tuple): parameter split indices.
mean (bool): When mean is true, the mean coefficient (degree) would apply on parameters. Default: False.
degree (int): The mean coefficient. Usually it equals to device number. Default: None. degree (int): The mean coefficient. Usually it equals to device number. Default: None.


Raises: Raises:
ValueError: If degree is not a int or less than 0. ValueError: If degree is not a int or less than 0.

Examples:
>>> from mindspore.communication import init, get_group_size
>>> from mindspore.ops import composite as C
>>> from mindspore.ops import operations as P
>>> from mindspore.ops import functional as F
>>> from mindspore import context
>>> from mindspore import nn
>>> from mindspore import ParameterTuple
>>> from mindspore.context import ParallelMode
>>>
>>> device_id = int(os.environ["DEVICE_ID"])
>>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True,
>>> device_id=int(device_id), enable_hccl=True)
>>> init()
>>> context.reset_auto_parallel_context()
>>> context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL)
>>>
>>>
>>> class TrainingWrapper(nn.Cell):
>>> def __init__(self, network, optimizer, sens=1.0):
>>> super(TrainingWrapper, self).__init__(auto_prefix=False)
>>> self.network = network
>>> self.network.add_flags(defer_inline=True)
>>> self.weights = ParameterTuple(network.trainable_params())
>>> self.optimizer = optimizer
>>> self.grad = C.GradOperation(get_by_list=True, sens_param=True)
>>> self.sens = sens
>>> self.reducer_flag = False
>>> self.grad_reducer = None
>>> self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
>>> if self.parallel_mode in [ParallelMode.DATA_PARALLEL,
>>> ParallelMode.HYBRID_PARALLEL]:
>>> self.reducer_flag = True
>>> if self.reducer_flag:
>>> mean = context.get_auto_parallel_context("gradients_mean")
>>> if mean.get_device_num_is_set():
>>> degree = context.get_auto_parallel_context("device_num")
>>> else:
>>> degree = get_group_size()
>>> self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
>>>
>>> def construct(self, *args):
>>> weights = self.weights
>>> loss = self.network(*args)
>>> sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
>>> grads = self.grad(self.network, weights)(*args, sens)
>>> if self.reducer_flag:
>>> # apply grad reducer on grads
>>> grads = self.grad_reducer(grads)
>>> return F.depend(loss, self.optimizer(grads))
>>>
>>> network = Net()
>>> optimizer = nn.Momentum(network.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> train_cell = TrainingWrapper(network, optimizer)
>>> inputs = Tensor(np.ones([16, 16]).astype(np.float32))
>>> label = Tensor(np.zeros([16, 16]).astype(np.float32))
>>> grads = train_cell(inputs, label)
""" """


def __init__(self, parameters, group, mean=True, degree=None):
def __init__(self, parameter_length, split_indices, mean=True, degree=None):
super(DistributedGradReducerThor, self).__init__(auto_prefix=False) super(DistributedGradReducerThor, self).__init__(auto_prefix=False)
self.hyper_map = C.HyperMap() self.hyper_map = C.HyperMap()
self.mul = P.Mul() self.mul = P.Mul()
@@ -165,16 +125,11 @@ class DistributedGradReducerThor(Cell):
raise ValueError("Parameter 'degree' in DistributedGradReducer should large than 0 and be int") raise ValueError("Parameter 'degree' in DistributedGradReducer should large than 0 and be int")
self.degree = degree self.degree = degree
self.mean = mean self.mean = mean
self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters)
_init_optimizer_allreduce(group)

def construct(self, grads):
# In some circumstances, the data precision of grads could be mixed with float16 and float32. Thus, the
# result of AllReduce is unreliable. To solve the problem, grads should be cast to float32 before AllReduce,
# and cast back after the operation.
datatypes = self.hyper_map(F.partial(_get_datatype), grads)
grads = self.hyper_map(F.partial(_cast_datatype, mstype.float32), grads)

new_grad = self.hyper_map(F.partial(reduce_opt, self.mul, self.degree), grads)
new_grad = self.hyper_map(F.partial(_cast_datatype), datatypes, new_grad)
return new_grad
self.op_list = _init_allreduce_operators(parameter_length, split_indices)

def construct(self, parameters):
datatypes = self.hyper_map(F.partial(_get_datatype), parameters)
parameters = self.hyper_map(F.partial(_cast_datatype, mstype.float32), parameters)
new_parameters = self.hyper_map(F.partial(reduce_opt, self.mul, self.degree), self.op_list, parameters)
new_parameters = self.hyper_map(F.partial(_cast_datatype), datatypes, new_parameters)
return new_parameters

+ 12
- 9
model_zoo/official/cv/resnet_thor/src/thor.py View File

@@ -22,7 +22,7 @@ import mindspore.common.dtype as mstype
from mindspore._checkparam import check_bool from mindspore._checkparam import check_bool
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
from mindspore.nn.optim.optimizer import Optimizer from mindspore.nn.optim.optimizer import Optimizer
from mindspore.parallel._utils import _get_device_num, _get_gradients_mean
from mindspore.parallel._utils import _get_device_num, _get_mirror_mean
from src.grad_reducer_thor import DistributedGradReducerThor from src.grad_reducer_thor import DistributedGradReducerThor


_momentum_opt = C.MultitypeFuncGraph("momentum_opt") _momentum_opt = C.MultitypeFuncGraph("momentum_opt")
@@ -85,10 +85,12 @@ class THOR_GPU(Optimizer):
self.assign = P.Assign() self.assign = P.Assign()
self.mul = P.Mul() self.mul = P.Mul()


mean = _get_gradients_mean()
mean = _get_mirror_mean()
degree = _get_device_num() degree = _get_device_num()
self.grad_reducer_thorA = DistributedGradReducerThor(self.parameters, 0, mean, degree)
self.grad_reducer_thorG = DistributedGradReducerThor(self.parameters, 0, mean, degree)

parameter_length = len(self.feature_map)
self.grad_reducer_thorA = DistributedGradReducerThor(parameter_length, ((parameter_length,), 0), mean, degree)
self.grad_reducer_thorG = DistributedGradReducerThor(parameter_length, ((parameter_length,), 0), mean, degree)
self.weight_decay = weight_decay self.weight_decay = weight_decay
self.decay_flags = tuple(decay_filter(x) for x in self.parameters) self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
self.update_gradient = P.UpdateThorGradient(split_dim=128) self.update_gradient = P.UpdateThorGradient(split_dim=128)
@@ -191,12 +193,13 @@ class THOR(Optimizer):
1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196,
1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49,
1.0] 1.0]
mean = _get_gradients_mean()
mean = _get_mirror_mean()
degree = _get_device_num() degree = _get_device_num()
self.grad_reducer_Amax = DistributedGradReducerThor(self.parameters, 2, mean, degree)
self.grad_reducer_Gmax = DistributedGradReducerThor(self.parameters, 5, mean, degree)
self.grad_reducer_A = DistributedGradReducerThor(self.parameters, 3, mean, degree)
self.grad_reducer_G = DistributedGradReducerThor(self.parameters, 4, mean, degree)
parameter_length = len(self.feature_map)
self.grad_reducer_Amax = DistributedGradReducerThor(parameter_length, ((27,), 2), mean, degree)
self.grad_reducer_Gmax = DistributedGradReducerThor(parameter_length, ((27,), 4), mean, degree)
self.grad_reducer_A = DistributedGradReducerThor(parameter_length, ((27,), 6), mean, degree)
self.grad_reducer_G = DistributedGradReducerThor(parameter_length, ((27,), 8), mean, degree)
self.matrix_A_inv = () self.matrix_A_inv = ()
self.matrix_G_inv = () self.matrix_G_inv = ()
self.matrix_max_inv = () self.matrix_max_inv = ()


+ 1
- 5
model_zoo/official/cv/resnet_thor/train.py View File

@@ -95,11 +95,7 @@ if __name__ == '__main__':
context.set_context(device_id=device_id, enable_auto_mixed_precision=True) context.set_context(device_id=device_id, enable_auto_mixed_precision=True)
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True) gradients_mean=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([107], "hccl_world_groupsum1")
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum2")
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum3")
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum4")
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum5")
auto_parallel_context().set_all_reduce_fusion_split_indices([107])
init() init()
# GPU target # GPU target
else: else:


+ 64
- 114
tests/st/networks/models/resnet50/src_thor/grad_reducer_thor.py View File

@@ -12,150 +12,109 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""grad_reducer_thor"""
import mindspore.common.dtype as mstype
from mindspore.communication.management import GlobalComm, get_group_size
"""grad reducer cell for distributed training"""
from mindspore.nn.cell import Cell from mindspore.nn.cell import Cell
from mindspore.communication.management import GlobalComm, get_group_size
from mindspore.ops import functional as F, composite as C, operations as P from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.ops.operations.comm_ops import AllReduce, ReduceOp
from mindspore.ops.operations.comm_ops import AllReduce
import mindspore.common.dtype as mstype


reduce_opt = C.MultitypeFuncGraph("reduce_opt") reduce_opt = C.MultitypeFuncGraph("reduce_opt")


_all_reduce_A = AllReduce()


def _init_allreduce_operators(length, split_indices):
""" initialize allreduce communication operators"""
indices = split_indices[0]
fusion = split_indices[1]
op_list = ()
j = 0
for i in range(length):
if j <= len(indices)-1:
temp = indices[j]
else:
temp = length
if i >= temp:
j = j + 1
fusion = fusion + 1
op = AllReduce('sum', GlobalComm.WORLD_COMM_GROUP)
op.add_prim_attr('fusion', fusion)
op_list = op_list + (op,)
return op_list


@reduce_opt.register("Function", "Number", "Function", "Tensor")
def _tensors_allreduce_mean(mul, degree, allreduce, parameters):
"""
Apply allreduce on parameters.


def _init_optimizer_allreduce(group):
global _all_reduce_A
_all_reduce_A = AllReduce(ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP)
_all_reduce_A.add_prim_attr('fusion', group)

Args:
mul(Primitive): The mul operator for parameters.
degree (int): The mean coefficient.
allreduce (Primitive): The communication operator for parameters.
parameters (Tensor): The parameters before operation.


@reduce_opt.register("Function", "Number", "Tensor")
def _tensors_allreduce_mean(mul, degree, grad):
degree = F.scalar_cast(degree, F.dtype(grad))
grad = _all_reduce_A(grad)
Returns:
Tensor, the parameters after operation.
"""
degree = F.scalar_cast(degree, F.dtype(parameters))
parameters = allreduce(parameters)
cast_op = P.Cast() cast_op = P.Cast()
return mul(grad, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(grad)))


@reduce_opt.register("Bool", "Tensor")
def _tensors_allreduce(allreduce_filter, grad):
if allreduce_filter:
return _all_reduce_A(grad)
return grad
return mul(parameters, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(parameters)))




_get_datatype = C.MultitypeFuncGraph("_get_datatype") _get_datatype = C.MultitypeFuncGraph("_get_datatype")




@_get_datatype.register("Tensor") @_get_datatype.register("Tensor")
def _tensors_get_datatype(grad):
def _tensors_get_datatype(parameters):
""" """
Acquire gradient datatype.
Acquire parameters datatype.


Args: Args:
grad (Tensor): The gradient tensor before operation.
parameters (Tensor): The parameters before operation.


Returns: Returns:
mstype, the datatype of gradient.
mstype, the datatype of parameters.
""" """
return F.dtype(grad)
return F.dtype(parameters)




_cast_datatype = C.MultitypeFuncGraph("_cast_datatype") _cast_datatype = C.MultitypeFuncGraph("_cast_datatype")




@_cast_datatype.register("TypeType", "Tensor") @_cast_datatype.register("TypeType", "Tensor")
def _tensors_cast_datatype(datatype, grad):
def _tensors_cast_datatype(datatype, parameters):
""" """
Cast gradient to datatype.
Cast parameters to datatype.


Args: Args:
datatype (mstype): the destination datatype of gradient.
grad (Tensor): The gradient tensor before operation.
datatype (mstype): the destination datatype of parameters.
parameters (Tensor): The parameters before operation.


Returns: Returns:
Tensor, the gradient tensor after operation.
Tensor, the parameters after operation.
""" """
return F.cast(grad, datatype)
return F.cast(parameters, datatype)




class DistributedGradReducerThor(Cell): class DistributedGradReducerThor(Cell):
""" """
A distributed optimizer. A distributed optimizer.


Constructs a gradient reducer Cell, which applies communication and average operations on
single-process gradient values.
Constructs a parameters reducer Cell, which applies communication and average operations on
single-process parameters values.


Args: Args:
parameters (list): the parameters to be updated.
group (int): the different group to allreduce.
mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. Default: False.
parameter_length (int): length of the parameters to be updated.
split_indices(tuple): parameter split indices.
mean (bool): When mean is true, the mean coefficient (degree) would apply on parameters. Default: False.
degree (int): The mean coefficient. Usually it equals to device number. Default: None. degree (int): The mean coefficient. Usually it equals to device number. Default: None.


Raises: Raises:
ValueError: If degree is not a int or less than 0. ValueError: If degree is not a int or less than 0.

Examples:
>>> from mindspore.communication import init, get_group_size
>>> from mindspore.ops import composite as C
>>> from mindspore.ops import operations as P
>>> from mindspore.ops import functional as F
>>> from mindspore import context
>>> from mindspore import nn
>>> from mindspore import ParameterTuple
>>> from mindspore.context import ParallelMode
>>>
>>> device_id = int(os.environ["DEVICE_ID"])
>>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True,
>>> device_id=int(device_id), enable_hccl=True)
>>> init()
>>> context.reset_auto_parallel_context()
>>> context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL)
>>>
>>>
>>> class TrainingWrapper(nn.Cell):
>>> def __init__(self, network, optimizer, sens=1.0):
>>> super(TrainingWrapper, self).__init__(auto_prefix=False)
>>> self.network = network
>>> self.network.add_flags(defer_inline=True)
>>> self.weights = ParameterTuple(network.trainable_params())
>>> self.optimizer = optimizer
>>> self.grad = C.GradOperation(get_by_list=True, sens_param=True)
>>> self.sens = sens
>>> self.reducer_flag = False
>>> self.grad_reducer = None
>>> self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
>>> if self.parallel_mode in [ParallelMode.DATA_PARALLEL,
>>> ParallelMode.HYBRID_PARALLEL]:
>>> self.reducer_flag = True
>>> if self.reducer_flag:
>>> mean = context.get_auto_parallel_context("gradients_mean")
>>> if mean.get_device_num_is_set():
>>> degree = context.get_auto_parallel_context("device_num")
>>> else:
>>> degree = get_group_size()
>>> self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
>>>
>>> def construct(self, *args):
>>> weights = self.weights
>>> loss = self.network(*args)
>>> sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
>>> grads = self.grad(self.network, weights)(*args, sens)
>>> if self.reducer_flag:
>>> # apply grad reducer on grads
>>> grads = self.grad_reducer(grads)
>>> return F.depend(loss, self.optimizer(grads))
>>>
>>> network = Net()
>>> optimizer = nn.Momentum(network.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> train_cell = TrainingWrapper(network, optimizer)
>>> inputs = Tensor(np.ones([16, 16]).astype(np.float32))
>>> label = Tensor(np.zeros([16, 16]).astype(np.float32))
>>> grads = train_cell(inputs, label)
""" """


def __init__(self, parameters, group, mean=True, degree=None):
def __init__(self, parameter_length, split_indices, mean=True, degree=None):
super(DistributedGradReducerThor, self).__init__(auto_prefix=False) super(DistributedGradReducerThor, self).__init__(auto_prefix=False)
self.hyper_map = C.HyperMap() self.hyper_map = C.HyperMap()
self.mul = P.Mul() self.mul = P.Mul()
@@ -166,20 +125,11 @@ class DistributedGradReducerThor(Cell):
raise ValueError("Parameter 'degree' in DistributedGradReducer should large than 0 and be int") raise ValueError("Parameter 'degree' in DistributedGradReducer should large than 0 and be int")
self.degree = degree self.degree = degree
self.mean = mean self.mean = mean
self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters)
_init_optimizer_allreduce(group)

def construct(self, grads):
# In some circumstances, the data precision of grads could be mixed with float16 and float32. Thus, the
# result of AllReduce is unreliable. To solve the problem, grads should be cast to float32 before AllReduce,
# and cast back after the operation.
datatypes = self.hyper_map(F.partial(_get_datatype), grads)
grads = self.hyper_map(F.partial(_cast_datatype, mstype.float32), grads)

if self.mean:
new_grad = self.hyper_map(F.partial(reduce_opt, self.mul, self.degree), grads)
else:
new_grad = self.hyper_map(F.partial(reduce_opt), self.allreduce_filter, grads)

new_grad = self.hyper_map(F.partial(_cast_datatype), datatypes, new_grad)
return new_grad
self.op_list = _init_allreduce_operators(parameter_length, split_indices)

def construct(self, parameters):
datatypes = self.hyper_map(F.partial(_get_datatype), parameters)
parameters = self.hyper_map(F.partial(_cast_datatype, mstype.float32), parameters)
new_parameters = self.hyper_map(F.partial(reduce_opt, self.mul, self.degree), self.op_list, parameters)
new_parameters = self.hyper_map(F.partial(_cast_datatype), datatypes, new_parameters)
return new_parameters

+ 5
- 4
tests/st/networks/models/resnet50/src_thor/thor.py View File

@@ -89,10 +89,11 @@ class THOR(Optimizer):
1.0] 1.0]
mean = _get_gradients_mean() mean = _get_gradients_mean()
degree = _get_device_num() degree = _get_device_num()
self.grad_reducer_Amax = DistributedGradReducerThor(self.parameters, 2, mean, degree)
self.grad_reducer_Gmax = DistributedGradReducerThor(self.parameters, 5, mean, degree)
self.grad_reducer_A = DistributedGradReducerThor(self.parameters, 3, mean, degree)
self.grad_reducer_G = DistributedGradReducerThor(self.parameters, 4, mean, degree)
parameter_length = len(self.feature_map)
self.grad_reducer_Amax = DistributedGradReducerThor(parameter_length, ((27,), 2), mean, degree)
self.grad_reducer_Gmax = DistributedGradReducerThor(parameter_length, ((27,), 4), mean, degree)
self.grad_reducer_A = DistributedGradReducerThor(parameter_length, ((27,), 6), mean, degree)
self.grad_reducer_G = DistributedGradReducerThor(parameter_length, ((27,), 8), mean, degree)
self.matrix_A_inv = () self.matrix_A_inv = ()
self.matrix_G_inv = () self.matrix_G_inv = ()
self.matrix_max_inv = () self.matrix_max_inv = ()


+ 1
- 5
tests/st/networks/models/resnet50/test_resnet50_imagenet.py View File

@@ -241,11 +241,7 @@ def train_process_thor(q, device_id, epoch_size, device_num, enable_hccl):
if enable_hccl: if enable_hccl:
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True, parameter_broadcast=True) gradients_mean=True, parameter_broadcast=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([107], "hccl_world_groupsum1")
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum2")
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum3")
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum4")
auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum5")
auto_parallel_context().set_all_reduce_fusion_split_indices([107])
init() init()


# network # network


Loading…
Cancel
Save