|
|
@@ -20,8 +20,11 @@ from mindspore.common.initializer import initializer |
|
|
from mindspore.common.tensor import Tensor |
|
|
from mindspore.common.tensor import Tensor |
|
|
import mindspore.common.dtype as mstype |
|
|
import mindspore.common.dtype as mstype |
|
|
import mindspore.context as context |
|
|
import mindspore.context as context |
|
|
from mindspore._checkparam import check_int_positive, check_bool, check_typename |
|
|
|
|
|
|
|
|
from mindspore._checkparam import check_bool, check_typename |
|
|
from mindspore._extends import cell_attr_register |
|
|
from mindspore._extends import cell_attr_register |
|
|
|
|
|
from mindspore.communication.management import get_group_size, get_rank |
|
|
|
|
|
from mindspore.communication import management |
|
|
|
|
|
from mindspore._checkparam import check_int_positive |
|
|
from ..cell import Cell |
|
|
from ..cell import Cell |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -30,6 +33,7 @@ class _BatchNorm(Cell): |
|
|
@cell_attr_register |
|
|
@cell_attr_register |
|
|
def __init__(self, |
|
|
def __init__(self, |
|
|
num_features, |
|
|
num_features, |
|
|
|
|
|
group=1, |
|
|
eps=1e-5, |
|
|
eps=1e-5, |
|
|
momentum=0.9, |
|
|
momentum=0.9, |
|
|
affine=True, |
|
|
affine=True, |
|
|
@@ -56,6 +60,21 @@ class _BatchNorm(Cell): |
|
|
gamma_init, num_features), name="gamma", requires_grad=affine) |
|
|
gamma_init, num_features), name="gamma", requires_grad=affine) |
|
|
self.beta = Parameter(initializer( |
|
|
self.beta = Parameter(initializer( |
|
|
beta_init, num_features), name="beta", requires_grad=affine) |
|
|
beta_init, num_features), name="beta", requires_grad=affine) |
|
|
|
|
|
self.group = check_int_positive(group) |
|
|
|
|
|
if self.group != 1: |
|
|
|
|
|
self.rank_id = get_rank() |
|
|
|
|
|
self.rank_size = get_group_size() |
|
|
|
|
|
self.device_list = [i for i in range(0, self.rank_size)] |
|
|
|
|
|
self.rank_list = self.list_group(self.device_list, self.group) |
|
|
|
|
|
self.rank_list_idx = len(self.rank_list) |
|
|
|
|
|
for i in range(self.rank_list_idx): |
|
|
|
|
|
if self.rank_id in self.rank_list[i] and self.group != 1: |
|
|
|
|
|
self.is_global = True |
|
|
|
|
|
management.create_group('group' + str(i), self.rank_list[i]) |
|
|
|
|
|
self.all_reduce = P.AllReduce(P.ReduceOp.SUM, 'group' + str(i)).add_prim_attr('fusion', 1) |
|
|
|
|
|
self.shape = P.Shape() |
|
|
|
|
|
self.reduce_mean = P.ReduceMean() |
|
|
|
|
|
self.square = P.Square() |
|
|
|
|
|
|
|
|
if context.get_context("enable_ge"): |
|
|
if context.get_context("enable_ge"): |
|
|
self.is_ge_backend = True |
|
|
self.is_ge_backend = True |
|
|
@@ -82,22 +101,53 @@ class _BatchNorm(Cell): |
|
|
def _check_data_dim(self, x): |
|
|
def _check_data_dim(self, x): |
|
|
raise NotImplementedError |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
|
|
|
def list_group(self, world_rank, group_size): |
|
|
|
|
|
if group_size > get_group_size(): |
|
|
|
|
|
raise ValueError("group size can not be greater than local rank size, group size is {}, " |
|
|
|
|
|
"local_rank_size is {}".format(group_size, get_group_size())) |
|
|
|
|
|
if len(world_rank) % group_size != 0: |
|
|
|
|
|
raise ValueError("please make your group size correct.") |
|
|
|
|
|
world_rank_list = zip(*(iter(world_rank),) *group_size) |
|
|
|
|
|
group_list = [list(i) for i in world_rank_list] |
|
|
|
|
|
return group_list |
|
|
|
|
|
|
|
|
def construct(self, x): |
|
|
def construct(self, x): |
|
|
if self.training and self.use_batch_statistics: |
|
|
if self.training and self.use_batch_statistics: |
|
|
if self.is_ge_backend: |
|
|
if self.is_ge_backend: |
|
|
y, batch_mean, batch_var, _, _ = \ |
|
|
|
|
|
self.bn_train(x, |
|
|
|
|
|
self.gamma, |
|
|
|
|
|
self.beta, |
|
|
|
|
|
None, |
|
|
|
|
|
None) |
|
|
|
|
|
|
|
|
|
|
|
mean_sub = self.sub_mean(self.moving_mean, batch_mean) |
|
|
|
|
|
temp_mean = self.mul_mean(mean_sub, self.momentum) |
|
|
|
|
|
mean_sub2 = self.sub_var(self.moving_variance, batch_var) |
|
|
|
|
|
temp_variance = self.mul_var(mean_sub2, self.momentum) |
|
|
|
|
|
y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean)) |
|
|
|
|
|
y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance)) |
|
|
|
|
|
|
|
|
if self.is_global: |
|
|
|
|
|
x_mean = self.reduce_mean(x) |
|
|
|
|
|
x_mean_square = self.reduce_mean(self.square(x)) |
|
|
|
|
|
global_batch_mean = self.all_reduce(x_mean) / self.group |
|
|
|
|
|
global_batch_mean_square = self.all_reduce(x_mean_square) / self.group |
|
|
|
|
|
global_mean = global_batch_mean |
|
|
|
|
|
global_var = global_batch_mean_square - self.square(global_batch_mean) |
|
|
|
|
|
y, batch_mean, batch_var, _, _ = \ |
|
|
|
|
|
self.bn_train(x, |
|
|
|
|
|
self.gamma, |
|
|
|
|
|
self.beta, |
|
|
|
|
|
None, |
|
|
|
|
|
None) |
|
|
|
|
|
|
|
|
|
|
|
mean_sub = self.sub_mean(self.moving_mean, global_mean) |
|
|
|
|
|
temp_mean = self.mul_mean(mean_sub, self.momentum) |
|
|
|
|
|
mean_sub2 = self.sub_var(self.moving_variance, global_var) |
|
|
|
|
|
temp_variance = self.mul_var(mean_sub2, self.momentum) |
|
|
|
|
|
y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean)) |
|
|
|
|
|
y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance)) |
|
|
|
|
|
else: |
|
|
|
|
|
y, batch_mean, batch_var, _, _ = \ |
|
|
|
|
|
self.bn_train(x, |
|
|
|
|
|
self.gamma, |
|
|
|
|
|
self.beta, |
|
|
|
|
|
None, |
|
|
|
|
|
None) |
|
|
|
|
|
|
|
|
|
|
|
mean_sub = self.sub_mean(self.moving_mean, batch_mean) |
|
|
|
|
|
temp_mean = self.mul_mean(mean_sub, self.momentum) |
|
|
|
|
|
mean_sub2 = self.sub_var(self.moving_variance, batch_var) |
|
|
|
|
|
temp_variance = self.mul_var(mean_sub2, self.momentum) |
|
|
|
|
|
y = F.depend(y, self.assign_sub_mean(self.moving_mean, temp_mean)) |
|
|
|
|
|
y = F.depend(y, self.assign_sub_var(self.moving_variance, temp_variance)) |
|
|
else: |
|
|
else: |
|
|
y = self.bn_train(x, |
|
|
y = self.bn_train(x, |
|
|
self.gamma, |
|
|
self.gamma, |
|
|
@@ -221,6 +271,55 @@ class BatchNorm2d(_BatchNorm): |
|
|
pass |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GlobalBatchNorm(_BatchNorm): |
|
|
|
|
|
r""" |
|
|
|
|
|
Global normalization layer over a N-dimension input. |
|
|
|
|
|
|
|
|
|
|
|
Global Normalization is cross device synchronized batch normalization. Batch Normalization implementation |
|
|
|
|
|
only normalize the data within each device. Global normalization will normalize the input within the group. |
|
|
|
|
|
It has been described in the paper `Batch Normalization: Accelerating Deep Network Training by |
|
|
|
|
|
Reducing Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`_. It rescales and recenters the |
|
|
|
|
|
feature using a mini-batch of data and the learned parameters which can be described in the following formula. |
|
|
|
|
|
|
|
|
|
|
|
.. math:: |
|
|
|
|
|
y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta |
|
|
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
|
num_features (int): `C` from an expected input of size (N, C, H, W). |
|
|
|
|
|
group (int): The number of device in each group. |
|
|
|
|
|
eps (float): A value added to the denominator for numerical stability. Default: 1e-5. |
|
|
|
|
|
momentum (float): A floating hyperparameter of the momentum for the |
|
|
|
|
|
running_mean and running_var computation. Default: 0.9. |
|
|
|
|
|
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma weight. |
|
|
|
|
|
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', |
|
|
|
|
|
'he_uniform', etc. Default: 'ones'. |
|
|
|
|
|
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta weight. |
|
|
|
|
|
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', |
|
|
|
|
|
'he_uniform', etc. Default: 'zeros'. |
|
|
|
|
|
moving_mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving mean. |
|
|
|
|
|
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', |
|
|
|
|
|
'he_uniform', etc. Default: 'zeros'. |
|
|
|
|
|
moving_var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the moving variance. |
|
|
|
|
|
The values of str refer to the function `initializer` including 'zeros', 'ones', 'xavier_uniform', |
|
|
|
|
|
'he_uniform', etc. Default: 'ones'. |
|
|
|
|
|
use_batch_statistics (bool): If true, use the mean value and variance value of current batch data, else use |
|
|
|
|
|
the mean value and variance value of specified value. Default: True. |
|
|
|
|
|
|
|
|
|
|
|
Inputs: |
|
|
|
|
|
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. |
|
|
|
|
|
|
|
|
|
|
|
Outputs: |
|
|
|
|
|
Tensor, the normalized, scaled, offset tensor, of shape :math:`(N, C_{out}, H_{out}, W_{out})`. |
|
|
|
|
|
|
|
|
|
|
|
Examples: |
|
|
|
|
|
>>> global_bn_op = nn.GlobalBatchNorm(num_features=3, group=4) |
|
|
|
|
|
>>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32) |
|
|
|
|
|
>>> global_bn_op(input) |
|
|
|
|
|
""" |
|
|
|
|
|
def _check_data_dim(self, x): |
|
|
|
|
|
if x.dim == 0: |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
class LayerNorm(Cell): |
|
|
class LayerNorm(Cell): |
|
|
r""" |
|
|
r""" |
|
|
Applies Layer Normalization over a mini-batch of inputs. |
|
|
Applies Layer Normalization over a mini-batch of inputs. |
|
|
|