Merge pull request !5782 from wangmin0104/mastertags/v1.0.0
| @@ -217,7 +217,7 @@ Inference result will be stored in the example path, whose folder name is "eval" | |||||
| ``` | ``` | ||||
| Inference result will be stored in the example path, whose folder name is "eval". Under this, you can find result like the followings in log. | Inference result will be stored in the example path, whose folder name is "eval". Under this, you can find result like the followings in log. | ||||
| ``` | ``` | ||||
| result: {'top_5_accuracy': 0.9286771766965429, 'top_1_accuracy': 0.7613036171574904} ckpt=train_parallel/resnet-36_5004.ckpt | |||||
| result: {'top_5_accuracy': 0.9287972151088348, 'top_1_accuracy': 0.7597031049935979} ckpt=train_parallel/resnet-36_5004.ckpt | |||||
| ``` | ``` | ||||
| ## Model Description | ## Model Description | ||||
| @@ -12,149 +12,109 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """grad_reducer_thor""" | |||||
| import mindspore.common.dtype as mstype | |||||
| from mindspore.communication.management import GlobalComm, get_group_size | |||||
| """grad reducer cell for distributed training""" | |||||
| from mindspore.nn.cell import Cell | from mindspore.nn.cell import Cell | ||||
| from mindspore.communication.management import GlobalComm, get_group_size | |||||
| from mindspore.ops import functional as F, composite as C, operations as P | from mindspore.ops import functional as F, composite as C, operations as P | ||||
| from mindspore.ops.operations.comm_ops import AllReduce, ReduceOp | |||||
| from mindspore.ops.operations.comm_ops import AllReduce | |||||
| import mindspore.common.dtype as mstype | |||||
| reduce_opt = C.MultitypeFuncGraph("reduce_opt") | reduce_opt = C.MultitypeFuncGraph("reduce_opt") | ||||
| _all_reduce_A = AllReduce() | |||||
| def _init_allreduce_operators(length, split_indices): | |||||
| """ initialize allreduce communication operators""" | |||||
| indices = split_indices[0] | |||||
| fusion = split_indices[1] | |||||
| op_list = () | |||||
| j = 0 | |||||
| for i in range(length): | |||||
| if j <= len(indices)-1: | |||||
| temp = indices[j] | |||||
| else: | |||||
| temp = length | |||||
| if i >= temp: | |||||
| j = j + 1 | |||||
| fusion = fusion + 1 | |||||
| op = AllReduce('sum', GlobalComm.WORLD_COMM_GROUP) | |||||
| op.add_prim_attr('fusion', fusion) | |||||
| op_list = op_list + (op,) | |||||
| return op_list | |||||
| @reduce_opt.register("Function", "Number", "Function", "Tensor") | |||||
| def _tensors_allreduce_mean(mul, degree, allreduce, parameters): | |||||
| """ | |||||
| Apply allreduce on parameters. | |||||
| def _init_optimizer_allreduce(group): | |||||
| global _all_reduce_A | |||||
| _all_reduce_A = AllReduce(ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP) | |||||
| _all_reduce_A.add_prim_attr('fusion', group) | |||||
| Args: | |||||
| mul(Primitive): The mul operator for parameters. | |||||
| degree (int): The mean coefficient. | |||||
| allreduce (Primitive): The communication operator for parameters. | |||||
| parameters (Tensor): The parameters before operation. | |||||
| @reduce_opt.register("Function", "Number", "Tensor") | |||||
| def _tensors_allreduce_mean(mul, degree, grad): | |||||
| degree = F.scalar_cast(degree, F.dtype(grad)) | |||||
| grad = _all_reduce_A(grad) | |||||
| Returns: | |||||
| Tensor, the parameters after operation. | |||||
| """ | |||||
| degree = F.scalar_cast(degree, F.dtype(parameters)) | |||||
| parameters = allreduce(parameters) | |||||
| cast_op = P.Cast() | cast_op = P.Cast() | ||||
| return mul(grad, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(grad))) | |||||
| @reduce_opt.register("Bool", "Tensor") | |||||
| def _tensors_allreduce(allreduce_filter, grad): | |||||
| if allreduce_filter: | |||||
| return _all_reduce_A(grad) | |||||
| return grad | |||||
| return mul(parameters, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(parameters))) | |||||
| _get_datatype = C.MultitypeFuncGraph("_get_datatype") | _get_datatype = C.MultitypeFuncGraph("_get_datatype") | ||||
| @_get_datatype.register("Tensor") | @_get_datatype.register("Tensor") | ||||
| def _tensors_get_datatype(grad): | |||||
| def _tensors_get_datatype(parameters): | |||||
| """ | """ | ||||
| Acquire gradient datatype. | |||||
| Acquire parameters datatype. | |||||
| Args: | Args: | ||||
| grad (Tensor): The gradient tensor before operation. | |||||
| parameters (Tensor): The parameters before operation. | |||||
| Returns: | Returns: | ||||
| mstype, the datatype of gradient. | |||||
| mstype, the datatype of parameters. | |||||
| """ | """ | ||||
| return F.dtype(grad) | |||||
| return F.dtype(parameters) | |||||
| _cast_datatype = C.MultitypeFuncGraph("_cast_datatype") | _cast_datatype = C.MultitypeFuncGraph("_cast_datatype") | ||||
| @_cast_datatype.register("TypeType", "Tensor") | @_cast_datatype.register("TypeType", "Tensor") | ||||
| def _tensors_cast_datatype(datatype, grad): | |||||
| def _tensors_cast_datatype(datatype, parameters): | |||||
| """ | """ | ||||
| Cast gradient to datatype. | |||||
| Cast parameters to datatype. | |||||
| Args: | Args: | ||||
| datatype (mstype): the destination datatype of gradient. | |||||
| grad (Tensor): The gradient tensor before operation. | |||||
| datatype (mstype): the destination datatype of parameters. | |||||
| parameters (Tensor): The parameters before operation. | |||||
| Returns: | Returns: | ||||
| Tensor, the gradient tensor after operation. | |||||
| Tensor, the parameters after operation. | |||||
| """ | """ | ||||
| return F.cast(grad, datatype) | |||||
| return F.cast(parameters, datatype) | |||||
| class DistributedGradReducerThor(Cell): | class DistributedGradReducerThor(Cell): | ||||
| """ | """ | ||||
| A distributed optimizer. | A distributed optimizer. | ||||
| Constructs a gradient reducer Cell, which applies communication and average operations on | |||||
| single-process gradient values. | |||||
| Constructs a parameters reducer Cell, which applies communication and average operations on | |||||
| single-process parameters values. | |||||
| Args: | Args: | ||||
| parameters (list): the parameters to be updated. | |||||
| mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. Default: False. | |||||
| parameter_length (int): length of the parameters to be updated. | |||||
| split_indices(tuple): parameter split indices. | |||||
| mean (bool): When mean is true, the mean coefficient (degree) would apply on parameters. Default: False. | |||||
| degree (int): The mean coefficient. Usually it equals to device number. Default: None. | degree (int): The mean coefficient. Usually it equals to device number. Default: None. | ||||
| Raises: | Raises: | ||||
| ValueError: If degree is not a int or less than 0. | ValueError: If degree is not a int or less than 0. | ||||
| Examples: | |||||
| >>> from mindspore.communication import init, get_group_size | |||||
| >>> from mindspore.ops import composite as C | |||||
| >>> from mindspore.ops import operations as P | |||||
| >>> from mindspore.ops import functional as F | |||||
| >>> from mindspore import context | |||||
| >>> from mindspore import nn | |||||
| >>> from mindspore import ParameterTuple | |||||
| >>> from mindspore.context import ParallelMode | |||||
| >>> | |||||
| >>> device_id = int(os.environ["DEVICE_ID"]) | |||||
| >>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, | |||||
| >>> device_id=int(device_id), enable_hccl=True) | |||||
| >>> init() | |||||
| >>> context.reset_auto_parallel_context() | |||||
| >>> context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL) | |||||
| >>> | |||||
| >>> | |||||
| >>> class TrainingWrapper(nn.Cell): | |||||
| >>> def __init__(self, network, optimizer, sens=1.0): | |||||
| >>> super(TrainingWrapper, self).__init__(auto_prefix=False) | |||||
| >>> self.network = network | |||||
| >>> self.network.add_flags(defer_inline=True) | |||||
| >>> self.weights = ParameterTuple(network.trainable_params()) | |||||
| >>> self.optimizer = optimizer | |||||
| >>> self.grad = C.GradOperation(get_by_list=True, sens_param=True) | |||||
| >>> self.sens = sens | |||||
| >>> self.reducer_flag = False | |||||
| >>> self.grad_reducer = None | |||||
| >>> self.parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||||
| >>> if self.parallel_mode in [ParallelMode.DATA_PARALLEL, | |||||
| >>> ParallelMode.HYBRID_PARALLEL]: | |||||
| >>> self.reducer_flag = True | |||||
| >>> if self.reducer_flag: | |||||
| >>> mean = context.get_auto_parallel_context("gradients_mean") | |||||
| >>> if mean.get_device_num_is_set(): | |||||
| >>> degree = context.get_auto_parallel_context("device_num") | |||||
| >>> else: | |||||
| >>> degree = get_group_size() | |||||
| >>> self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree) | |||||
| >>> | |||||
| >>> def construct(self, *args): | |||||
| >>> weights = self.weights | |||||
| >>> loss = self.network(*args) | |||||
| >>> sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) | |||||
| >>> grads = self.grad(self.network, weights)(*args, sens) | |||||
| >>> if self.reducer_flag: | |||||
| >>> # apply grad reducer on grads | |||||
| >>> grads = self.grad_reducer(grads) | |||||
| >>> return F.depend(loss, self.optimizer(grads)) | |||||
| >>> | |||||
| >>> network = Net() | |||||
| >>> optimizer = nn.Momentum(network.trainable_params(), learning_rate=0.1, momentum=0.9) | |||||
| >>> train_cell = TrainingWrapper(network, optimizer) | |||||
| >>> inputs = Tensor(np.ones([16, 16]).astype(np.float32)) | |||||
| >>> label = Tensor(np.zeros([16, 16]).astype(np.float32)) | |||||
| >>> grads = train_cell(inputs, label) | |||||
| """ | """ | ||||
| def __init__(self, parameters, group, mean=True, degree=None): | |||||
| def __init__(self, parameter_length, split_indices, mean=True, degree=None): | |||||
| super(DistributedGradReducerThor, self).__init__(auto_prefix=False) | super(DistributedGradReducerThor, self).__init__(auto_prefix=False) | ||||
| self.hyper_map = C.HyperMap() | self.hyper_map = C.HyperMap() | ||||
| self.mul = P.Mul() | self.mul = P.Mul() | ||||
| @@ -165,16 +125,11 @@ class DistributedGradReducerThor(Cell): | |||||
| raise ValueError("Parameter 'degree' in DistributedGradReducer should large than 0 and be int") | raise ValueError("Parameter 'degree' in DistributedGradReducer should large than 0 and be int") | ||||
| self.degree = degree | self.degree = degree | ||||
| self.mean = mean | self.mean = mean | ||||
| self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters) | |||||
| _init_optimizer_allreduce(group) | |||||
| def construct(self, grads): | |||||
| # In some circumstances, the data precision of grads could be mixed with float16 and float32. Thus, the | |||||
| # result of AllReduce is unreliable. To solve the problem, grads should be cast to float32 before AllReduce, | |||||
| # and cast back after the operation. | |||||
| datatypes = self.hyper_map(F.partial(_get_datatype), grads) | |||||
| grads = self.hyper_map(F.partial(_cast_datatype, mstype.float32), grads) | |||||
| new_grad = self.hyper_map(F.partial(reduce_opt, self.mul, self.degree), grads) | |||||
| new_grad = self.hyper_map(F.partial(_cast_datatype), datatypes, new_grad) | |||||
| return new_grad | |||||
| self.op_list = _init_allreduce_operators(parameter_length, split_indices) | |||||
| def construct(self, parameters): | |||||
| datatypes = self.hyper_map(F.partial(_get_datatype), parameters) | |||||
| parameters = self.hyper_map(F.partial(_cast_datatype, mstype.float32), parameters) | |||||
| new_parameters = self.hyper_map(F.partial(reduce_opt, self.mul, self.degree), self.op_list, parameters) | |||||
| new_parameters = self.hyper_map(F.partial(_cast_datatype), datatypes, new_parameters) | |||||
| return new_parameters | |||||
| @@ -22,7 +22,7 @@ import mindspore.common.dtype as mstype | |||||
| from mindspore._checkparam import check_bool | from mindspore._checkparam import check_bool | ||||
| from mindspore._checkparam import Validator as validator | from mindspore._checkparam import Validator as validator | ||||
| from mindspore.nn.optim.optimizer import Optimizer | from mindspore.nn.optim.optimizer import Optimizer | ||||
| from mindspore.parallel._utils import _get_device_num, _get_gradients_mean | |||||
| from mindspore.parallel._utils import _get_device_num, _get_mirror_mean | |||||
| from src.grad_reducer_thor import DistributedGradReducerThor | from src.grad_reducer_thor import DistributedGradReducerThor | ||||
| _momentum_opt = C.MultitypeFuncGraph("momentum_opt") | _momentum_opt = C.MultitypeFuncGraph("momentum_opt") | ||||
| @@ -85,10 +85,12 @@ class THOR_GPU(Optimizer): | |||||
| self.assign = P.Assign() | self.assign = P.Assign() | ||||
| self.mul = P.Mul() | self.mul = P.Mul() | ||||
| mean = _get_gradients_mean() | |||||
| mean = _get_mirror_mean() | |||||
| degree = _get_device_num() | degree = _get_device_num() | ||||
| self.grad_reducer_thorA = DistributedGradReducerThor(self.parameters, 0, mean, degree) | |||||
| self.grad_reducer_thorG = DistributedGradReducerThor(self.parameters, 0, mean, degree) | |||||
| parameter_length = len(self.feature_map) | |||||
| self.grad_reducer_thorA = DistributedGradReducerThor(parameter_length, ((parameter_length,), 0), mean, degree) | |||||
| self.grad_reducer_thorG = DistributedGradReducerThor(parameter_length, ((parameter_length,), 0), mean, degree) | |||||
| self.weight_decay = weight_decay | self.weight_decay = weight_decay | ||||
| self.decay_flags = tuple(decay_filter(x) for x in self.parameters) | self.decay_flags = tuple(decay_filter(x) for x in self.parameters) | ||||
| self.update_gradient = P.UpdateThorGradient(split_dim=128) | self.update_gradient = P.UpdateThorGradient(split_dim=128) | ||||
| @@ -191,12 +193,13 @@ class THOR(Optimizer): | |||||
| 1.0 / 196, 1.0 / 196, 1.0 / 196, | 1.0 / 196, 1.0 / 196, 1.0 / 196, | ||||
| 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, | 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, | ||||
| 1.0] | 1.0] | ||||
| mean = _get_gradients_mean() | |||||
| mean = _get_mirror_mean() | |||||
| degree = _get_device_num() | degree = _get_device_num() | ||||
| self.grad_reducer_Amax = DistributedGradReducerThor(self.parameters, 2, mean, degree) | |||||
| self.grad_reducer_Gmax = DistributedGradReducerThor(self.parameters, 5, mean, degree) | |||||
| self.grad_reducer_A = DistributedGradReducerThor(self.parameters, 3, mean, degree) | |||||
| self.grad_reducer_G = DistributedGradReducerThor(self.parameters, 4, mean, degree) | |||||
| parameter_length = len(self.feature_map) | |||||
| self.grad_reducer_Amax = DistributedGradReducerThor(parameter_length, ((27,), 2), mean, degree) | |||||
| self.grad_reducer_Gmax = DistributedGradReducerThor(parameter_length, ((27,), 4), mean, degree) | |||||
| self.grad_reducer_A = DistributedGradReducerThor(parameter_length, ((27,), 6), mean, degree) | |||||
| self.grad_reducer_G = DistributedGradReducerThor(parameter_length, ((27,), 8), mean, degree) | |||||
| self.matrix_A_inv = () | self.matrix_A_inv = () | ||||
| self.matrix_G_inv = () | self.matrix_G_inv = () | ||||
| self.matrix_max_inv = () | self.matrix_max_inv = () | ||||
| @@ -95,11 +95,7 @@ if __name__ == '__main__': | |||||
| context.set_context(device_id=device_id, enable_auto_mixed_precision=True) | context.set_context(device_id=device_id, enable_auto_mixed_precision=True) | ||||
| context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | ||||
| gradients_mean=True) | gradients_mean=True) | ||||
| auto_parallel_context().set_all_reduce_fusion_split_indices([107], "hccl_world_groupsum1") | |||||
| auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum2") | |||||
| auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum3") | |||||
| auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum4") | |||||
| auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum5") | |||||
| auto_parallel_context().set_all_reduce_fusion_split_indices([107]) | |||||
| init() | init() | ||||
| # GPU target | # GPU target | ||||
| else: | else: | ||||
| @@ -12,150 +12,109 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """grad_reducer_thor""" | |||||
| import mindspore.common.dtype as mstype | |||||
| from mindspore.communication.management import GlobalComm, get_group_size | |||||
| """grad reducer cell for distributed training""" | |||||
| from mindspore.nn.cell import Cell | from mindspore.nn.cell import Cell | ||||
| from mindspore.communication.management import GlobalComm, get_group_size | |||||
| from mindspore.ops import functional as F, composite as C, operations as P | from mindspore.ops import functional as F, composite as C, operations as P | ||||
| from mindspore.ops.operations.comm_ops import AllReduce, ReduceOp | |||||
| from mindspore.ops.operations.comm_ops import AllReduce | |||||
| import mindspore.common.dtype as mstype | |||||
| reduce_opt = C.MultitypeFuncGraph("reduce_opt") | reduce_opt = C.MultitypeFuncGraph("reduce_opt") | ||||
| _all_reduce_A = AllReduce() | |||||
| def _init_allreduce_operators(length, split_indices): | |||||
| """ initialize allreduce communication operators""" | |||||
| indices = split_indices[0] | |||||
| fusion = split_indices[1] | |||||
| op_list = () | |||||
| j = 0 | |||||
| for i in range(length): | |||||
| if j <= len(indices)-1: | |||||
| temp = indices[j] | |||||
| else: | |||||
| temp = length | |||||
| if i >= temp: | |||||
| j = j + 1 | |||||
| fusion = fusion + 1 | |||||
| op = AllReduce('sum', GlobalComm.WORLD_COMM_GROUP) | |||||
| op.add_prim_attr('fusion', fusion) | |||||
| op_list = op_list + (op,) | |||||
| return op_list | |||||
| @reduce_opt.register("Function", "Number", "Function", "Tensor") | |||||
| def _tensors_allreduce_mean(mul, degree, allreduce, parameters): | |||||
| """ | |||||
| Apply allreduce on parameters. | |||||
| def _init_optimizer_allreduce(group): | |||||
| global _all_reduce_A | |||||
| _all_reduce_A = AllReduce(ReduceOp.SUM, GlobalComm.WORLD_COMM_GROUP) | |||||
| _all_reduce_A.add_prim_attr('fusion', group) | |||||
| Args: | |||||
| mul(Primitive): The mul operator for parameters. | |||||
| degree (int): The mean coefficient. | |||||
| allreduce (Primitive): The communication operator for parameters. | |||||
| parameters (Tensor): The parameters before operation. | |||||
| @reduce_opt.register("Function", "Number", "Tensor") | |||||
| def _tensors_allreduce_mean(mul, degree, grad): | |||||
| degree = F.scalar_cast(degree, F.dtype(grad)) | |||||
| grad = _all_reduce_A(grad) | |||||
| Returns: | |||||
| Tensor, the parameters after operation. | |||||
| """ | |||||
| degree = F.scalar_cast(degree, F.dtype(parameters)) | |||||
| parameters = allreduce(parameters) | |||||
| cast_op = P.Cast() | cast_op = P.Cast() | ||||
| return mul(grad, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(grad))) | |||||
| @reduce_opt.register("Bool", "Tensor") | |||||
| def _tensors_allreduce(allreduce_filter, grad): | |||||
| if allreduce_filter: | |||||
| return _all_reduce_A(grad) | |||||
| return grad | |||||
| return mul(parameters, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(parameters))) | |||||
| _get_datatype = C.MultitypeFuncGraph("_get_datatype") | _get_datatype = C.MultitypeFuncGraph("_get_datatype") | ||||
| @_get_datatype.register("Tensor") | @_get_datatype.register("Tensor") | ||||
| def _tensors_get_datatype(grad): | |||||
| def _tensors_get_datatype(parameters): | |||||
| """ | """ | ||||
| Acquire gradient datatype. | |||||
| Acquire parameters datatype. | |||||
| Args: | Args: | ||||
| grad (Tensor): The gradient tensor before operation. | |||||
| parameters (Tensor): The parameters before operation. | |||||
| Returns: | Returns: | ||||
| mstype, the datatype of gradient. | |||||
| mstype, the datatype of parameters. | |||||
| """ | """ | ||||
| return F.dtype(grad) | |||||
| return F.dtype(parameters) | |||||
| _cast_datatype = C.MultitypeFuncGraph("_cast_datatype") | _cast_datatype = C.MultitypeFuncGraph("_cast_datatype") | ||||
| @_cast_datatype.register("TypeType", "Tensor") | @_cast_datatype.register("TypeType", "Tensor") | ||||
| def _tensors_cast_datatype(datatype, grad): | |||||
| def _tensors_cast_datatype(datatype, parameters): | |||||
| """ | """ | ||||
| Cast gradient to datatype. | |||||
| Cast parameters to datatype. | |||||
| Args: | Args: | ||||
| datatype (mstype): the destination datatype of gradient. | |||||
| grad (Tensor): The gradient tensor before operation. | |||||
| datatype (mstype): the destination datatype of parameters. | |||||
| parameters (Tensor): The parameters before operation. | |||||
| Returns: | Returns: | ||||
| Tensor, the gradient tensor after operation. | |||||
| Tensor, the parameters after operation. | |||||
| """ | """ | ||||
| return F.cast(grad, datatype) | |||||
| return F.cast(parameters, datatype) | |||||
| class DistributedGradReducerThor(Cell): | class DistributedGradReducerThor(Cell): | ||||
| """ | """ | ||||
| A distributed optimizer. | A distributed optimizer. | ||||
| Constructs a gradient reducer Cell, which applies communication and average operations on | |||||
| single-process gradient values. | |||||
| Constructs a parameters reducer Cell, which applies communication and average operations on | |||||
| single-process parameters values. | |||||
| Args: | Args: | ||||
| parameters (list): the parameters to be updated. | |||||
| group (int): the different group to allreduce. | |||||
| mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. Default: False. | |||||
| parameter_length (int): length of the parameters to be updated. | |||||
| split_indices(tuple): parameter split indices. | |||||
| mean (bool): When mean is true, the mean coefficient (degree) would apply on parameters. Default: False. | |||||
| degree (int): The mean coefficient. Usually it equals to device number. Default: None. | degree (int): The mean coefficient. Usually it equals to device number. Default: None. | ||||
| Raises: | Raises: | ||||
| ValueError: If degree is not a int or less than 0. | ValueError: If degree is not a int or less than 0. | ||||
| Examples: | |||||
| >>> from mindspore.communication import init, get_group_size | |||||
| >>> from mindspore.ops import composite as C | |||||
| >>> from mindspore.ops import operations as P | |||||
| >>> from mindspore.ops import functional as F | |||||
| >>> from mindspore import context | |||||
| >>> from mindspore import nn | |||||
| >>> from mindspore import ParameterTuple | |||||
| >>> from mindspore.context import ParallelMode | |||||
| >>> | |||||
| >>> device_id = int(os.environ["DEVICE_ID"]) | |||||
| >>> context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, | |||||
| >>> device_id=int(device_id), enable_hccl=True) | |||||
| >>> init() | |||||
| >>> context.reset_auto_parallel_context() | |||||
| >>> context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL) | |||||
| >>> | |||||
| >>> | |||||
| >>> class TrainingWrapper(nn.Cell): | |||||
| >>> def __init__(self, network, optimizer, sens=1.0): | |||||
| >>> super(TrainingWrapper, self).__init__(auto_prefix=False) | |||||
| >>> self.network = network | |||||
| >>> self.network.add_flags(defer_inline=True) | |||||
| >>> self.weights = ParameterTuple(network.trainable_params()) | |||||
| >>> self.optimizer = optimizer | |||||
| >>> self.grad = C.GradOperation(get_by_list=True, sens_param=True) | |||||
| >>> self.sens = sens | |||||
| >>> self.reducer_flag = False | |||||
| >>> self.grad_reducer = None | |||||
| >>> self.parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||||
| >>> if self.parallel_mode in [ParallelMode.DATA_PARALLEL, | |||||
| >>> ParallelMode.HYBRID_PARALLEL]: | |||||
| >>> self.reducer_flag = True | |||||
| >>> if self.reducer_flag: | |||||
| >>> mean = context.get_auto_parallel_context("gradients_mean") | |||||
| >>> if mean.get_device_num_is_set(): | |||||
| >>> degree = context.get_auto_parallel_context("device_num") | |||||
| >>> else: | |||||
| >>> degree = get_group_size() | |||||
| >>> self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree) | |||||
| >>> | |||||
| >>> def construct(self, *args): | |||||
| >>> weights = self.weights | |||||
| >>> loss = self.network(*args) | |||||
| >>> sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) | |||||
| >>> grads = self.grad(self.network, weights)(*args, sens) | |||||
| >>> if self.reducer_flag: | |||||
| >>> # apply grad reducer on grads | |||||
| >>> grads = self.grad_reducer(grads) | |||||
| >>> return F.depend(loss, self.optimizer(grads)) | |||||
| >>> | |||||
| >>> network = Net() | |||||
| >>> optimizer = nn.Momentum(network.trainable_params(), learning_rate=0.1, momentum=0.9) | |||||
| >>> train_cell = TrainingWrapper(network, optimizer) | |||||
| >>> inputs = Tensor(np.ones([16, 16]).astype(np.float32)) | |||||
| >>> label = Tensor(np.zeros([16, 16]).astype(np.float32)) | |||||
| >>> grads = train_cell(inputs, label) | |||||
| """ | """ | ||||
| def __init__(self, parameters, group, mean=True, degree=None): | |||||
| def __init__(self, parameter_length, split_indices, mean=True, degree=None): | |||||
| super(DistributedGradReducerThor, self).__init__(auto_prefix=False) | super(DistributedGradReducerThor, self).__init__(auto_prefix=False) | ||||
| self.hyper_map = C.HyperMap() | self.hyper_map = C.HyperMap() | ||||
| self.mul = P.Mul() | self.mul = P.Mul() | ||||
| @@ -166,20 +125,11 @@ class DistributedGradReducerThor(Cell): | |||||
| raise ValueError("Parameter 'degree' in DistributedGradReducer should large than 0 and be int") | raise ValueError("Parameter 'degree' in DistributedGradReducer should large than 0 and be int") | ||||
| self.degree = degree | self.degree = degree | ||||
| self.mean = mean | self.mean = mean | ||||
| self.allreduce_filter = tuple(x.layerwise_parallel is False for x in parameters) | |||||
| _init_optimizer_allreduce(group) | |||||
| def construct(self, grads): | |||||
| # In some circumstances, the data precision of grads could be mixed with float16 and float32. Thus, the | |||||
| # result of AllReduce is unreliable. To solve the problem, grads should be cast to float32 before AllReduce, | |||||
| # and cast back after the operation. | |||||
| datatypes = self.hyper_map(F.partial(_get_datatype), grads) | |||||
| grads = self.hyper_map(F.partial(_cast_datatype, mstype.float32), grads) | |||||
| if self.mean: | |||||
| new_grad = self.hyper_map(F.partial(reduce_opt, self.mul, self.degree), grads) | |||||
| else: | |||||
| new_grad = self.hyper_map(F.partial(reduce_opt), self.allreduce_filter, grads) | |||||
| new_grad = self.hyper_map(F.partial(_cast_datatype), datatypes, new_grad) | |||||
| return new_grad | |||||
| self.op_list = _init_allreduce_operators(parameter_length, split_indices) | |||||
| def construct(self, parameters): | |||||
| datatypes = self.hyper_map(F.partial(_get_datatype), parameters) | |||||
| parameters = self.hyper_map(F.partial(_cast_datatype, mstype.float32), parameters) | |||||
| new_parameters = self.hyper_map(F.partial(reduce_opt, self.mul, self.degree), self.op_list, parameters) | |||||
| new_parameters = self.hyper_map(F.partial(_cast_datatype), datatypes, new_parameters) | |||||
| return new_parameters | |||||
| @@ -89,10 +89,11 @@ class THOR(Optimizer): | |||||
| 1.0] | 1.0] | ||||
| mean = _get_gradients_mean() | mean = _get_gradients_mean() | ||||
| degree = _get_device_num() | degree = _get_device_num() | ||||
| self.grad_reducer_Amax = DistributedGradReducerThor(self.parameters, 2, mean, degree) | |||||
| self.grad_reducer_Gmax = DistributedGradReducerThor(self.parameters, 5, mean, degree) | |||||
| self.grad_reducer_A = DistributedGradReducerThor(self.parameters, 3, mean, degree) | |||||
| self.grad_reducer_G = DistributedGradReducerThor(self.parameters, 4, mean, degree) | |||||
| parameter_length = len(self.feature_map) | |||||
| self.grad_reducer_Amax = DistributedGradReducerThor(parameter_length, ((27,), 2), mean, degree) | |||||
| self.grad_reducer_Gmax = DistributedGradReducerThor(parameter_length, ((27,), 4), mean, degree) | |||||
| self.grad_reducer_A = DistributedGradReducerThor(parameter_length, ((27,), 6), mean, degree) | |||||
| self.grad_reducer_G = DistributedGradReducerThor(parameter_length, ((27,), 8), mean, degree) | |||||
| self.matrix_A_inv = () | self.matrix_A_inv = () | ||||
| self.matrix_G_inv = () | self.matrix_G_inv = () | ||||
| self.matrix_max_inv = () | self.matrix_max_inv = () | ||||
| @@ -241,11 +241,7 @@ def train_process_thor(q, device_id, epoch_size, device_num, enable_hccl): | |||||
| if enable_hccl: | if enable_hccl: | ||||
| context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | ||||
| gradients_mean=True, parameter_broadcast=True) | gradients_mean=True, parameter_broadcast=True) | ||||
| auto_parallel_context().set_all_reduce_fusion_split_indices([107], "hccl_world_groupsum1") | |||||
| auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum2") | |||||
| auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum3") | |||||
| auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum4") | |||||
| auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum5") | |||||
| auto_parallel_context().set_all_reduce_fusion_split_indices([107]) | |||||
| init() | init() | ||||
| # network | # network | ||||