Browse Source

!12004 thor generalization code submit

From: @sl_wang
Reviewed-by: @guoqi1024
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
1f3b059195
25 changed files with 2687 additions and 1140 deletions
  1. +3
    -1
      mindspore/nn/layer/__init__.py
  2. +672
    -0
      mindspore/nn/layer/thor_layer.py
  3. +2
    -1
      mindspore/nn/optim/__init__.py
  4. +897
    -0
      mindspore/nn/optim/thor.py
  5. +3
    -2
      mindspore/nn/wrap/grad_reducer.py
  6. +19
    -0
      mindspore/train/train_thor/__init__.py
  7. +157
    -0
      mindspore/train/train_thor/convert_utils.py
  8. +188
    -0
      mindspore/train/train_thor/dataset_helper.py
  9. +236
    -0
      mindspore/train/train_thor/model_thor.py
  10. +50
    -0
      model_zoo/official/cv/resnet/src/config.py
  11. +33
    -0
      model_zoo/official/cv/resnet/src/lr_generator.py
  12. +74
    -4
      model_zoo/official/cv/resnet/src/resnet.py
  13. +29
    -7
      model_zoo/official/cv/resnet/train.py
  14. +1
    -1
      model_zoo/official/nlp/bert/scripts/run_distributed_pretrain_ascend.sh
  15. +1
    -0
      model_zoo/official/nlp/bert/scripts/run_standalone_pretrain_ascend.sh
  16. +6
    -0
      model_zoo/official/nlp/bert/src/config.py
  17. +61
    -3
      model_zoo/official/nlp/bert/src/utils.py
  18. +45
    -38
      tests/st/networks/models/bert/bert_performance/test_bert_thor.py
  19. +1
    -1
      tests/st/networks/models/resnet50/src_thor/config.py
  20. +0
    -135
      tests/st/networks/models/resnet50/src_thor/grad_reducer_thor.py
  21. +0
    -88
      tests/st/networks/models/resnet50/src_thor/lr_generator.py
  22. +188
    -154
      tests/st/networks/models/resnet50/src_thor/resnet.py
  23. +0
    -202
      tests/st/networks/models/resnet50/src_thor/thor.py
  24. +0
    -479
      tests/st/networks/models/resnet50/src_thor/thor_layer.py
  25. +21
    -24
      tests/st/networks/models/resnet50/test_resnet50_imagenet.py

+ 3
- 1
mindspore/nn/layer/__init__.py View File

@@ -18,7 +18,7 @@ Layer.
The high-level components(Cells) used to construct the neural network.
"""
from . import activation, normalization, container, conv, lstm, basic, embedding, pooling, image, quant, math, \
combined, timedistributed
combined, timedistributed, thor_layer
from .activation import *
from .normalization import *
from .container import *
@@ -32,6 +32,7 @@ from .quant import *
from .math import *
from .combined import *
from .timedistributed import *
from .thor_layer import *

__all__ = []
__all__.extend(activation.__all__)
@@ -47,3 +48,4 @@ __all__.extend(quant.__all__)
__all__.extend(math.__all__)
__all__.extend(combined.__all__)
__all__.extend(timedistributed.__all__)
__all__.extend(thor_layer.__all__)

+ 672
- 0
mindspore/nn/layer/thor_layer.py View File

@@ -0,0 +1,672 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""layers for second order optimization"""
import numpy as np
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.common.initializer import initializer, Initializer
from mindspore.ops import operations as P
from mindspore.common.parameter import Parameter
from mindspore._checkparam import Validator, Rel, twice
from mindspore import context
from mindspore.nn.cell import Cell
from mindspore.nn.layer.activation import get_activation


__all__ = ['Dense_Thor', 'Conv2d_Thor', 'Embedding_Thor']

class Dense_Thor(Cell):
r"""
The dense connected layer.

Applies dense connected layer for the input. This layer implements the operation as:

.. math::
\text{outputs} = \text{activation}(\text{inputs} * \text{kernel} + \text{bias}),

where :math:`\text{activation}` is the activation function passed as the activation
argument (if passed in), :math:`\text{kernel}` is a weight matrix with the same
data type as the inputs created by the layer, and :math:`\text{bias}` is a bias vector
with the same data type as the inputs created by the layer (only if has_bias is True).

Args:
in_channels (int): The number of channels in the input space.
out_channels (int): The number of channels in the output space.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
activation (str): activate function applied to the output of the fully connected layer, eg. 'ReLU'.
Default: None.

Raises:
ValueError: If weight_init or bias_init shape is incorrect.

Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`.

Outputs:
Tensor of shape :math:`(N, out\_channels)`.

Examples:
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
>>> net = nn.Dense(3, 4)
>>> net(input)
[[ 2.5246444 2.2738023 0.5711005 -3.9399147 ]
[ 1.0739875 4.0155234 0.94188046 -5.459526 ]]
"""
def __init__(self,
in_channels,
out_channels,
weight_init='normal',
bias_init='zeros',
has_bias=True,
activation=None):
super(Dense_Thor, self).__init__()
self.thor = True
self.in_channels = Validator.check_positive_int(in_channels)
self.out_channels = Validator.check_positive_int(out_channels)
self.has_bias = Validator.check_bool(has_bias)
if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \
weight_init.shape[1] != in_channels:
raise ValueError("Weight init shape error.")
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")

self.bias = None
if self.has_bias:
if isinstance(bias_init, Tensor):
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels:
raise ValueError("Bias init shape error.")
self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
self.bias_add = P.BiasAdd()

self.matmul = P.MatMul(transpose_b=True)
self.activation = get_activation(activation)
self.activation_flag = self.activation is not None

self.matrix_A = Parameter(Tensor(np.zeros([in_channels, in_channels]).astype(np.float32)),
name='matrix_A', requires_grad=False)
self.shape = P.Shape()
self.reshape = P.Reshape()
self.transpose = P.Transpose()
self.mul = P.Mul()
self.is_Ascend = True
if context.get_context("device_target") == "Ascend":
if out_channels == 1001:
self.matrix_G = Parameter(Tensor(np.zeros([1024, 1024]).astype(np.float32)),
name='matrix_G', requires_grad=False)
self.pad = P.Pad(((0, 23), (0, 23)))
self.pad1 = P.Pad(((0, 7), (0, 7)))
self.slice = P.Slice()
self.add = P.TensorAdd()
else:
self.matrix_G = Parameter(Tensor(np.eye(out_channels).astype(np.float32)),
name="matrix_G", requires_grad=False)
self.abs = P.Abs()
self.reduce_max = P.ReduceMax(keep_dims=False)
self.neg = P.Neg()
self.reduce_sum = P.ReduceSum()
self.matmul = P.MatMul(transpose_b=True)
self.cube_matmul = P.CusMatMulCube(transpose_a=True)
self.cast = P.Cast()
self.is_nsp_layer = (out_channels == 2)
else:
self.is_Ascend = False
self.matrix_G = Parameter(Tensor(np.eye(out_channels).astype(np.float32)),
name="matrix_G", requires_grad=False)
self.cube_matmul = P.MatMul(transpose_a=True)
self.getG = P.InsertGradientOf(self.save_gradient)


def save_gradient(self, dout):
"""
this function only for thor optimizer
save_gradient
"""
out = dout
if self.is_Ascend:
if not self.is_nsp_layer:
shape = self.shape(dout)
normalizer = self.cast(shape[0], mstype.float32)
matrix_G = self.cube_matmul(dout, dout)
matrix_G = self.mul(matrix_G, 1.0 / normalizer)
if self.out_channels == 1001:
matrix_G = P.Pad(((0, 23), (0, 23)))(matrix_G)
self.matrix_G = matrix_G
else:
dout_shape = self.shape(dout)
normalizer = dout_shape[0]
matrix_G = self.cube_matmul(dout, dout)
matrix_G = self.mul(matrix_G, 1.0 / normalizer)
self.matrix_G = matrix_G
return out

def construct(self, x):
if self.thor:
if self.is_Ascend:
inputs = self.cube_matmul(x, x)
shape = self.shape(x)
normalizer = self.cast(shape[0], mstype.float32)
matrix_A = self.mul(inputs, 1.0 / normalizer)
self.matrix_A = matrix_A
else:
inputs = self.cube_matmul(x, x)
inputs_shape = self.shape(inputs)
normalizer = inputs_shape[0]
matrix_A = self.mul(inputs, 1.0 / normalizer)
self.matrix_A = matrix_A
x = self.matmul(x, self.weight)
x = self.getG(x)
else:
x = self.matmul(x, self.weight)
if self.has_bias:
x = self.bias_add(x, self.bias)
if self.activation_flag:
x = self.activation(x)
return x

def extend_repr(self):
s = 'input_channels={}, output_channels={}'.format(self.in_channels, self.out_channels)
if self.has_bias:
s += ', has_bias={}'.format(self.has_bias)
# if self.activation_flag:
# s += ', activation={}'.format(self.activation)
return s

class _Conv(Cell):
"""
Applies a N-D convolution over an input signal composed of several input planes.
"""

def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
pad_mode,
padding,
dilation,
group,
has_bias,
weight_init,
bias_init,
transposed=False):
super(_Conv, self).__init__()
self.in_channels = Validator.check_positive_int(in_channels)
self.out_channels = Validator.check_positive_int(out_channels)
self.kernel_size = kernel_size
self.stride = stride
self.pad_mode = pad_mode
# self.weight_init = weight_init
self.bias_init = bias_init
if isinstance(padding, int):
Validator.check_non_negative_int(padding, 'padding', self.cls_name)
self.padding = padding
elif isinstance(padding, tuple):
for pad in padding:
Validator.check_non_negative_int(pad, 'padding item', self.cls_name)
self.padding = padding
else:
raise TypeError("padding type must be int/tuple(int) cannot be {}!".format(type(padding)))

self.dilation = dilation
self.group = Validator.check_positive_int(group)
self.has_bias = has_bias
if (not isinstance(kernel_size[0], int)) or (not isinstance(kernel_size[1], int)) or \
isinstance(kernel_size[0], bool) or isinstance(kernel_size[1], bool) or \
kernel_size[0] < 1 or kernel_size[1] < 1:
raise ValueError("Attr 'kernel_size' of 'Conv2D' Op passed "
+ str(self.kernel_size) + ", should be a int or tuple and equal to or greater than 1.")
if (not isinstance(stride[0], int)) or (not isinstance(stride[1], int)) or \
isinstance(stride[0], bool) or isinstance(stride[1], bool) or stride[0] < 1 or stride[1] < 1:
raise ValueError("Attr 'stride' of 'Conv2D' Op passed "
+ str(self.stride) + ", should be a int or tuple and equal to or greater than 1.")
if (not isinstance(dilation[0], int)) or (not isinstance(dilation[1], int)) or \
isinstance(dilation[0], bool) or isinstance(dilation[1], bool) or dilation[0] < 1 or dilation[1] < 1:
raise ValueError("Attr 'dilation' of 'Conv2D' Op passed "
+ str(self.dilation) + ", should be a int or tuple and equal to or greater than 1.")
if in_channels % group != 0:
raise ValueError("Attr 'in_channels' of 'Conv2D' Op must be divisible by "
"attr 'group' of 'Conv2D' Op.")
if out_channels % group != 0:
raise ValueError("Attr 'out_channels' of 'Conv2D' Op must be divisible by "
"attr 'group' of 'Conv2D' Op.")
if transposed:
shape = [in_channels, out_channels // group, *kernel_size]
else:
shape = [out_channels, in_channels // group, *kernel_size]
self.weight = Parameter(initializer(weight_init, shape), name='weight')

if Validator.check_bool(has_bias):
self.bias = Parameter(initializer(self.bias_init, [out_channels]), name='bias')
else:
if self.bias_init != 'zeros':
logger.warning("Value of 'has_bias' is False, value of 'bias_init' will be ignored.")
self.bias = None

def construct(self, *inputs):
"""Must be overridden by all subclasses."""
raise NotImplementedError


class Conv2d_Thor(_Conv):
r"""
2D convolution layer.

Applies a 2D convolution over an input tensor which is typically of shape :math:`(N, C_{in}, H_{in}, W_{in})`,
where :math:`N` is batch size, :math:`C_{in}` is channel number, and :math:`H_{in}, W_{in})` are height and width.
For each batch of shape :math:`(C_{in}, H_{in}, W_{in})`, the formula is defined as:

.. math::

out_j = \sum_{i=0}^{C_{in} - 1} ccor(W_{ij}, X_i) + b_j,

where :math:`ccor` is the cross-correlation operator, :math:`C_{in}` is the input channel number, :math:`j` ranges
from :math:`0` to :math:`C_{out} - 1`, :math:`W_{ij}` corresponds to the :math:`i`-th channel of the :math:`j`-th
filter and :math:`out_{j}` corresponds to the :math:`j`-th channel of the output. :math:`W_{ij}` is a slice
of kernel and it has shape :math:`(\text{ks_h}, \text{ks_w})`, where :math:`\text{ks_h}` and
:math:`\text{ks_w}` are the height and width of the convolution kernel. The full kernel has shape
:math:`(C_{out}, C_{in} // \text{group}, \text{ks_h}, \text{ks_w})`, where group is the group number
to split the input in the channel dimension.

If the 'pad_mode' is set to be "valid", the output height and width will be
:math:`\left \lfloor{1 + \frac{H_{in} + 2 \times \text{padding} - \text{ks_h} -
(\text{ks_h} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` and
:math:`\left \lfloor{1 + \frac{W_{in} + 2 \times \text{padding} - \text{ks_w} -
(\text{ks_w} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` respectively.

The first introduction can be found in paper `Gradient Based Learning Applied to Document Recognition
<http://vision.stanford.edu/cs598_spring07/papers/Lecun98.pdf>`_.

Args:
in_channels (int): The number of input channel :math:`C_{in}`.
out_channels (int): The number of output channel :math:`C_{out}`.
kernel_size (Union[int, tuple[int]]): The data type is int or a tuple of 2 integers. Specifies the height
and width of the 2D convolution window. Single int means the value is for both the height and the width of
the kernel. A tuple of 2 ints means the first value is for the height and the other is for the
width of the kernel.
stride (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
the height and width of movement are both strides, or a tuple of two int numbers that
represent height and width of movement respectively. Default: 1.
pad_mode (str): Specifies padding mode. The optional values are
"same", "valid", "pad". Default: "same".

- same: Adopts the way of completion. The height and width of the output will be the same as
the input. The total number of padding will be calculated in horizontal and vertical
directions and evenly distributed to top and bottom, left and right if possible. Otherwise, the
last extra padding will be done from the bottom and the right side. If this mode is set, `padding`
must be 0.

- valid: Adopts the way of discarding. The possible largest height and width of output will be returned
without padding. Extra pixels will be discarded. If this mode is set, `padding`
must be 0.

- pad: Implicit paddings on both sides of the input. The number of `padding` will be padded to the input
Tensor borders. `padding` must be greater than or equal to 0.

padding (Union[int, tuple[int]]): Implicit paddings on both sides of the input. If `padding` is one integer,
the paddings of top, bottom, left and right are the same, equal to padding. If `padding` is a tuple
with four integers, the paddings of top, bottom, left and right will be equal to padding[0],
padding[1], padding[2], and padding[3] accordingly. Default: 0.
dilation (Union[int, tuple[int]]): The data type is int or a tuple of 2 integers. Specifies the dilation rate
to use for dilated convolution. If set to be :math:`k > 1`, there will
be :math:`k - 1` pixels skipped for each sampling location. Its value must
be greater or equal to 1 and bounded by the height and width of the
input. Default: 1.
group (int): Splits filter into groups, `in_ channels` and `out_channels` must be
divisible by the number of groups. If the group is equal to `in_channels` and `out_channels`,
this 2D convolution layer also can be called 2D depthwise convolution layer. Default: 1.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
It can be a Tensor, a string, an Initializer or a number. When a string is specified,
values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well
as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones'
and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of
Initializer for more details. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible
Initializer and string are the same as 'weight_init'. Refer to the values of
Initializer for more details. Default: 'zeros'.

Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.

Outputs:
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.

Examples:
>>> net = nn.Conv2d(120, 240, 4, has_bias=False, weight_init='normal')
>>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
>>> net(input).shape
(1, 240, 1024, 640)
"""

def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
pad_mode='same',
padding=0,
dilation=1,
group=1,
has_bias=False,
weight_init='normal',
bias_init='zeros'):
kernel_size = twice(kernel_size)
stride = twice(stride)
self._dilation = dilation
dilation = twice(dilation)
super(Conv2d_Thor, self).__init__(
in_channels,
out_channels,
kernel_size,
stride,
pad_mode,
padding,
dilation,
group,
has_bias,
weight_init,
bias_init)
self.conv2d = P.Conv2D(out_channel=self.out_channels,
kernel_size=self.kernel_size,
mode=1,
pad_mode=self.pad_mode,
pad=self.padding,
stride=self.stride,
dilation=self.dilation,
group=self.group)
self._init_depthwise_conv2d(weight_init)
self.bias_add = P.BiasAdd()

self.thor = True
self.hw = kernel_size[0] * kernel_size[1]
self.matrix_A_dim = self.in_channels * self.kernel_size[0] * self.kernel_size[1]
self.matrix_G_dim = self.out_channels
self.shape = P.Shape()
self.reshape = P.Reshape()
self.mul = P.Mul()
self.cast = P.Cast()
self.A_normalizer = Parameter(initializer(0, [1], mstype.float32), name="A_normalizer", requires_grad=False)
self.G_normalizer = Parameter(initializer(0, [1], mstype.float32), name="G_normalizer", requires_grad=False)
self.is_Ascend = True
if context.get_context("device_target") == "Ascend":
ksizes = (1, kernel_size[0], kernel_size[1], 1)
strides = (1, stride[0], stride[1], 1)
self.img2col = P.CusImg2Col(ksizes=ksizes, strides=strides)
self.cube_matmul = P.CusMatMulCube(transpose_a=True)
self.transpose02314 = P.CusTranspose02314()
dampingA_dim = self.matrix_A_dim
self.diag_block_dim = 128
if (self.matrix_A_dim % self.diag_block_dim) != 0 and self.matrix_A_dim > self.diag_block_dim:
dampingA_dim = (self.matrix_A_dim // self.diag_block_dim + 1) * self.diag_block_dim
dampingG_dim = self.matrix_G_dim
if (self.matrix_G_dim % self.diag_block_dim) != 0 and self.matrix_G_dim > self.diag_block_dim:
dampingG_dim = (self.matrix_G_dim // self.diag_block_dim + 1) * self.diag_block_dim
self.matrix_A_cov = Parameter(Tensor(np.zeros([dampingA_dim, dampingA_dim]).astype(np.float32)),
name='matrix_A', requires_grad=False)
self.matrix_G_cov = Parameter(Tensor(np.zeros([dampingG_dim, dampingG_dim]).astype(np.float32)),
name='matrix_G', requires_grad=False)

self.channels_slice_flag = False
self.C0 = 16
if self.in_channels % self.C0 != 0:
self.channels_slice_flag = True
self.padA_flag = False
if (self.matrix_A_dim // self.diag_block_dim) * self.diag_block_dim != self.matrix_A_dim \
and self.matrix_A_dim > self.diag_block_dim:
self.padA_flag = True
pad_dim = self.diag_block_dim - self.matrix_A_dim % self.diag_block_dim
self.padA = P.Pad(((0, pad_dim), (0, pad_dim)))
self.slice = P.Slice()
else:
self.is_Ascend = False
self.img2col = P.Im2Col(kernel_size=kernel_size, stride=stride, pad_mode="same")
self.matmul = P.MatMul(transpose_b=True)
self.reduce_mean = P.ReduceMean(keep_dims=False)
self.matrix_A_cov = Parameter(Tensor(np.zeros([self.matrix_A_dim, self.matrix_A_dim]).astype(np.float32)),
name='matrix_A', requires_grad=False)
self.matrix_G_cov = Parameter(Tensor(np.zeros([self.matrix_G_dim, self.matrix_G_dim]).astype(np.float32)),
name='matrix_G', requires_grad=False)
self.getG = P.InsertGradientOf(self.save_gradient)


def _init_depthwise_conv2d(self, weight_init):
"""Initialize depthwise conv2d op"""
if context.get_context("device_target") == "Ascend" and self.group > 1:
self.dilation = self._dilation
Validator.check_integer('group', self.group, self.in_channels, Rel.EQ)
Validator.check_integer('group', self.group, self.out_channels, Rel.EQ)
self.conv2d = P.DepthwiseConv2dNative(channel_multiplier=1,
kernel_size=self.kernel_size,
pad_mode=self.pad_mode,
pad=self.padding,
stride=self.stride,
dilation=self.dilation)
weight_shape = [1, self.in_channels, *self.kernel_size]
self.weight_init = weight_init
if isinstance(weight_init, Tensor):
self.weight_init = Tensor(weight_init.asnumpy().swapaxes(0, 1), weight_init.dtype)
if isinstance(weight_init, Initializer):
self.weight_init.shape = weight_shape
self.weight = Parameter(initializer(self.weight_init, weight_shape), name='weight')


def save_gradient(self, dout):
"""save_gradient"""
out = dout
if self.is_Ascend:
dout = self.transpose02314(dout)
dout_shape = self.shape(dout)
normalizer = dout_shape[0]
matrix_G = self.cube_matmul(dout, dout)
normalizer = self.cast(normalizer, mstype.float32)
matrix_G = self.mul(matrix_G, 1.0 / normalizer)
self.G_normalizer = normalizer
self.matrix_G_cov = matrix_G
else:
dout = self.reduce_mean(dout, 0)
dout_shape = self.shape(dout)
dout = self.reshape(dout, (dout_shape[0], -1))
dout_shape = self.shape(dout)
normalizer = dout_shape[1]
dout = self.cast(dout, mstype.float32)
matrix_G = self.matmul(dout, dout)
matrix_G = self.mul(matrix_G, 1.0 / normalizer)
self.G_normalizer = normalizer
self.matrix_G_cov = matrix_G
return out



def construct(self, x):
if self.thor:
matrix_A = self.img2col(x)
matrix_A_shape = self.shape(matrix_A)
if self.is_Ascend:
normalizer = matrix_A_shape[0]
matrix_A = self.cube_matmul(matrix_A, matrix_A)
if self.channels_slice_flag:
matrix_A = self.reshape(matrix_A, (self.hw, self.C0, self.hw, self.C0))
matrix_A = self.slice(matrix_A, (0, 0, 0, 0),
(self.hw, self.in_channels, self.hw, self.in_channels))
matrix_A = self.reshape(matrix_A, (self.matrix_A_dim, self.matrix_A_dim))
normalizer = self.cast(normalizer, mstype.float32)
matrix_A = self.mul(matrix_A, 1.0 / normalizer)
if self.padA_flag:
matrix_A = self.padA(matrix_A)
self.A_normalizer = normalizer
self.matrix_A_cov = matrix_A
else:
matrix_A = self.reshape(matrix_A, (matrix_A_shape[0] * matrix_A_shape[1] * matrix_A_shape[2],
matrix_A_shape[3], -1))
matrix_A = self.reduce_mean(matrix_A, 1)
matrix_A_shape = self.shape(matrix_A)
normalizer = matrix_A_shape[1]
matrix_A = self.cast(matrix_A, mstype.float32)
matrix_A = self.matmul(matrix_A, matrix_A)
matrix_A = self.mul(matrix_A, 1.0 / normalizer)
self.A_normalizer = normalizer
self.matrix_A_cov = matrix_A
output = self.conv2d(x, self.weight)
output = self.getG(output)
else:
output = self.conv2d(x, self.weight)
if self.has_bias:
output = self.bias_add(output, self.bias)
return output

def extend_repr(self):
s = 'input_channels={}, output_channels={}, kernel_size={},' \
'stride={}, pad_mode={}, padding={}, dilation={}, ' \
'group={}, has_bias={},' \
'weight_init={}, bias_init={}'.format(
self.in_channels,
self.out_channels,
self.kernel_size,
self.stride,
self.pad_mode,
self.padding,
self.dilation,
self.group,
self.has_bias,
self.weight_init,
self.bias_init)
return s

class Embedding_Thor(Cell):
r"""
A simple lookup table that stores embeddings of a fixed dictionary and size.

This module is often used to store word embeddings and retrieve them using
indices. The input to the module is a list of indices, and the output is
the corresponding word embeddings.

Note:
When 'use_one_hot' is set to True, the type of the input must be mindspore.int32.

Args:
vocab_size (int): Size of the dictionary of embeddings.
embedding_size (int): The size of each embedding vector.
use_one_hot (bool): Specifies whether to apply one_hot encoding form. Default: False.
embedding_table (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the embedding_table.
Refer to class `initializer` for the values of string when a string
is specified. Default: 'normal'.
dtype (:class:`mindspore.dtype`): Data type of input. Default: mindspore.float32.
padding_idx (int, None): When the padding_idx encounters index, the output embedding vector of this index
will be initialized to zero. Default: None. The feature is inactivated.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(\text{batch_size}, \text{input_length})`. The elements of
the Tensor must be integer and not larger than vocab_size. Otherwise the corresponding embedding vector will
be zero.

Outputs:
Tensor of shape :math:`(\text{batch_size}, \text{input_length}, \text{embedding_size})`.

Examples:
>>> net = nn.Embedding(20000, 768, True)
>>> input_data = Tensor(np.ones([8, 128]), mindspore.int32)
>>>
>>> # Maps the input word IDs to word embedding.
>>> output = net(input_data)
>>> output.shape
(8, 128, 768)
"""

def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal',
dtype=mstype.float32, padding_idx=None):
super(Embedding_Thor, self).__init__()
self.vocab_size = Validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name)
self.embedding_size = Validator.check_value_type('embedding_size', embedding_size, [int], self.cls_name)
Validator.check_value_type('use_one_hot', use_one_hot, [bool], self.cls_name)
Validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
self.use_one_hot = use_one_hot
self.dtype = dtype
self.init_tensor = initializer(embedding_table, [vocab_size, embedding_size])
self.padding_idx = padding_idx
if padding_idx is not None:
self.padding_idx = Validator.check_int_range(padding_idx, 0, vocab_size, Rel.INC_BOTH,
"padding_idx", self.cls_name)
self.init_tensor = self.init_tensor.to_tensor().asnumpy()
self.init_tensor[self.padding_idx] = 0
self.embedding_table = Parameter(self.init_tensor, name='embedding_table')
self.expand = P.ExpandDims()
self.reshape_flat = P.Reshape()
self.shp_flat = (-1,)
self.gather = P.GatherV2()
self.one_hot = P.OneHot()
self.on_value = Tensor(1.0, self.dtype)
self.off_value = Tensor(0.0, self.dtype)
self.array_mul = P.MatMul()
self.reshape = P.Reshape()
self.get_shp = P.Shape()
self.thor = True
self.matrix_A = Parameter(Tensor(np.zeros([vocab_size]).astype(np.float32)),
name='matrix_A', requires_grad=False)
self.matrix_G = Parameter(Tensor(np.zeros([embedding_size, embedding_size]).astype(np.float32)),
name="matrix_G", requires_grad=False)
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.getG = P.InsertGradientOf(self.save_gradient)
self.cast = P.Cast()
if context.get_context("device_target") == "Ascend":
self.cube_matmul = P.CusMatMulCube(transpose_a=True)
else:
self.cube_matmul = P.MatMul(transpose_a=True)
self.mul = P.Mul()


def save_gradient(self, dout):
"""
this function only for thor optimizer
save_gradient
"""
out = dout
shape = self.get_shp(dout)
normalizer = self.cast(shape[0], mstype.float32)
matrix_G = self.cube_matmul(dout, dout)
matrix_G = self.mul(matrix_G, 1.0 / normalizer)
self.matrix_G = matrix_G
return out

def construct(self, ids):
extended_ids = self.expand(ids, -1)
out_shape = self.get_shp(ids) + (self.embedding_size,)
flat_ids = self.reshape_flat(extended_ids, self.shp_flat)

if self.use_one_hot:
one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
output_for_reshape = self.array_mul(one_hot_ids, self.embedding_table)
else:
if self.thor:
one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
matrix_A = self.reduce_sum(one_hot_ids, 0)
self.matrix_A = matrix_A
output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
output_for_reshape = self.getG(output_for_reshape)
else:
output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)


output = self.reshape(output_for_reshape, out_shape)
return output


def extend_repr(self):
s = 'vocab_size={}, embedding_size={}, use_one_hot={}, embedding_table={}, dtype={}, padding_idx={}'.format(
self.vocab_size, self.embedding_size, self.use_one_hot, self.embedding_table, self.dtype, self.padding_idx)
return s

+ 2
- 1
mindspore/nn/optim/__init__.py View File

@@ -29,6 +29,7 @@ from .rmsprop import RMSProp
from .proximal_ada_grad import ProximalAdagrad
from .lazyadam import LazyAdam
from .ada_grad import Adagrad
from .thor import THOR

__all__ = ['Optimizer', 'Momentum', 'LARS', 'Adam', 'AdamWeightDecay', 'LazyAdam', 'AdamOffload',
'Lamb', 'SGD', 'FTRL', 'RMSProp', 'ProximalAdagrad', 'Adagrad']
'Lamb', 'SGD', 'FTRL', 'RMSProp', 'ProximalAdagrad', 'Adagrad', 'THOR']

+ 897
- 0
mindspore/nn/optim/thor.py View File

@@ -0,0 +1,897 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""THOR"""
import numpy as np
from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common.tensor import Tensor
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore._checkparam import Validator
from mindspore.nn.optim.optimizer import Optimizer
from mindspore.parallel._utils import _get_device_num, _get_gradients_mean
from mindspore import context
from mindspore.context import ParallelMode
from mindspore.nn.layer import Dense_Thor, Conv2d_Thor, Embedding_Thor
from mindspore.nn.wrap import DistributedGradReducer
from mindspore.train.train_thor.convert_utils import ConvertNetUntils
from mindspore.parallel._auto_parallel_context import auto_parallel_context

# Enumerates types of Layer
Other = -1
Conv = 1
FC = 2
Embedding = 3
LayerNorm = 4
BatchNorm = 5


_momentum_opt = C.MultitypeFuncGraph("momentum_opt")

op_add = P.AddN()
apply_decay = C.MultitypeFuncGraph("apply_decay")


@apply_decay.register("Number", "Bool", "Tensor", "Tensor")
def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
"""Get grad with weight_decay."""
if if_apply:
return op_add((weight * weight_decay, gradient))
return gradient


@_momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment):
"""Apply momentum optimizer to the weight parameter using Tensor."""
success = True
success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum))
return success

C0 = 16


def caculate_device_shape(matrix_dim, channel, is_A):
ll = (0)
if is_A:
if channel // C0 == 0:
matrix_dim = (matrix_dim / channel) * C0
ll = (int(matrix_dim // C0), int(matrix_dim // C0), C0, C0), int(matrix_dim)
else:
ll = (int(matrix_dim // C0), int(matrix_dim // C0), C0, C0), int(matrix_dim)
return ll


def caculate_matmul_shape(matrix_A_dim, matrix_G_dim, split_dim):
"""get matmul shape"""
split_dimA = split_dim
split_dimG = split_dim
if matrix_A_dim % split_dim == 0:
batch_w = matrix_A_dim // split_dim
else:
if matrix_A_dim < split_dim:
batch_w = 1
split_dimA = matrix_A_dim
else:
batch_w = matrix_A_dim // split_dim + 1

if matrix_G_dim % split_dim == 0:
batch_h = matrix_G_dim // split_dim
else:
if matrix_G_dim < split_dim:
batch_h = 1
split_dimG = matrix_G_dim
else:
batch_h = matrix_G_dim // split_dim + 1
matrix_A_shape = (batch_h, batch_w, split_dimA, split_dimA)
matrix_G_shape = (batch_h, split_dimG, split_dimG)
return matrix_A_shape, matrix_G_shape


def find_net_layertype_recur(net, layertype_map):
"""get net layer type recursively."""
cells = net.name_cells()
for name in cells:
subcell = cells[name]
print("thor subcell name: ", name)
if subcell == net:
continue
elif isinstance(subcell, Conv2d_Thor):
layertype_map.append(Conv)
elif isinstance(subcell, Dense_Thor):
layertype_map.append(FC)
elif isinstance(subcell, Embedding_Thor):
layertype_map.append(Embedding)
elif isinstance(subcell, nn.LayerNorm):
layertype_map.append(LayerNorm)
elif isinstance(subcell, nn.BatchNorm2d):
layertype_map.append(BatchNorm)
elif isinstance(subcell, (nn.Conv2d, nn.Dense, nn.Embedding, nn.Conv2dTranspose, nn.Conv1d, nn.Conv1dTranspose,
nn.BatchNorm1d, nn.GroupNorm, nn.GlobalBatchNorm)):
layertype_map.append(Other)
else:
find_net_layertype_recur(subcell, layertype_map)

def get_net_layertype_mask(net):
layertype_map = []
find_net_layertype_recur(net, layertype_map)
return layertype_map

def get_layer_counter(layer_type, layer_counter, params, idx):
"""get layer counter"""
if layer_type in [Conv, FC, LayerNorm, BatchNorm]:
if layer_type in [LayerNorm, BatchNorm]:
if "beta" in params[idx].name.lower():
layer_counter = layer_counter + 1
else:
if "bias" in params[idx].name.lower():
layer_counter = layer_counter + 1
else:
if idx < len(params) - 1 and "bias" not in params[idx + 1].name.lower():
layer_counter = layer_counter + 1
else:
layer_counter = layer_counter + 1
return layer_counter


def THOR(net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0, batch_size=32,
use_nesterov=False, decay_filter=lambda x: x.name not in [], split_indices=None):
context.set_context(max_call_depth=10000)
ConvertNetUntils().convert_to_thor_net(net)
if context.get_context("device_target") == "Ascend":
return THOR_Ascend(net, learning_rate, damping, momentum, weight_decay, loss_scale, batch_size, decay_filter,
split_indices=split_indices)
return THOR_GPU(net, learning_rate, damping, momentum, weight_decay, loss_scale, batch_size,
use_nesterov, decay_filter, split_indices=split_indices)


class THOR_GPU(Optimizer):
"""
THOR_GPU
"""
def __init__(self, net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0, batch_size=32,
use_nesterov=False, decay_filter=lambda x: x.name not in [], split_indices=None):
params = filter(lambda x: x.requires_grad, net.get_parameters())
super(THOR_GPU, self).__init__(learning_rate, params, weight_decay, loss_scale)
Validator.check_value_type("momentum", momentum, [float], self.cls_name)
if isinstance(momentum, float) and momentum < 0.0:
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
self.params = self.parameters
self.use_nesterov = Validator.check_bool(use_nesterov)
self.moments = self.params.clone(prefix="moments", init='zeros')
self.hyper_map = C.HyperMap()
self.opt = P.ApplyMomentum(use_nesterov=self.use_nesterov)
self.net = net
self.matrix_A_cov = ParameterTuple(filter(lambda x: 'matrix_A' in x.name, net.get_parameters()))
self.matrix_G_cov = ParameterTuple(filter(lambda x: 'matrix_G' in x.name, net.get_parameters()))
self.A_normalizer = ParameterTuple(filter(lambda x: 'A_normalizer' in x.name, net.get_parameters()))
self.G_normalizer = ParameterTuple(filter(lambda x: 'G_normalizer' in x.name, net.get_parameters()))
self.transpose = P.Transpose()
self.shape = P.Shape()
self.reshape = P.Reshape()
self.matmul = P.MatMul()
self.assign = P.Assign()
self.mul = P.Mul()
self.damping = damping
self.gather = P.GatherV2()
self.one = Tensor(1, mstype.int32)
self.batch_size = Tensor(batch_size, mstype.float32)
self.loss_scale = Tensor(1 / (loss_scale * loss_scale), mstype.float32)
self.batch_size_scale = Tensor(batch_size * batch_size, mstype.float32)
self.feature_map = Tensor(1.0, mstype.float32)
self.axis = 0
self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False)
self.cast = P.Cast()
self.sqrt = P.Sqrt()
self.eye = P.Eye()
split_dim = 128
self.embedding_cholesky = P.CholeskyTrsm()
self.cholesky = P.CholeskyTrsm(split_dim=split_dim)
self.vector_matmul = P.BatchMatMul(transpose_a=True)
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.inv = P.Reciprocal()
self.square = P.Square()
self.expand = P.ExpandDims()
self.thor = True

self.matrix_A = ()
self.matrix_G = ()
self.matrix_A_shape = ()
self.thor_layer_count = 0
self.conv_layer_count = 0
self.weight_fim_idx_map = ()
self.weight_conv_idx_map = ()
self.weight_layerType_idx_map = ()
layer_type_map = get_net_layertype_mask(net)

layer_counter = 0
for idx in range(len(self.params)):
layer_type = layer_type_map[layer_counter]
weight = self.params[idx]
weight_shape = self.shape(weight)
if layer_type in [Conv, FC] and "bias" not in self.params[idx].name.lower():
in_channels = weight_shape[1]
out_channels = weight_shape[0]
matrix_A_dim = in_channels
if layer_type == Conv:
matrix_A_dim = in_channels * weight_shape[2] * weight_shape[3]
matrix_G_dim = out_channels
matrix_A_shape, matrix_G_shape = caculate_matmul_shape(matrix_A_dim, matrix_G_dim, split_dim)
matrix_A_inv = Parameter(np.zeros(matrix_A_shape).astype(np.float32),
name='matrix_A_inv_' + str(self.thor_layer_count), requires_grad=False)
matrix_G_inv = Parameter(np.zeros(matrix_G_shape).astype(np.float32),
name="matrix_G_inv_" + str(self.thor_layer_count), requires_grad=False)
self.matrix_A = self.matrix_A + (matrix_A_inv,)
self.matrix_G = self.matrix_G + (matrix_G_inv,)
self.matrix_A_shape = self.matrix_A_shape + (matrix_A_shape,)
elif layer_type == Embedding:
vocab_size = weight_shape[0]
embedding_size = weight_shape[1]
matrix_A_inv = Parameter(Tensor(np.zeros([vocab_size]).astype(np.float32)),
name='matrix_A_inv_' + str(self.thor_layer_count), requires_grad=False)
matrix_G_inv = Parameter(Tensor(np.zeros([embedding_size, embedding_size]).astype(np.float32)),
name="matrix_G_inv_" + str(self.thor_layer_count), requires_grad=False)
self.matrix_A = self.matrix_A + (matrix_A_inv,)
self.matrix_G = self.matrix_G + (matrix_G_inv,)
self.matrix_A_shape = self.matrix_A_shape + ((vocab_size,),)

if layer_type in [Conv, FC, Embedding] and "bias" not in self.params[idx].name.lower():
self.weight_fim_idx_map = self.weight_fim_idx_map + (self.thor_layer_count,)
self.weight_layerType_idx_map = self.weight_layerType_idx_map + (layer_type,)
self.thor_layer_count = self.thor_layer_count + 1
if layer_type == Conv:
self.weight_conv_idx_map = self.weight_conv_idx_map + (self.conv_layer_count,)
self.conv_layer_count = self.conv_layer_count + 1
else:
self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,)
else:
self.weight_fim_idx_map = self.weight_fim_idx_map + (-1,)
self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,)
if layer_type == LayerNorm:
self.weight_layerType_idx_map = self.weight_layerType_idx_map + (LayerNorm,)
else:
self.weight_layerType_idx_map = self.weight_layerType_idx_map + (Other,)
# bert.cls1.output_bias: not a network layer, only a trainable param
if "output_bias" not in self.params[idx].name.lower():
layer_counter = get_layer_counter(layer_type, layer_counter, self.params, idx)

self.matrix_A = ParameterTuple(self.matrix_A)
self.matrix_G = ParameterTuple(self.matrix_G)
self.weight_decay = weight_decay
self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
self.update_gradient = P.UpdateThorGradient(split_dim=split_dim)

self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
if self.is_distributed:
mean = _get_gradients_mean()
degree = _get_device_num()
if self.conv_layer_count > 0:
if not split_indices:
self.split_indices = split_indices
else:
self.split_indices = [len(self.matrix_A) - 1]
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum2")
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum4")
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum6")
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum8")
self.grad_reducer_Amax = DistributedGradReducer(self.matrix_A, mean, degree, fusion_type=2)
self.grad_reducer_Gmax = DistributedGradReducer(self.matrix_A, mean, degree, fusion_type=4)
self.grad_reducer_A = DistributedGradReducer(self.matrix_A, mean, degree, fusion_type=6)
self.grad_reducer_G = DistributedGradReducer(self.matrix_A, mean, degree, fusion_type=8)
else:
if not split_indices:
self.split_indices = split_indices
else:
self.split_indices = [len(self.params) - 1]
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum3")
self.grad_reducer_g = DistributedGradReducer(self.params, mean, degree, fusion_type=3)

def _get_Ainv_Ginv_list(self, gradients, damping_step, matrix_a_allreduce, matrix_g_allreduce):
"""get matrixA inverse list and matrix G inverse list"""
for i in range(len(self.params)):
thor_layer_count = self.weight_fim_idx_map[i]
conv_layer_count = self.weight_conv_idx_map[i]
layer_type = self.weight_layerType_idx_map[i]
if layer_type in [Conv, FC, Embedding]:
g = gradients[i]
matrix_A = self.matrix_A_cov[thor_layer_count]
matrix_G = self.matrix_G_cov[thor_layer_count]
matrix_A = F.depend(matrix_A, g)
matrix_G = F.depend(matrix_G, g)
dampingA = damping_step
dampingG = damping_step
feature_map = self.feature_map
if layer_type == Conv:
A_normalizer = self.A_normalizer[conv_layer_count]
G_normalizer = self.G_normalizer[conv_layer_count]
A_normalizer = F.depend(A_normalizer, g)
G_normalizer = F.depend(G_normalizer, g)
dampingA = self.mul(damping_step, 1.0 / A_normalizer)
dampingG = self.mul(damping_step, 1.0 / G_normalizer)
feature_map = self.sqrt(1.0 / A_normalizer)
A_shape = self.shape(matrix_A)
A_eye = self.eye(A_shape[0], A_shape[0], mstype.float32)
dampingA = self.sqrt(dampingA)
dampingG = self.sqrt(dampingG)
G_shape = self.shape(matrix_G)
G_eye = self.eye(G_shape[0], G_shape[1], mstype.float32)
matrix_G = self.mul(matrix_G, self.loss_scale)
matrix_G = self.mul(matrix_G, self.batch_size_scale)
matrix_G = matrix_G + dampingG * G_eye
if layer_type == Embedding:
A_eye = P.OnesLike()(matrix_A)
matrix_A = self.mul(matrix_A, 1.0 / self.batch_size)
matrix_A = matrix_A + dampingA * A_eye
matrix_A = self.inv(matrix_A)
matrix_G = self.embedding_cholesky(matrix_G)
matrix_G = self.matmul(matrix_G, matrix_G)
else:
matrix_A = matrix_A + dampingA * A_eye
matrix_A = self.cholesky(matrix_A)
matrix_A = self.vector_matmul(matrix_A, matrix_A)
matrix_A = P.BroadcastTo(self.matrix_A_shape[thor_layer_count])(matrix_A)
matrix_G = self.cholesky(matrix_G)
matrix_G = self.vector_matmul(matrix_G, matrix_G)
matrix_A = self.mul(matrix_A, feature_map)
matrix_G = self.mul(matrix_G, feature_map)
matrix_a_allreduce = matrix_a_allreduce + (matrix_A,)
matrix_g_allreduce = matrix_g_allreduce + (matrix_G,)
return matrix_a_allreduce, matrix_g_allreduce

def construct(self, gradients):
params = self.params
moments = self.moments
gradients = self.scale_grad(gradients)
damping_step = self.gather(self.damping, self.cov_step, self.axis)
damping_step = self.cast(damping_step, mstype.float32)
new_grads = ()
if self.thor:
matrix_Ainv_list = ()
matrix_Ginv_list = ()
matrix_A_allreduce, matrix_G_allreduce = self._get_Ainv_Ginv_list(gradients, damping_step,
matrix_Ainv_list, matrix_Ginv_list)
if self.is_distributed and self.conv_layer_count > 0:
matrix_A_allreduce = self.grad_reducer_A(matrix_A_allreduce)
matrix_G_allreduce = self.grad_reducer_G(matrix_G_allreduce)

for i in range(len(self.params)):
g = gradients[i]
thor_layer_count = self.weight_fim_idx_map[i]
conv_layer_count = self.weight_conv_idx_map[i]
layer_type = self.weight_layerType_idx_map[i]
if layer_type in [Conv, FC]:
g_shape = self.shape(g)
g = self.reshape(g, (g_shape[0], -1))
matrix_A = matrix_A_allreduce[thor_layer_count]
matrix_G = matrix_G_allreduce[thor_layer_count]
g = self.update_gradient(matrix_G, g, matrix_A)
fake_A = self.assign(self.matrix_A[thor_layer_count], matrix_A)
fake_G = self.assign(self.matrix_G[thor_layer_count], matrix_G)
g = F.depend(g, fake_A)
g = F.depend(g, fake_G)
if conv_layer_count != -1:
g = self.reshape(g, g_shape)
elif layer_type == Embedding:
matrix_A = matrix_A_allreduce[thor_layer_count]
matrix_G = matrix_G_allreduce[thor_layer_count]
fake_A = self.assign(self.matrix_A[thor_layer_count], matrix_A)
fake_G = self.assign(self.matrix_G[thor_layer_count], matrix_G)
g = F.depend(g, fake_A)
g = F.depend(g, fake_G)
temp_a = self.expand(matrix_A, 1)
g = self.mul(temp_a, g)
g = self.matmul(g, matrix_G)
elif layer_type == LayerNorm:
damping = self.sqrt(damping_step)
normalizer = self.batch_size
normalizer = self.cast(normalizer, mstype.float32)
fim_cov = self.square(g)
fim_cov = self.mul(fim_cov, 1.0 / normalizer)
fim_cov = fim_cov + damping
fim_inv = self.inv(fim_cov)
g = self.mul(fim_inv, g)
new_grads = new_grads + (g,)
else:
for j in range(len(self.params)):
g = gradients[j]
thor_layer_count = self.weight_fim_idx_map[j]
conv_layer_count = self.weight_conv_idx_map[j]
layer_type = self.weight_layerType_idx_map[j]
if layer_type in [Conv, FC]:
g_shape = self.shape(g)
g = self.reshape(g, (g_shape[0], -1))
matrix_A = self.matrix_A[thor_layer_count]
matrix_G = self.matrix_G[thor_layer_count]
g = self.update_gradient(matrix_G, g, matrix_A)
if conv_layer_count != -1:
g = self.reshape(g, g_shape)
elif layer_type == Embedding:
matrix_A = self.matrix_A[thor_layer_count]
matrix_G = self.matrix_G[thor_layer_count]
g = gradients[j]
temp_a = self.expand(matrix_A, 1)
g = self.mul(temp_a, g)
g = self.matmul(g, matrix_G)
elif layer_type == LayerNorm:
damping = self.sqrt(damping_step)
normalizer = self.batch_size
normalizer = self.cast(normalizer, mstype.float32)
fim_cov = self.square(g)
fim_cov = self.mul(fim_cov, 1.0 / normalizer)
fim_cov = fim_cov + damping
fim_inv = self.inv(fim_cov)
g = self.mul(fim_inv, g)
new_grads = new_grads + (g,)
gradients = new_grads

if self.is_distributed and self.conv_layer_count == 0:
gradients = self.grad_reducer_g(gradients)
self.cov_step = self.cov_step + self.one
if self.weight_decay > 0:
gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags, params, gradients)
lr = self.get_lr()
success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments)
return success


class THOR_Ascend(Optimizer):
"""THOR"""

def __init__(self, net, learning_rate, damping, momentum, weight_decay=0.0, loss_scale=1.0, batch_size=32,
decay_filter=lambda x: x.name not in [], split_indices=None):
params = filter(lambda x: x.requires_grad, net.get_parameters())
super(THOR_Ascend, self).__init__(learning_rate, params, weight_decay, loss_scale)
if isinstance(momentum, float) and momentum < 0.0:
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
self.params = self.parameters
self.moments = self.params.clone(prefix="moments", init='zeros')
self.hyper_map = C.HyperMap()
self.opt = P.ApplyMomentum()
self.net = net
self.matrix_A_cov = ParameterTuple(filter(lambda x: 'matrix_A' in x.name, net.get_parameters()))
self.matrix_G_cov = ParameterTuple(filter(lambda x: 'matrix_G' in x.name, net.get_parameters()))
self.A_normalizer = ParameterTuple(filter(lambda x: 'A_normalizer' in x.name, net.get_parameters()))
self.G_normalizer = ParameterTuple(filter(lambda x: 'G_normalizer' in x.name, net.get_parameters()))
self.cube_matmul_left = P.CusMatMulCubeFraczLeftCast()
self.cube_matmul_left_fc = P.CusMatMulCubeDenseLeft()
self.cube_matmul_right_fc = P.CusMatMulCubeDenseRight()
self.cube_matmul_right_mul = P.CusMatMulCubeFraczRightMul()
self.transpose = P.Transpose()
self.shape = P.Shape()
self.reshape = P.Reshape()
self.mul = P.Mul()

self.C0 = 16
self.matrix_A_dim = ()
self.padA_flag = ()
self.device_shape_pad_flag = ()
self.diag_block_dim = 128
self.matrix_A = ()
self.matrix_G = ()
print("matrix_A_cov len is", len(self.matrix_A_cov))
self.thor_layer_count = 0
self.conv_layer_count = 0
self.weight_fim_idx_map = ()
self.weight_conv_idx_map = ()
self.weight_layerType_idx_map = ()
layer_type_map = get_net_layertype_mask(net)
layer_counter = 0
for idx in range(len(self.params)):
layer_type = layer_type_map[layer_counter]
weight = self.params[idx]
weight_shape = self.shape(weight)
if layer_type == Conv and "bias" not in self.params[idx].name.lower():
in_channels = weight_shape[1]
out_channels = weight_shape[0]
matrix_A_dim = in_channels * weight_shape[2] * weight_shape[3]
matrix_G_dim = out_channels
matrix_A_device_shape, matrix_A_device_dim = caculate_device_shape(matrix_A_dim, in_channels, True)
matrix_G_device_shape, matrix_G_device_dim = caculate_device_shape(matrix_G_dim, in_channels, False)
matrix_A_inv = Parameter(
Tensor(np.reshape(np.identity(matrix_A_device_dim).astype(np.float16), matrix_A_device_shape)),
name='matrix_A_inv_' + str(self.thor_layer_count), requires_grad=False)
matrix_G_inv = Parameter(
Tensor(np.reshape(np.identity(matrix_G_device_dim).astype(np.float16), matrix_G_device_shape)),
name="matrix_G_inv_" + str(self.thor_layer_count), requires_grad=False)
self.matrix_A = self.matrix_A + (matrix_A_inv,)
self.matrix_G = self.matrix_G + (matrix_G_inv,)
self.matrix_A_dim = self.matrix_A_dim + (matrix_A_dim,)
padA_flag = False
if (matrix_A_dim // self.diag_block_dim) * self.diag_block_dim != matrix_A_dim \
and matrix_A_dim > self.diag_block_dim:
padA_flag = True
self.padA_flag = self.padA_flag + (padA_flag,)
device_shape_pad_flag = False
if matrix_A_dim != matrix_A_device_dim:
device_shape_pad_flag = True
self.device_shape_pad_flag = self.device_shape_pad_flag + (device_shape_pad_flag,)
elif layer_type == FC and "bias" not in self.params[idx].name.lower():
out_channels = weight_shape[0]
if out_channels == 1001:
fc_matrix_A = Parameter(Tensor(np.zeros([128, 128, 16, 16]).astype(np.float16)),
name='matrix_A_inv_' + str(self.thor_layer_count),
requires_grad=False)
fc_matrix_G = Parameter(Tensor(np.zeros([63, 63, 16, 16]).astype(np.float16)),
name="matrix_G_inv_" + str(self.thor_layer_count),
requires_grad=False)
self.matrix_A = self.matrix_A + (fc_matrix_A,)
self.matrix_G = self.matrix_G + (fc_matrix_G,)

if layer_type in [Conv, FC, Embedding] and "bias" not in self.params[idx].name.lower():
self.weight_fim_idx_map = self.weight_fim_idx_map + (self.thor_layer_count,)
self.weight_layerType_idx_map = self.weight_layerType_idx_map + (layer_type,)
self.thor_layer_count = self.thor_layer_count + 1
if layer_type == Conv:
self.weight_conv_idx_map = self.weight_conv_idx_map + (self.conv_layer_count,)
self.conv_layer_count = self.conv_layer_count + 1
else:
self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,)
else:
self.weight_fim_idx_map = self.weight_fim_idx_map + (-1,)
self.weight_conv_idx_map = self.weight_conv_idx_map + (-1,)
if layer_type == LayerNorm:
self.weight_layerType_idx_map = self.weight_layerType_idx_map + (LayerNorm,)
else:
self.weight_layerType_idx_map = self.weight_layerType_idx_map + (Other,)
# bert.cls1.output_bias: not a network layer, only a trainable param
if "output_bias" not in self.params[idx].name.lower():
layer_counter = get_layer_counter(layer_type, layer_counter, self.params, idx)

self.matrix_A = ParameterTuple(self.matrix_A)
self.matrix_G = ParameterTuple(self.matrix_G)
self.matrix_max_inv = ()
for i in range(len(self.matrix_A)):
self.matrix_max_inv = self.matrix_max_inv + (
Parameter(initializer(1, [1], mstype.float32), name="matrix_max" + str(i), requires_grad=False),)
self.log = P.Log()
self.exp = P.Exp()
self.sqrt = P.Sqrt()
self.matrix_max_inv = ParameterTuple(self.matrix_max_inv)
self.assign = P.Assign()
self.cast = P.Cast()
self.thor = True
self.weight_decay = weight_decay * loss_scale
self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
self.damping = damping
self.gather = P.GatherV2()
self.one = Tensor(1, mstype.int32)
self.batch_size = Tensor(batch_size, mstype.float32)
self.loss_scale = Tensor(1 / (loss_scale * loss_scale), mstype.float32)
self.batch_size_scale = Tensor(batch_size * batch_size, mstype.float32)
self.axis = 0
self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False)
self.cast = P.Cast()
self.eye = P.Eye()
self.cholesky = P.CusCholeskyTrsm()
self.vector_matmul = P.CusBatchMatMul()
self.fused_abs_max2 = P.CusFusedAbsMax1()
self.matrix_combine = P.CusMatrixCombine()
self.slice = P.Slice()
self.expand = P.ExpandDims()
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.square = P.Square()
self.inv = P.Inv()
self.matmul = P.MatMul()

self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
if self.is_distributed:
mean = _get_gradients_mean()
degree = _get_device_num()
if self.conv_layer_count > 0:
if not split_indices:
self.split_indices = split_indices
else:
self.split_indices = [len(self.matrix_A) - 1]
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum2")
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum4")
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum6")
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum8")
self.grad_reducer_Amax = DistributedGradReducer(self.matrix_A, mean, degree, fusion_type=2)
self.grad_reducer_Gmax = DistributedGradReducer(self.matrix_A, mean, degree, fusion_type=4)
self.grad_reducer_A = DistributedGradReducer(self.matrix_A, mean, degree, fusion_type=6)
self.grad_reducer_G = DistributedGradReducer(self.matrix_A, mean, degree, fusion_type=8)
else:
if not split_indices:
self.split_indices = split_indices
else:
self.split_indices = [len(self.params) - 1]
auto_parallel_context().set_all_reduce_fusion_split_indices(self.split_indices, "hccl_world_groupsum3")
self.grad_reducer_g = DistributedGradReducer(self.params, mean, degree, fusion_type=3)

def _get_Ainv_Ginv_Amax_Gmax_list(self, gradients, damping_step, matrix_a_allreduce, matrix_g_allreduce,
matrix_a_max_allreduce, matrix_g_max_allreduce):
"""get matrixA inverse list, matrixG inverse list, matrixA_max list, matrixG_max list"""
for i in range(len(self.params)):
thor_layer_count = self.weight_fim_idx_map[i]
conv_layer_count = self.weight_conv_idx_map[i]
layer_type = self.weight_layerType_idx_map[i]
if layer_type in [Conv, FC, Embedding]:
g = gradients[i]
matrix_A = self.matrix_A_cov[thor_layer_count]
matrix_G = self.matrix_G_cov[thor_layer_count]
matrix_A = F.depend(matrix_A, g)
matrix_G = F.depend(matrix_G, g)
A_shape = self.shape(matrix_A)
A_eye = self.eye(A_shape[0], A_shape[0], mstype.float32)
G_shape = self.shape(matrix_G)
G_eye = self.eye(G_shape[0], G_shape[0], mstype.float32)
if layer_type == Conv:
A_normalizer = self.A_normalizer[conv_layer_count]
G_normalizer = self.G_normalizer[conv_layer_count]
A_normalizer = F.depend(A_normalizer, g)
G_normalizer = F.depend(G_normalizer, g)
dampingA = self.mul(damping_step, self.batch_size / A_normalizer)
dampingG = self.mul(damping_step, self.batch_size / G_normalizer)
dampingA = self.sqrt(dampingA)
matrix_A = matrix_A + dampingA * A_eye
matrix_A_inv = self.cholesky(matrix_A)
matrix_A_inv = self.vector_matmul(matrix_A_inv, matrix_A_inv)
A_max = P.CusFusedAbsMax1([self.matrix_A_dim[conv_layer_count],
self.matrix_A_dim[conv_layer_count]])(matrix_A_inv)
A_max = self.fused_abs_max2(A_max)
matrix_A_inv = self.matrix_combine(matrix_A_inv)
if self.padA_flag[conv_layer_count]:
matrix_A_inv = self.slice(matrix_A_inv, (0, 0), (self.matrix_A_dim[conv_layer_count],
self.matrix_A_dim[conv_layer_count]))
if self.device_shape_pad_flag[conv_layer_count]:
weight = self.params[i]
weight_shape = self.shape(weight)
kernel_hw = weight_shape[2] * weight_shape[3]
in_channels = weight_shape[1]
matrix_A_inv = self.reshape(matrix_A_inv, (kernel_hw, in_channels, kernel_hw, in_channels))
matrix_A_inv = P.Pad(((0, 0), (0, self.C0 - in_channels), (0, 0),
(0, self.C0 - in_channels)))(matrix_A_inv)
matrix_A_inv_shape = self.shape(self.matrix_A[thor_layer_count])
matrix_A_device_temp_shape = (matrix_A_inv_shape[0], matrix_A_inv_shape[2],
matrix_A_inv_shape[1], matrix_A_inv_shape[3])
matrix_A_inv = self.reshape(matrix_A_inv, matrix_A_device_temp_shape)
matrix_A_inv = self.transpose(matrix_A_inv, (2, 0, 1, 3))

dampingG = self.sqrt(dampingG)
matrix_G = self.mul(matrix_G, self.loss_scale)
matrix_G = self.mul(matrix_G, self.batch_size_scale)
matrix_G = matrix_G + dampingG * G_eye
matrix_G_inv = self.cholesky(matrix_G)
matrix_G_inv = self.vector_matmul(matrix_G_inv, matrix_G_inv)
G_max = self.fused_abs_max2(matrix_G_inv)
G_max = self.fused_abs_max2(G_max)
matrix_G_inv = self.matrix_combine(matrix_G_inv)
matrix_G_inv_shape = self.shape(self.matrix_G[thor_layer_count])
matrix_G_device_temp_shape = (matrix_G_inv_shape[0], matrix_G_inv_shape[2],
matrix_G_inv_shape[1], matrix_G_inv_shape[3])
matrix_G_inv = self.reshape(matrix_G_inv, matrix_G_device_temp_shape)
matrix_G_inv = self.transpose(matrix_G_inv, (2, 0, 1, 3))

A_max = F.depend(A_max, g)
G_max = F.depend(G_max, g)
matrix_a_allreduce = matrix_a_allreduce + (matrix_A_inv,)
matrix_g_allreduce = matrix_g_allreduce + (matrix_G_inv,)
matrix_a_max_allreduce = matrix_a_max_allreduce + (A_max,)
matrix_g_max_allreduce = matrix_g_max_allreduce + (G_max,)
elif layer_type == FC:
damping = self.sqrt(damping_step)
matrix_A = matrix_A + damping * A_eye
matrix_A_inv = self.cholesky(matrix_A)
matrix_A_inv = self.vector_matmul(matrix_A_inv, matrix_A_inv)
weight_shape = self.shape(self.params[i])
out_channels = weight_shape[0]
if out_channels == 2:
matrix_A_inv = self.matrix_combine(matrix_A_inv)
matrix_G_inv = G_eye
else:
matrix_G = self.mul(matrix_G, self.loss_scale)
matrix_G = self.mul(matrix_G, self.batch_size_scale)
matrix_G = matrix_G + damping * G_eye
matrix_G_inv = self.cholesky(matrix_G)
matrix_G_inv = self.vector_matmul(matrix_G_inv, matrix_G_inv)
if out_channels == 1001:
matrix_A_inv_max = self.fused_abs_max2(matrix_A_inv)
A_max = self.fused_abs_max2(matrix_A_inv_max)
matrix_A_inv = self.matrix_combine(matrix_A_inv)
matrix_A_inv_shape = self.shape(matrix_A_inv)
matrix_A_inv = self.reshape(matrix_A_inv,
(matrix_A_inv_shape[0] / 16, 16,
matrix_A_inv_shape[0] / 16, 16))
matrix_A_inv = self.transpose(matrix_A_inv, (2, 0, 1, 3))
matrix_G_inv_max = P.CusFusedAbsMax1([1001, 1001])(matrix_G_inv)
G_max = self.fused_abs_max2(matrix_G_inv_max)
matrix_G_inv = self.matrix_combine(matrix_G_inv)
matrix_G_inv = self.slice(matrix_G_inv, (0, 0), (1001, 1001))
matrix_G_inv = P.Pad(((0, 7), (0, 7)))(matrix_G_inv)
matrix_G_inv_shape = self.shape(matrix_G_inv)
matrix_G_inv = self.reshape(matrix_G_inv,
(matrix_G_inv_shape[0] / 16, 16,
matrix_G_inv_shape[0] / 16, 16))
matrix_G_inv = self.transpose(matrix_G_inv, (2, 0, 1, 3))
A_max = F.depend(A_max, g)
G_max = F.depend(G_max, g)
matrix_a_max_allreduce = matrix_a_max_allreduce + (A_max,)
matrix_g_max_allreduce = matrix_g_max_allreduce + (G_max,)
else:
matrix_A_inv = self.matrix_combine(matrix_A_inv)
matrix_G_inv = self.matrix_combine(matrix_G_inv)
matrix_a_allreduce = matrix_a_allreduce + (matrix_A_inv,)
matrix_g_allreduce = matrix_g_allreduce + (matrix_G_inv,)
elif layer_type == Embedding:
damping = self.sqrt(damping_step)
A_eye = P.OnesLike()(matrix_A)
matrix_A = self.mul(matrix_A, 1.0 / self.batch_size)
matrix_A = matrix_A + damping * A_eye
matrix_A_inv = self.inv(matrix_A)
matrix_G = self.mul(matrix_G, self.loss_scale)
matrix_G = self.mul(matrix_G, self.batch_size_scale)
matrix_G = matrix_G + damping * G_eye
matrix_G_inv = self.cholesky(matrix_G)
matrix_G_inv = self.vector_matmul(matrix_G_inv, matrix_G_inv)
matrix_G_inv = self.matrix_combine(matrix_G_inv)
matrix_a_allreduce = matrix_a_allreduce + (matrix_A_inv,)
matrix_g_allreduce = matrix_g_allreduce + (matrix_G_inv,)
return matrix_a_allreduce, matrix_g_allreduce, matrix_a_max_allreduce, matrix_g_max_allreduce

def _process_layernorm(self, damping_step, gradient):
"""process layernorm layer for thor"""
damping = self.sqrt(damping_step)
normalizer = self.cast(self.batch_size, mstype.float32)
fim_cov = self.square(gradient)
fim_cov = self.mul(fim_cov, 1.0 / normalizer)
fim_cov = fim_cov + damping
fim_inv = self.inv(fim_cov)
gradient = self.mul(fim_inv, gradient)
return gradient

def _get_second_gradients(self, new_grads, damping_step, gradients):
"""get second gradients for thor"""
params_len = len(self.params)
for i in range(params_len):
g = gradients[i]
thor_layer_count = self.weight_fim_idx_map[i]
layer_type = self.weight_layerType_idx_map[i]
if self.conv_layer_count > 0:
matrix_A = self.matrix_A[thor_layer_count]
matrix_G = self.matrix_G[thor_layer_count]
matrix_max = self.matrix_max_inv[thor_layer_count]
if layer_type == FC:
g = self.cube_matmul_left_fc(matrix_G, g)
g = self.cube_matmul_right_fc(g, matrix_A, matrix_max)
elif layer_type == Conv:
g = self.cube_matmul_left(matrix_G, g)
g = self.cube_matmul_right_mul(g, matrix_A, matrix_max)
else:
if layer_type == Embedding:
temp_a_ori = self.matrix_A_cov[thor_layer_count]
temp_g = self.matrix_G_cov[thor_layer_count]
temp_a = self.expand(temp_a_ori, 1)
g = self.mul(temp_a, g)
temp_g = self.cast(temp_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.matmul(g, temp_g)
g = self.cast(g, mstype.float32)
elif layer_type == FC:
temp_a = self.matrix_A_cov[thor_layer_count]
temp_g = self.matrix_G_cov[thor_layer_count]
temp_a = self.cast(temp_a, mstype.float16)
temp_g = self.cast(temp_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.matmul(temp_g, g)
g = self.matmul(g, temp_a)
g = self.cast(g, mstype.float32)
elif layer_type == LayerNorm:
g = self._process_layernorm(damping_step, g)
new_grads = new_grads + (g,)
return new_grads

def construct(self, gradients):
params = self.params
moments = self.moments
damping_step = self.gather(self.damping, self.cov_step, self.axis)
damping_step = self.cast(damping_step, mstype.float32)
if self.thor:
matrix_A_allreduce = ()
matrix_G_allreduce = ()
matrix_A_max_allreduce = ()
matrix_G_max_allreduce = ()
matrix_A_allreduce, matrix_G_allreduce, matrix_A_max_allreduce, matrix_G_max_allreduce = \
self._get_Ainv_Ginv_Amax_Gmax_list(gradients, damping_step, matrix_A_allreduce, matrix_G_allreduce,
matrix_A_max_allreduce, matrix_G_max_allreduce)
if self.is_distributed and self.conv_layer_count > 0:
matrix_A_allreduce = self.grad_reducer_A(matrix_A_allreduce)
matrix_G_allreduce = self.grad_reducer_G(matrix_G_allreduce)
matrix_A_max_allreduce = self.grad_reducer_Amax(matrix_A_max_allreduce)
matrix_G_max_allreduce = self.grad_reducer_Gmax(matrix_G_max_allreduce)

new_grads = ()
for i in range(len(self.params)):
g = gradients[i]
thor_layer_count = self.weight_fim_idx_map[i]
conv_layer_count = self.weight_conv_idx_map[i]
layer_type = self.weight_layerType_idx_map[i]
if self.conv_layer_count > 0:
temp_a = matrix_A_allreduce[thor_layer_count]
temp_g = matrix_G_allreduce[thor_layer_count]
matrix_A_inv_max = self.log(matrix_A_max_allreduce[thor_layer_count])
matrix_A_inv_max = self.mul(matrix_A_inv_max, -1)
matrix_A_inv_max = self.exp(matrix_A_inv_max)
temp_a = self.mul(temp_a, matrix_A_inv_max)
matrix_G_inv_max = self.log(matrix_G_max_allreduce[thor_layer_count])
matrix_G_inv_max = self.mul(matrix_G_inv_max, -1)
matrix_G_inv_max = self.exp(matrix_G_inv_max)
temp_g = self.mul(temp_g, matrix_G_inv_max)
temp_max = self.mul(matrix_A_max_allreduce[thor_layer_count],
matrix_G_max_allreduce[thor_layer_count])
temp_a = self.cast(temp_a, mstype.float16)
temp_g = self.cast(temp_g, mstype.float16)
if layer_type == FC:
g = self.cube_matmul_left_fc(temp_g, g)
g = self.cube_matmul_right_fc(g, temp_a, temp_max)
elif layer_type == Conv:
A_normalizer = self.A_normalizer[conv_layer_count]
A_normalizer = F.depend(A_normalizer, g)
temp_max = self.mul(temp_max, self.batch_size / A_normalizer)
g = self.cube_matmul_left(temp_g, g)
g = self.cube_matmul_right_mul(g, temp_a, temp_max)
fake_A = self.assign(self.matrix_A[thor_layer_count], temp_a)
fake_G = self.assign(self.matrix_G[thor_layer_count], temp_g)
fake_max = self.assign(self.matrix_max_inv[thor_layer_count], temp_max)
g = F.depend(g, fake_A)
g = F.depend(g, fake_G)
g = F.depend(g, fake_max)
else:
if layer_type == Embedding:
temp_a_ori = matrix_A_allreduce[thor_layer_count]
temp_g = matrix_G_allreduce[thor_layer_count]
fake_A = self.assign(self.matrix_A_cov[thor_layer_count], temp_a_ori)
fake_G = self.assign(self.matrix_G_cov[thor_layer_count], temp_g)
g = F.depend(g, fake_A)
g = F.depend(g, fake_G)
temp_a = self.expand(temp_a_ori, 1)
g = self.mul(temp_a, g)
temp_g = self.cast(temp_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.matmul(g, temp_g)
g = self.cast(g, mstype.float32)
elif layer_type == FC:
temp_a = matrix_A_allreduce[thor_layer_count]
temp_g = matrix_G_allreduce[thor_layer_count]
fake_A = self.assign(self.matrix_A_cov[thor_layer_count], temp_a)
fake_G = self.assign(self.matrix_G_cov[thor_layer_count], temp_g)
g = F.depend(g, fake_A)
g = F.depend(g, fake_G)
temp_a = self.cast(temp_a, mstype.float16)
temp_g = self.cast(temp_g, mstype.float16)
g = self.cast(g, mstype.float16)
g = self.matmul(temp_g, g)
g = self.matmul(g, temp_a)
g = self.cast(g, mstype.float32)
elif layer_type == LayerNorm:
g = self._process_layernorm(damping_step, g)
new_grads = new_grads + (g,)
gradients = new_grads
else:
new_grads = ()
gradients = self._get_second_gradients(new_grads, damping_step, gradients)

if self.is_distributed and self.conv_layer_count == 0:
gradients = self.grad_reducer_g(gradients)
self.cov_step = self.cov_step + self.one
if self.weight_decay > 0:
gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags, params, gradients)
gradients = self.scale_grad(gradients)
lr = self.get_lr()
success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum, lr), gradients, params, moments)
return success

+ 3
- 2
mindspore/nn/wrap/grad_reducer.py View File

@@ -239,6 +239,7 @@ class DistributedGradReducer(Cell):
parameters (list): the parameters to be updated.
mean (bool): When mean is true, the mean coefficient (degree) would apply on gradients. Default: False.
degree (int): The mean coefficient. Usually it equals to device number. Default: None.
fusion_type (int): The type of all reduce fusion. Default: 1.

Raises:
ValueError: If degree is not a int or less than 0.
@@ -319,7 +320,7 @@ class DistributedGradReducer(Cell):
256.0
"""

def __init__(self, parameters, mean=True, degree=None):
def __init__(self, parameters, mean=True, degree=None, fusion_type=1):
super(DistributedGradReducer, self).__init__(auto_prefix=False)
self.map_ = C.Map()
if degree is None:
@@ -337,7 +338,7 @@ class DistributedGradReducer(Cell):
self.op_list = _init_allreduce_operators(len(parameters), split_indices)
else:
self.split_fusion = False
self.allreduce = AllReduce().add_prim_attr('fusion', 1)
self.allreduce = AllReduce().add_prim_attr('fusion', fusion_type)
self.allgather = AllGather(GlobalComm.WORLD_COMM_GROUP)
ps_filter = lambda x: x.is_param_ps
self.ps_parameters = tuple(ps_filter(x) for x in parameters)


+ 19
- 0
mindspore/train/train_thor/__init__.py View File

@@ -0,0 +1,19 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""convert to second order related classes and functions."""

from .convert_utils import ConvertNetUntils, ConvertModelUtils

__all__ = ["ConvertNetUntils", "ConvertModelUtils"]

+ 157
- 0
mindspore/train/train_thor/convert_utils.py View File

@@ -0,0 +1,157 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
convert utils for second order optimizer: thor
"""
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore import context


class ConvertNetUntils():
"""
Convert net to thor layer net
"""
def __init__(self):
self._convert_method_map = {nn.Dense: self._convert_dense,
nn.Embedding: self._convert_embedding,
nn.Conv2d: self._convert_conv2d}


def _convert_dense(self, subcell):
"""
convert dense cell to second_order cell
"""

weight = subcell.weight
act_name = None
if subcell.activation_flag:
act_class = subcell.activation.__class__.__name__
act_name = act_class.lower()
if subcell.out_channels == 1001:
new_subcell = nn.Dense_Thor(in_channels=subcell.in_channels,
out_channels=subcell.out_channels,
weight_init=weight,
has_bias=subcell.has_bias,
bias_init='zeros',
activation=act_name)
else:
compute_type = mstype.float16
if context.get_context("device_target") == "GPU":
compute_type = mstype.float32
new_subcell = nn.Dense_Thor(in_channels=subcell.in_channels,
out_channels=subcell.out_channels,
weight_init=weight,
has_bias=subcell.has_bias,
bias_init='zeros',
activation=act_name).to_float(compute_type)

if subcell.has_bias:
new_subcell.bias = subcell.bias
return new_subcell


def _convert_embedding(self, subcell):
"""
convert embedding cell to second_order cell
"""
new_subcell = nn.Embedding_Thor(vocab_size=subcell.vocab_size,
embedding_size=subcell.embedding_size,
use_one_hot=False)
new_subcell.embedding_table = subcell.embedding_table
return new_subcell


def _convert_conv2d(self, subcell):
"""
convert conv2d cell to second_order cell
"""
out_channel = subcell.out_channels
in_channel = subcell.in_channels
kernel_size = subcell.kernel_size[0]
stride = subcell.stride
padding = subcell.padding
pad_mode = subcell.pad_mode
has_bias = subcell.has_bias
weight = subcell.weight
new_subcell = nn.Conv2d_Thor(in_channel, out_channel,
kernel_size=kernel_size, stride=stride, padding=padding, pad_mode=pad_mode,
has_bias=has_bias, weight_init=weight)
return new_subcell


def _convert_to_thor_net(self, net):
"""
convert net to thor net
"""
cells = net.name_cells()
change = False
for name in cells:
subcell = cells[name]
if subcell == net:
continue
elif isinstance(subcell, (nn.Dense_Thor, nn.Conv2d_Thor, nn.Embedding_Thor)):
continue
elif isinstance(subcell, (nn.Conv2dTranspose, nn.Conv1d, nn.Conv1dTranspose, nn.BatchNorm1d, nn.GroupNorm,
nn.GlobalBatchNorm, nn.LayerNorm, nn.BatchNorm2d, nn.MaxPool2d)):
continue
elif isinstance(subcell, (nn.Embedding, nn.Dense, nn.Conv2d)):
prefix = subcell.param_prefix
new_subcell = self._convert_method_map[type(subcell)](subcell)
print("subcell name: ", name, "prefix is", prefix, flush=True)
if isinstance(new_subcell, (nn.Dense_Thor, nn.Embedding_Thor, nn.Conv2d_Thor)):
print("convert to thor layer success.", flush=True)
new_subcell.update_parameters_name(prefix + '.')
net.insert_child_to_cell(name, new_subcell)
change = True
else:
self._convert_to_thor_net(subcell)

if isinstance(net, nn.SequentialCell) and change:
print("is nn.SequentialCell and change")
net.cell_list = list(net.cells())


def convert_to_thor_net(self, net):
"""
api for convert net to thor net
"""
net.update_cell_prefix()
self._convert_to_thor_net(net)
net.update_cell_type("second_order")


class ConvertModelUtils():
"""
convert model to thor model utils
"""

def convert_to_thor_model(self, model, network, loss_fn=None, optimizer=None, metrics=None, amp_level="O0",
loss_scale_manager=None, keep_batchnorm_fp32=False, frequency=834):

"""
api for convert model to thor model
"""
optim_name = type(optimizer).__name__
if optim_name in ("THOR_Ascend", "THOR_GPU"):
from .model_thor import Model_Thor
if isinstance(network, nn.TrainOneStepCell):
model = Model_Thor(network=network, frequency=frequency)
else:
model = Model_Thor(network=network, loss_fn=loss_fn, optimizer=optimizer, amp_level=amp_level,
loss_scale_manager=loss_scale_manager,
keep_batchnorm_fp32=keep_batchnorm_fp32, metrics=metrics, frequency=frequency)

return model

+ 188
- 0
mindspore/train/train_thor/dataset_helper.py View File

@@ -0,0 +1,188 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Dataset help for minddata dataset"""
import math
import os
from mindspore._checkparam import Validator
from mindspore import context
from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes
from mindspore.nn.wrap import GetNextSingleOp
from mindspore.parallel._utils import _get_device_num, _need_to_full, _to_full_shapes


def _send_data(dataset, epoch_num):
"""Engine dataset to write data to tdt queue."""
if not hasattr(dataset, '__has_sent__'):
exec_dataset = dataset.__transfer_dataset__
exec_dataset.send(epoch_num)
dataset.__has_sent__ = True


def _send_data_no_flag(dataset, epoch_num):
"""Engine dataset to write data to tdt queue directly."""
exec_dataset = dataset.__transfer_dataset__
exec_dataset.send(epoch_num)


class DatasetHelper:
"""
Help function to use the MindData dataset.

According to different contexts, change the iterations of dataset and use the same iteration for loop in different
contexts.

Note:
The iteration of DatasetHelper will provide one epoch data.

Args:
dataset (DataSet): The training dataset iterator.
dataset_sink_mode (bool): If true use GetNext to fetch the data, or else feed the data from host. Default: True.
sink_size (int): Control the amount of data in each sink.
If sink_size=-1, sink the complete dataset for each epoch.
If sink_size>0, sink sink_size data for each epoch. Default: -1.
epoch_num (int): Control the number of epoch data to send. Default: 1.

Examples:
>>> dataset_helper = DatasetHelper(dataset)
>>> for inputs in dataset_helper:
>>> outputs = network(*inputs)
"""

def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1, epoch_num=1, iter_first_order=1):
dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
Validator.check_is_int(sink_size)
if sink_size < -1 or sink_size == 0:
raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size))

if dataset_sink_mode:
if context.get_context("device_target") == "Ascend":
iterclass = _DatasetIterMSLoopSink
self.iter = iterclass(dataset, sink_size, epoch_num, iter_first_order)
elif context.get_context("device_target") == "GPU":
iterclass = _DatasetIterMS
self.iter = iterclass(dataset, sink_size, epoch_num)
elif context.get_context("device_target") == "CPU":
raise RuntimeError("Currently dataset sink mode is not supported when the device target is CPU.")

def __iter__(self):
return self.iter.__iter__()

# A temp solution for loop sink. Delete later
def types_shapes(self):
"""Get the types and shapes from dataset on the current configuration."""
return self.iter.types_shapes()

def sink_size(self):
"""Get sink_size for each iteration."""
return self.iter.get_sink_size()

def stop_send(self):
"""Free up resources about data sink."""
self.iter.stop_send()


class _DatasetIter:
"""Base iter for dataset helper"""
def __init__(self, dataset, sink_size, epoch_num):
self.dataset = dataset
self.sink_size = sink_size
self.sink_count = 1

if not hasattr(dataset, '__transfer_dataset__'):
if hasattr(dataset, '__loop_size__'):
self.sink_size = dataset.__loop_size__
dataset.__transfer_dataset__ = _exec_datagraph(dataset, self.sink_size)

if not hasattr(dataset, '__no_send__'):
_send_data(dataset, epoch_num)
else:
_send_data_no_flag(dataset, epoch_num)

self.stop_send = dataset.__transfer_dataset__.stop_send
self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset)

def __iter__(self):
self.index = 0
return self

def __next__(self):
if self.index >= self.sink_count:
raise StopIteration()
self.index += 1
return self.op()

def types_shapes(self):
return self.dataset_types, self.dataset_shapes

def get_sink_count(self, dataset):
sink_count = 1
if hasattr(dataset, '__loop_size__'):
loop_size = dataset.__loop_size__
if loop_size <= dataset.get_dataset_size() and dataset.get_dataset_size() % loop_size != 0:
raise ValueError(f'Dataset size {dataset.get_dataset_size()} and '
f'sink_size {loop_size} are not matched.')
sink_count = math.ceil(dataset.get_dataset_size() / loop_size)
return sink_count

def get_sink_size(self):
"""get sink_size to device"""
sink_size = 1
if hasattr(self.dataset, '__loop_size__'):
sink_size = self.dataset.__loop_size__
else:
if context.get_context("enable_ge") or context.get_context("device_target") == "Ascend":
if self.sink_size > 0:
sink_size = self.sink_size
else:
sink_size = self.dataset.get_dataset_size()
return sink_size


class _DatasetIterMSLoopSink(_DatasetIter):
"""Iter for context when device_target is Ascend"""
def __init__(self, dataset, sink_size, epoch_num, iter_first_order):
super().__init__(dataset, sink_size, epoch_num)
sink_count = 1
if hasattr(dataset, '__loop_size__'):
loop_size = dataset.__loop_size__ + iter_first_order
sink_count = int(sink_size / loop_size) * 2
self.sink_count = sink_count
ms_role = os.getenv("MS_ROLE")
if ms_role in ("MS_PSERVER", "MS_SCHED"):
self.sink_count = 1
# for self._parallel_mode equal to semi_auto_parallel or auto_parallel, and not using full_batch,
# use a complete tensor to compile, and slice tensor to run. The batch dimension of tensors for
# compile is device_number times the batch dimension of tensors for run. Now only support LoopSink.
if _need_to_full():
device_num = _get_device_num()
self.dataset_shapes = _to_full_shapes(self.dataset_shapes, device_num)

def op():
return tuple()

self.op = op


class _DatasetIterMS(_DatasetIter):
"""Iter for MS when enable_loop_sink is False."""
def __init__(self, dataset, sink_size, epoch_num):
super().__init__(dataset, sink_size, epoch_num)
if sink_size > 0:
self.sink_count = sink_size
else:
self.sink_count = dataset.get_dataset_size()

queue_name = dataset.__transfer_dataset__.queue_name
self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name)

+ 236
- 0
mindspore/train/train_thor/model_thor.py View File

@@ -0,0 +1,236 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Model."""

import math
from mindspore.train.callback import RunContext
from mindspore import context
from mindspore.context import ParallelMode
from mindspore.train.model import Model
from mindspore.train.dataset_helper import connect_network_with_dataset
from mindspore.parallel._utils import _need_to_full, _to_full_tensor
from mindspore.common.dtype import pytype_to_dtype
from mindspore._c_expression import init_exec_dataset
from .dataset_helper import DatasetHelper

def _convert_type(types):
"""
Convert from numpy type to tensor type.

Args:
types (list): Numpy type list of element in dataset.

Returns:
list, list of element in dataset.
"""
ms_types = []
for np_type in types:
ms_type = pytype_to_dtype(np_type)
ms_types.append(ms_type)
return ms_types


def _get_types_and_shapes(dataset):
"""Get dataset types and shapes."""
dataset_types = _convert_type(dataset.output_types())
dataset_shapes = dataset.output_shapes()
return dataset_types, dataset_shapes


def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'):
"""Initialize and execute the dataset graph."""
batch_size = exec_dataset.get_batch_size()
input_indexs = exec_dataset.input_indexs

# transform data format
dataset_types, dataset_shapes = _get_types_and_shapes(exec_dataset)
init_exec_dataset(exec_dataset.__transfer_dataset__.queue_name,
dataset_size,
batch_size,
dataset_types,
dataset_shapes,
input_indexs,
phase=phase,
need_run=False)


class Model_Thor(Model):
"""
High-Level API for Training or Testing.

`Model` groups layers into an object with training and inference features.

Args:
network (Cell): A training or testing network.
loss_fn (Cell): Objective function, if loss_fn is None, the
network should contain the logic of loss and grads calculation, and the logic
of parallel if needed. Default: None.
optimizer (Cell): Optimizer for updating the weights. Default: None.
metrics (Union[dict, set]): A Dictionary or a set of metrics to be evaluated by the model during
training and testing. eg: {'accuracy', 'recall'}. Default: None.
eval_network (Cell): Network for evaluation. If not defined, `network` and `loss_fn` would be wrapped as
`eval_network`. Default: None.
eval_indexes (list): When defining the `eval_network`, if `eval_indexes` is None, all outputs of the
`eval_network` would be passed to metrics, otherwise `eval_indexes` must contain three
elements, including the positions of loss value, predicted value and label. The loss
value would be passed to the `Loss` metric, the predicted value and label would be passed
to other metric. Default: None.
amp_level (str): Option for argument `level` in `mindspore.amp.build_train_network`, level for mixed
precision training. Supports [O0, O2, O3]. Default: "O0".

- O0: Do not change.
- O2: Cast network to float16, keep batchnorm run in float32, using dynamic loss scale.
- O3: Cast network to float16, with additional property 'keep_batchnorm_fp32=False'.

O2 is recommended on GPU, O3 is recommended on Ascend.

loss_scale_manager (Union[None, LossScaleManager]): If it is None, the loss would not be scaled. Otherwise,
scale the loss by LossScaleManager. It is a key argument.
e.g. Use `loss_scale_manager=None` to set the value.
keep_batchnorm_fp32 (bool): Keep Batchnorm running in `float32`. If it is set to true, the level setting before
will be overwritten. Default: True.
"""

def __init__(self, network, loss_fn=None, optimizer=None, metrics=None, eval_network=None,
eval_indexes=None, amp_level="O0", frequency=834, **kwargs):
super(Model_Thor, self).__init__(network, loss_fn, optimizer, metrics, eval_network,
eval_indexes, amp_level, **kwargs)
self._frequency = frequency
self._train_network = self._build_train_network()

def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, sink_size=-1,
epoch_num=1, iter_first_order=1):
"""Initializes dataset."""
if dataset_sink_mode and not is_train:
dataset.__loop_size__ = 1
dataset_helper = DatasetHelper(dataset, dataset_sink_mode, sink_size, epoch_num, iter_first_order)

if dataset_sink_mode and context.get_context("device_target") != "GPU":
network = connect_network_with_dataset(network, dataset_helper)
network.set_train(is_train)
network.phase = phase

if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
network.set_auto_parallel()

return dataset_helper, network

def _train_dataset_sink_process(self, epoch, train_dataset, list_callback=None, cb_params=None, sink_size=-1):
"""
Training process. The data would be passed to network through dataset channel.

Args:
epoch (int): Total number of iterations on the data.
train_dataset (Dataset): A training dataset iterator. If there is no
loss_fn, a tuple with multiple data (data1, data2, data3, ...) should be
returned and passed to the network. Otherwise, a tuple (data, label) should
be returned. The data and label would be passed to the network and loss
function respectively.
list_callback (Callback): Executor of callback list. Default: None.
cb_params (_InternalCallbackParam): Callback parameters. Default: None.
sink_size (int): Control the amount of data in each sink. Default: -1.
"""
if sink_size == -1:
epoch_num = epoch
else:
epoch_num = math.ceil(epoch * sink_size / train_dataset.get_dataset_size())

iter_first_order = self._frequency - 1
iter_second_order = 1
train_dataset.__loop_size__ = iter_second_order
dataset_helper, train_network = self._exec_preprocess(self._train_network,
is_train=True,
phase='train',
dataset=train_dataset,
dataset_sink_mode=True,
sink_size=sink_size,
epoch_num=epoch_num,
iter_first_order=iter_first_order)

self._train_network = train_network
cb_params.train_network = self._train_network
cb_params.cur_step_num = 0

run_context = RunContext(cb_params)
list_callback.begin(run_context)

# used to stop training for early stop, such as stopAtTIme or stopATStep
should_stop = False
switch_branch_one = True
index_first_order = 0
train_network_init_flag = True
has_do_dataset_init = False

for i in range(epoch):
cb_params.cur_epoch_num = i + 1
list_callback.epoch_begin(run_context)
# for data sink dataset_helper only iter once, other wise iter epoch_size times.
for inputs in dataset_helper:
if _need_to_full() and context.get_context("device_target") == "GPU":
inputs = _to_full_tensor(inputs, self._device_number, self._global_rank)
list_callback.step_begin(run_context)
if context.get_context("device_target") == "GPU":
if switch_branch_one:
cb_params.cur_step_num += 1
if train_network_init_flag:
self._train_network.add_flags_recursive(thor=True)
self._train_network.phase = 'train0'
switch_branch_one = not switch_branch_one
outputs = self._train_network(*inputs)
cb_params.net_outputs = outputs
list_callback.step_end(run_context)
else:
cb_params.cur_step_num += 1
if train_network_init_flag:
self._train_network.add_flags_recursive(thor=False)
train_network_init_flag = False
self._train_network.phase = 'train1'
outputs = self._train_network(*inputs)
cb_params.net_outputs = outputs
index_first_order += 1
if index_first_order == iter_first_order:
index_first_order = 0
switch_branch_one = not switch_branch_one
list_callback.step_end(run_context)
else:
if switch_branch_one:
cb_params.cur_step_num += 1
if train_network_init_flag:
self._train_network.add_flags_recursive(thor=True)
self._train_network.phase = 'train0'
else:
cb_params.cur_step_num += iter_first_order
if train_network_init_flag:
self._train_network.add_flags_recursive(thor=False)
train_network_init_flag = False
self._train_network.phase = 'train1'
if not has_do_dataset_init:
_exec_datagraph(train_dataset, iter_first_order, phase='train1_dataset')
has_do_dataset_init = True
switch_branch_one = not switch_branch_one
outputs = self._train_network(*inputs)
cb_params.net_outputs = outputs
list_callback.step_end(run_context)

list_callback.epoch_end(run_context)
should_stop = should_stop or run_context.get_stop_requested()
if should_stop:
break
dataset_helper.stop_send()

list_callback.end(run_context)


__all__ = ["Model_Thor"]

+ 50
- 0
model_zoo/official/cv/resnet/src/config.py View File

@@ -16,6 +16,10 @@
network config setting, will be used in train.py and eval.py
"""
from easydict import EasyDict as ed
# config optimizer for resnet50, imagenet2012. Momentum is default, Thor is optional.
cfg = ed({
'optimizer': 'Thor',
})

# config for resent50, cifar10
config1 = ed({
@@ -101,3 +105,49 @@ config4 = ed({
"lr_max": 0.3,
"lr_end": 0.0001
})

# config for resnet50, imagenet2012, Ascend 910
config_thor_Ascend = ed({
"class_num": 1001,
"batch_size": 32,
"loss_scale": 128,
"momentum": 0.9,
"weight_decay": 5e-4,
"epoch_size": 45,
"pretrain_epoch_size": 0,
"save_checkpoint": True,
"save_checkpoint_epochs": 2,
"keep_checkpoint_max": 15,
"save_checkpoint_path": "./",
"use_label_smooth": True,
"label_smooth_factor": 0.1,
"lr_init": 0.05803,
"lr_decay": 4.04839,
"lr_end_epoch": 53,
"damping_init": 0.02714,
"damping_decay": 0.50036,
"frequency": 834,
})

# config for resnet50, imagenet2012, GPU
config_thor_gpu = ed({
"class_num": 1001,
"batch_size": 32,
"loss_scale": 128,
"momentum": 0.9,
"weight_decay": 5e-4,
"epoch_size": 40,
"pretrain_epoch_size": 0,
"save_checkpoint": True,
"save_checkpoint_epochs": 1,
"keep_checkpoint_max": 15,
"save_checkpoint_path": "./",
"use_label_smooth": True,
"label_smooth_factor": 0.1,
"lr_init": 0.05672,
"lr_decay": 4.9687,
"lr_end_epoch": 50,
"damping_init": 0.02345,
"damping_decay": 0.5467,
"frequency": 834,
})

+ 33
- 0
model_zoo/official/cv/resnet/src/lr_generator.py View File

@@ -205,3 +205,36 @@ def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch=120
lr_each_step = np.array(lr_each_step).astype(np.float32)
learning_rate = lr_each_step[global_step:]
return learning_rate


def get_thor_lr(global_step, lr_init, decay, total_epochs, steps_per_epoch, decay_epochs=100):
"""get_model_lr"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
for i in range(total_steps):
epoch = (i + 1) / steps_per_epoch
base = (1.0 - float(epoch) / total_epochs) ** decay
lr_local = lr_init * base
if epoch >= decay_epochs:
lr_local = lr_local * 0.5
if epoch >= decay_epochs + 1:
lr_local = lr_local * 0.5
lr_each_step.append(lr_local)
current_step = global_step
lr_each_step = np.array(lr_each_step).astype(np.float32)
learning_rate = lr_each_step[current_step:]
return learning_rate


def get_thor_damping(global_step, damping_init, decay_rate, total_epochs, steps_per_epoch):
"""get_model_damping"""
damping_each_step = []
total_steps = steps_per_epoch * total_epochs
for step in range(total_steps):
epoch = (step + 1) / steps_per_epoch
damping_here = damping_init * (decay_rate ** (epoch / 10))
damping_each_step.append(damping_here)
current_step = global_step
damping_each_step = np.array(damping_each_step).astype(np.float32)
damping_now = damping_each_step[current_step:]
return damping_now

+ 74
- 4
model_zoo/official/cv/resnet/src/resnet.py View File

@@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
"""ResNet."""
import math
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
@@ -36,12 +37,81 @@ def _weight_variable(shape, factor=0.01):
return Tensor(init_value)


def calculate_gain(nonlinearity, param=None):
"""calculate_gain"""
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
res = 0
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
res = 1
elif nonlinearity == 'tanh':
res = 5.0 / 3
elif nonlinearity == 'relu':
res = math.sqrt(2.0)
elif nonlinearity == 'leaky_relu':
if param is None:
negative_slope = 0.01
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
# True/False are instances of int, hence check above
negative_slope = param
else:
raise ValueError("negative_slope {} not a valid number".format(param))
res = math.sqrt(2.0 / (1 + negative_slope ** 2))
else:
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
return res


def _calculate_fan_in_and_fan_out(tensor):
"""_calculate_fan_in_and_fan_out"""
dimensions = len(tensor)
if dimensions < 2:
raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
if dimensions == 2: # Linear
fan_in = tensor[1]
fan_out = tensor[0]
else:
num_input_fmaps = tensor[1]
num_output_fmaps = tensor[0]
receptive_field_size = 1
if dimensions > 2:
import time
time.sleep(10)
receptive_field_size = tensor[2] * tensor[3]
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
return fan_in, fan_out


def _calculate_correct_fan(tensor, mode):
mode = mode.lower()
valid_modes = ['fan_in', 'fan_out']
if mode not in valid_modes:
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
return fan_in if mode == 'fan_in' else fan_out


def kaiming_normal(inputs_shape, a=0, mode='fan_in', nonlinearity='leaky_relu'):
fan = _calculate_correct_fan(inputs_shape, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
return np.random.normal(0, std, size=inputs_shape).astype(np.float32)


def kaiming_uniform(inputs_shape, a=0., mode='fan_in', nonlinearity='leaky_relu'):
fan = _calculate_correct_fan(inputs_shape, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
return np.random.uniform(-bound, bound, size=inputs_shape).astype(np.float32)


def _conv3x3(in_channel, out_channel, stride=1, use_se=False):
if use_se:
weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=3)
else:
weight_shape = (out_channel, in_channel, 3, 3)
weight = _weight_variable(weight_shape)
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
return nn.Conv2d(in_channel, out_channel,
kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight)

@@ -51,7 +121,7 @@ def _conv1x1(in_channel, out_channel, stride=1, use_se=False):
weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=1)
else:
weight_shape = (out_channel, in_channel, 1, 1)
weight = _weight_variable(weight_shape)
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
return nn.Conv2d(in_channel, out_channel,
kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight)

@@ -61,7 +131,7 @@ def _conv7x7(in_channel, out_channel, stride=1, use_se=False):
weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=7)
else:
weight_shape = (out_channel, in_channel, 7, 7)
weight = _weight_variable(weight_shape)
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
return nn.Conv2d(in_channel, out_channel,
kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight)

@@ -82,7 +152,7 @@ def _fc(in_channel, out_channel, use_se=False):
weight = Tensor(np.reshape(weight, (out_channel, in_channel)), dtype=mstype.float32)
else:
weight_shape = (out_channel, in_channel)
weight = _weight_variable(weight_shape)
weight = Tensor(kaiming_uniform(weight_shape, a=math.sqrt(5)))
return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0)




+ 29
- 7
model_zoo/official/cv/resnet/train.py View File

@@ -18,9 +18,10 @@ import argparse
import ast
from mindspore import context
from mindspore import Tensor
from mindspore.nn.optim.momentum import Momentum
from mindspore.nn.optim.momentum import Momentum, THOR
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.train_thor import ConvertModelUtils
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.train.loss_scale_manager import FixedLossScaleManager
@@ -32,6 +33,7 @@ import mindspore.nn as nn
import mindspore.common.initializer as weight_init
from src.lr_generator import get_lr, warmup_cosine_annealing_lr
from src.CrossEntropySmooth import CrossEntropySmooth
from src.config import cfg

parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--net', type=str, default=None, help='Resnet Model, either resnet50 or resnet101')
@@ -65,6 +67,12 @@ else:
from src.config import config4 as config
from src.dataset import create_dataset4 as create_dataset

if cfg.optimizer == "Thor":
if args_opt.device_target == "Ascend":
from src.config import config_thor_Ascend as config
else:
from src.config import config_thor_gpu as config


if __name__ == '__main__':
target = args_opt.device_target
@@ -124,13 +132,17 @@ if __name__ == '__main__':
cell.weight.dtype))

# init lr
if args_opt.net == "resnet50" or args_opt.net == "se-resnet50":
lr = get_lr(lr_init=config.lr_init, lr_end=config.lr_end, lr_max=config.lr_max,
warmup_epochs=config.warmup_epochs, total_epochs=config.epoch_size, steps_per_epoch=step_size,
lr_decay_mode=config.lr_decay_mode)
if cfg.optimizer == "Thor":
from src.lr_generator import get_thor_lr
lr = get_thor_lr(0, config.lr_init, config.lr_decay, config.lr_end_epoch, step_size, decay_epochs=39)
else:
lr = warmup_cosine_annealing_lr(config.lr, step_size, config.warmup_epochs, config.epoch_size,
config.pretrain_epoch_size * step_size)
if args_opt.net == "resnet50" or args_opt.net == "se-resnet50":
lr = get_lr(lr_init=config.lr_init, lr_end=config.lr_end, lr_max=config.lr_max,
warmup_epochs=config.warmup_epochs, total_epochs=config.epoch_size, steps_per_epoch=step_size,
lr_decay_mode=config.lr_decay_mode)
else:
lr = warmup_cosine_annealing_lr(config.lr, step_size, config.warmup_epochs, config.epoch_size,
config.pretrain_epoch_size * step_size)
lr = Tensor(lr)

# define opt
@@ -180,6 +192,16 @@ if __name__ == '__main__':
## fp32 training
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, config.weight_decay)
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
if cfg.optimizer == "Thor" and args_opt.dataset == "imagenet2012":
from src.lr_generator import get_thor_damping
damping = get_thor_damping(0, config.damping_init, config.damping_decay, 70, step_size)
split_indices = [26, 53]
opt = THOR(net, lr, Tensor(damping), config.momentum, config.weight_decay, config.loss_scale,
config.batch_size, split_indices=split_indices)
model = ConvertModelUtils().convert_to_thor_model(model=model, network=net, loss_fn=loss, optimizer=opt,
loss_scale_manager=loss_scale, metrics={'acc'},
amp_level="O2", keep_batchnorm_fp32=False,
frequency=config.frequency)

# define callbacks
time_cb = TimeMonitor(data_size=step_size)


+ 1
- 1
model_zoo/official/nlp/bert/scripts/run_distributed_pretrain_ascend.sh View File

@@ -23,7 +23,7 @@ echo "For hyper parameter, please note that you should customize the scripts:
'{CUR_DIR}/scripts/ascend_distributed_launcher/hyper_parameter_config.ini' "
echo "=============================================================================================================="
CUR_DIR=`pwd`
ulimit -s 102400
python ${CUR_DIR}/scripts/ascend_distributed_launcher/get_distribute_pretrain_cmd.py \
--run_script_dir=${CUR_DIR}/run_pretrain.py \
--hyper_parameter_config_dir=${CUR_DIR}/scripts/ascend_distributed_launcher/hyper_parameter_config.ini \


+ 1
- 0
model_zoo/official/nlp/bert/scripts/run_standalone_pretrain_ascend.sh View File

@@ -24,6 +24,7 @@ DEVICE_ID=$1
EPOCH_SIZE=$2
DATA_DIR=$3
SCHEMA_DIR=$4
ulimit -s 102400

mkdir -p ms_log
PROJECT_DIR=$(cd "$(dirname "$0")" || exit; pwd)


+ 6
- 0
model_zoo/official/nlp/bert/src/config.py View File

@@ -48,6 +48,12 @@ cfg = edict({
'learning_rate': 2e-5,
'momentum': 0.9,
}),
'Thor': edict({
'momentum': 0.9,
'weight_decay': 5e-4,
'loss_scale': 1.0,
'frequency': 100,
}),
})

'''


+ 61
- 3
model_zoo/official/nlp/bert/src/utils.py View File

@@ -22,6 +22,7 @@ import math
import collections
import numpy as np
import mindspore.nn as nn
from mindspore import context
from mindspore import log as logger
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
@@ -106,11 +107,10 @@ class LossCallBack(Callback):
percent = 1
epoch_num -= 1
print("epoch: {}, current epoch percent: {}, step: {}, outputs are {}"
.format(int(epoch_num), "%.3f" % percent, cb_params.cur_step_num, str(cb_params.net_outputs),
flush=True))
.format(int(epoch_num), "%.3f" % percent, cb_params.cur_step_num, str(cb_params.net_outputs)))
else:
print("epoch: {}, step: {}, outputs are {}".format(cb_params.cur_epoch_num, cb_params.cur_step_num,
str(cb_params.net_outputs), flush=True))
str(cb_params.net_outputs)))

def LoadNewestCkpt(load_finetune_checkpoint_dir, steps_per_epoch, epoch_num, prefix):
"""
@@ -181,3 +181,61 @@ def convert_labels_to_index(label_list):
sub_label = pre + label
label2id[sub_label] = index
return label2id

def _get_poly_lr(global_step, lr_init, lr_end, lr_max, warmup_steps, total_steps, poly_power):
"""
generate learning rate array

Args:
global_step(int): current step
lr_init(float): init learning rate
lr_end(float): end learning rate
lr_max(float): max learning rate
warmup_steps(int): number of warmup epochs
total_steps(int): total epoch of training
poly_power(int): poly learning rate power

Returns:
np.array, learning rate array
"""
lr_each_step = []
if warmup_steps != 0:
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
else:
inc_each_step = 0
for i in range(total_steps):
if i < warmup_steps:
lr = float(lr_init) + inc_each_step * float(i)
else:
base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)))
lr = float(lr_max - lr_end) * (base ** poly_power)
lr = lr + lr_end
if lr < 0.0:
lr = 0.0
lr_each_step.append(lr)

learning_rate = np.array(lr_each_step).astype(np.float32)
current_step = global_step
learning_rate = learning_rate[current_step:]
return learning_rate


def get_bert_thor_lr():
if context.get_context("device_target") == "Ascend":
learning_rate = _get_poly_lr(global_step=0, lr_init=0.0, lr_end=3.244018779068399e-05,
lr_max=0.0034022148941459055, warmup_steps=0, total_steps=30000, poly_power=1)
else:
learning_rate = _get_poly_lr(global_step=0, lr_init=0.0, lr_end=1e-6, lr_max=1.7e-3, warmup_steps=0,
total_steps=30000, poly_power=1)

return Tensor(learning_rate)


def get_bert_thor_damping():
if context.get_context("device_target") == "Ascend":
damping = _get_poly_lr(global_step=0, lr_init=0.0, lr_end=1e-6, lr_max=5e-2, warmup_steps=0, total_steps=30000,
poly_power=1)
else:
damping = _get_poly_lr(global_step=0, lr_init=0.0, lr_end=1e-6, lr_max=3.5e-2, warmup_steps=0,
total_steps=30000, poly_power=1)
return Tensor(damping)

tests/st/networks/models/bert/bert_performance/test_bert_thor_mlperf.py → tests/st/networks/models/bert/bert_performance/test_bert_thor.py View File

@@ -28,20 +28,44 @@ from mindspore import log as logger
from mindspore.train.callback import Callback
from mindspore.context import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn.optim import THOR
from mindspore.train.model import Model
from mindspore.train.train_thor import ConvertModelUtils
import mindspore.dataset.transforms.c_transforms as C
from model_zoo.official.nlp.bert_thor.src.bert_for_pre_training import BertNetworkWithLoss, BertTrainOneStepCell
from model_zoo.official.nlp.bert_thor.src.bert_net_config import bert_net_cfg
from model_zoo.official.nlp.bert_thor.src.config import cfg
from model_zoo.official.nlp.bert_thor.src.lr_generator import get_bert_lr, get_bert_damping
from model_zoo.official.nlp.bert_thor.src.model_thor import Model
from model_zoo.official.nlp.bert_thor.src.thor_for_bert_arg import THOR

from model_zoo.official.nlp.bert.src.bert_for_pre_training import BertNetworkWithLoss, BertTrainOneStepCell
from model_zoo.official.nlp.bert.src.utils import get_bert_thor_lr, get_bert_thor_damping
from model_zoo.official.nlp.bert.src.bert_model import BertConfig

MINDSPORE_HCCL_CONFIG_PATH = "/home/workspace/mindspore_config/hccl/rank_table_8p.json"
DATASET_PATH = "/home/workspace/mindspore_dataset/bert/thor/en-wiki-512_test_first1wan"

load_checkpoint_path = ""
data_sink_steps = 100
train_steps = 200
batch_size = 12
frequency = 100
momentum = 0.9
weight_decay = 5e-4
loss_scale = 1.0

bert_net_cfg = BertConfig(
seq_length=512,
vocab_size=30522,
hidden_size=1024,
num_hidden_layers=4,
num_attention_heads=16,
intermediate_size=4096,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
use_relative_positions=False,
dtype=mstype.float32,
compute_type=mstype.float16
)

np.random.seed(1)
ds.config.set_seed(1)
@@ -113,27 +137,7 @@ def create_bert_dataset(device_num=1, rank=0, do_shuffle="true", data_dir=None,

def _set_bert_all_reduce_split():
"""set bert all_reduce fusion split, support num_hidden_layers is 12 and 24."""
from mindspore.parallel._auto_parallel_context import auto_parallel_context
if bert_net_cfg.num_hidden_layers == 12:
if bert_net_cfg.use_relative_positions:
auto_parallel_context().set_all_reduce_fusion_split_indices([29, 58, 87, 116, 145, 174, 203, 217],
"hccl_world_groupsum1")
auto_parallel_context().set_all_reduce_fusion_split_indices([29, 58, 87, 116, 145, 174, 203, 217],
"hccl_world_groupsum3")
else:
auto_parallel_context().set_all_reduce_fusion_split_indices([28, 55, 82, 109, 136, 163, 190, 205],
"hccl_world_groupsum1")
auto_parallel_context().set_all_reduce_fusion_split_indices([28, 55, 82, 109, 136, 163, 190, 205],
"hccl_world_groupsum3")
elif bert_net_cfg.num_hidden_layers == 24:
if bert_net_cfg.use_relative_positions:
auto_parallel_context().set_all_reduce_fusion_split_indices([30, 90, 150, 210, 270, 330, 390, 421],
"hccl_world_groupsum1")
auto_parallel_context().set_all_reduce_fusion_split_indices([30, 90, 150, 210, 270, 330, 390, 421],
"hccl_world_groupsum3")
else:
auto_parallel_context().set_all_reduce_fusion_split_indices([38, 77], "hccl_world_groupsum1")
auto_parallel_context().set_all_reduce_fusion_split_indices([38, 77], "hccl_world_groupsum3")
context.set_auto_parallel_context(all_reduce_fusion_config=[38, 77])


def train_process_bert_thor(q, device_id, epoch_size, device_num):
@@ -153,7 +157,6 @@ def train_process_bert_thor(q, device_id, epoch_size, device_num):
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
device_num=device_num)

bert_net_cfg.num_hidden_layers = 4
data_set = create_bert_dataset(device_num=device_num, rank=rank, do_shuffle=False, data_dir=DATASET_PATH,
schema_dir=None)
net_with_loss = BertNetworkWithLoss(bert_net_cfg, True)
@@ -161,13 +164,12 @@ def train_process_bert_thor(q, device_id, epoch_size, device_num):
new_repeat_count = epoch_size * data_set.get_dataset_size() // data_sink_steps
new_repeat_count = min(new_repeat_count, train_steps // data_sink_steps)

lr = get_bert_lr()
damping = get_bert_damping()
optimizer = THOR(filter(lambda x: x.requires_grad, net_with_loss.get_parameters()), lr, cfg.Thor.momentum,
filter(lambda x: 'matrix_A' in x.name, net_with_loss.get_parameters()),
filter(lambda x: 'matrix_G' in x.name, net_with_loss.get_parameters()),
cfg.Thor.weight_decay, cfg.Thor.loss_scale, bert_net_cfg.num_hidden_layers,
bert_net_cfg.batch_size, damping)
lr = get_bert_thor_lr()
damping = get_bert_thor_damping()
split_indices = [38, 77]
optimizer = THOR(net_with_loss, lr, damping, momentum, weight_decay, loss_scale, batch_size,
decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
split_indices=split_indices)
time_monitor_callback = TimeMonitor(data_sink_steps)
loss_callback = LossCallback()
callback = [time_monitor_callback, loss_callback]
@@ -177,7 +179,9 @@ def train_process_bert_thor(q, device_id, epoch_size, device_num):
load_param_into_net(net_with_loss, param_dict)

net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer)
model = Model(net_with_grads, frequency=cfg.Thor.frequency)
model = Model(net_with_grads)
model = ConvertModelUtils().convert_to_thor_model(model, network=net_with_grads, optimizer=optimizer,
frequency=frequency)
model.train(new_repeat_count, data_set, callbacks=callback, dataset_sink_mode=True, sink_size=data_sink_steps)

loss_list = loss_callback.loss_list
@@ -230,9 +234,12 @@ def test_bert_thor_mlperf_8p():
os.system("rm -rf " + str(i))

print("End training...")
assert mean_cost < 64.4
assert mean_loss < 7.9
assert mean_cost < 71.5
assert mean_loss < 8.125


if __name__ == '__main__':
begin = time.time()
test_bert_thor_mlperf_8p()
end = time.time()
print("time span is", end - begin, flush=True)

+ 1
- 1
tests/st/networks/models/resnet50/src_thor/config.py View File

@@ -18,7 +18,7 @@ network config setting, will be used in train.py and eval.py
from easydict import EasyDict as ed

config = ed({
"class_num": 1000,
"class_num": 1001,
"batch_size": 32,
"loss_scale": 128,
"momentum": 0.9,


+ 0
- 135
tests/st/networks/models/resnet50/src_thor/grad_reducer_thor.py View File

@@ -1,135 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""grad reducer cell for distributed training"""
from mindspore.nn.cell import Cell
from mindspore.communication.management import GlobalComm, get_group_size
from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.ops.operations.comm_ops import AllReduce
import mindspore.common.dtype as mstype

reduce_opt = C.MultitypeFuncGraph("reduce_opt")


def _init_allreduce_operators(length, split_indices):
""" initialize allreduce communication operators"""
indices = split_indices[0]
fusion = split_indices[1]
op_list = ()
j = 0
for i in range(length):
if j <= len(indices)-1:
temp = indices[j]
else:
temp = length
if i >= temp:
j = j + 1
fusion = fusion + 1
op = AllReduce('sum', GlobalComm.WORLD_COMM_GROUP)
op.add_prim_attr('fusion', fusion)
op_list = op_list + (op,)
return op_list


@reduce_opt.register("Function", "Number", "Function", "Tensor")
def _tensors_allreduce_mean(mul, degree, allreduce, parameters):
"""
Apply allreduce on parameters.

Args:
mul(Primitive): The mul operator for parameters.
degree (int): The mean coefficient.
allreduce (Primitive): The communication operator for parameters.
parameters (Tensor): The parameters before operation.

Returns:
Tensor, the parameters after operation.
"""
degree = F.scalar_cast(degree, F.dtype(parameters))
parameters = allreduce(parameters)
cast_op = P.Cast()
return mul(parameters, cast_op(F.scalar_to_array(1.0 / degree), F.dtype(parameters)))


_get_datatype = C.MultitypeFuncGraph("_get_datatype")


@_get_datatype.register("Tensor")
def _tensors_get_datatype(parameters):
"""
Acquire parameters datatype.

Args:
parameters (Tensor): The parameters before operation.

Returns:
mstype, the datatype of parameters.
"""
return F.dtype(parameters)


_cast_datatype = C.MultitypeFuncGraph("_cast_datatype")


@_cast_datatype.register("TypeType", "Tensor")
def _tensors_cast_datatype(datatype, parameters):
"""
Cast parameters to datatype.

Args:
datatype (mstype): the destination datatype of parameters.
parameters (Tensor): The parameters before operation.

Returns:
Tensor, the parameters after operation.
"""
return F.cast(parameters, datatype)


class DistributedGradReducerThor(Cell):
"""
A distributed optimizer.

Constructs a parameters reducer Cell, which applies communication and average operations on
single-process parameters values.

Args:
parameter_length (int): length of the parameters to be updated.
split_indices(tuple): parameter split indices.
mean (bool): When mean is true, the mean coefficient (degree) would apply on parameters. Default: False.
degree (int): The mean coefficient. Usually it equals to device number. Default: None.

Raises:
ValueError: If degree is not a int or less than 0.
"""

def __init__(self, parameter_length, split_indices, mean=True, degree=None):
super(DistributedGradReducerThor, self).__init__(auto_prefix=False)
self.hyper_map = C.HyperMap()
self.mul = P.Mul()
if degree is None:
self.degree = get_group_size()
else:
if not isinstance(degree, int) or degree <= 0:
raise ValueError("Parameter 'degree' in DistributedGradReducer should large than 0 and be int")
self.degree = degree
self.mean = mean
self.op_list = _init_allreduce_operators(parameter_length, split_indices)

def construct(self, parameters):
datatypes = self.hyper_map(F.partial(_get_datatype), parameters)
parameters = self.hyper_map(F.partial(_cast_datatype, mstype.float32), parameters)
new_parameters = self.hyper_map(F.partial(reduce_opt, self.mul, self.degree), self.op_list, parameters)
new_parameters = self.hyper_map(F.partial(_cast_datatype), datatypes, new_parameters)
return new_parameters

+ 0
- 88
tests/st/networks/models/resnet50/src_thor/lr_generator.py View File

@@ -1,88 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""learning rate generator"""
import math

import numpy as np


def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode):
"""
generate learning rate array

Args:
lr_init(float): init learning rate
lr_end(float): end learning rate
lr_max(float): max learning rate
warmup_epochs(int): number of warmup epochs
total_epochs(int): total epoch of training
steps_per_epoch(int): steps of one epoch
lr_decay_mode(string): learning rate decay mode, including steps, poly, cosine or default

Returns:
np.array, learning rate array
"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
warmup_steps = steps_per_epoch * warmup_epochs
if lr_decay_mode == 'steps':
decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps]
for i in range(total_steps):
if i < decay_epoch_index[0]:
lr = lr_max
elif i < decay_epoch_index[1]:
lr = lr_max * 0.1
elif i < decay_epoch_index[2]:
lr = lr_max * 0.01
else:
lr = lr_max * 0.001
lr_each_step.append(lr)
elif lr_decay_mode == 'poly':
if warmup_steps != 0:
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
else:
inc_each_step = 0
for i in range(total_steps):
if i < warmup_steps:
lr = float(lr_init) + inc_each_step * float(i)
else:
base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)))
lr = float(lr_max) * base * base
if lr < 0.0:
lr = 0.0
lr_each_step.append(lr)
elif lr_decay_mode == 'cosine':
decay_steps = total_steps - warmup_steps
for i in range(total_steps):
if i < warmup_steps:
lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps)
lr = float(lr_init) + lr_inc * (i + 1)
else:
linear_decay = (total_steps - i) / decay_steps
cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps))
decayed = linear_decay * cosine_decay + 0.00001
lr = lr_max * decayed
lr_each_step.append(lr)
else:
for i in range(total_steps):
if i < warmup_steps:
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
else:
lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps)
lr_each_step.append(lr)

learning_rate = np.array(lr_each_step).astype(np.float32)

return learning_rate

+ 188
- 154
tests/st/networks/models/resnet50/src_thor/resnet.py View File

@@ -13,104 +13,57 @@
# limitations under the License.
# ============================================================================
"""ResNet."""
import math
import numpy as np
import mindspore.nn as nn
from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from mindspore.ops import operations as P

from .thor_layer import Conv2d_Thor, Dense_Thor


def calculate_gain(nonlinearity, param=None):
"""calculate_gain"""
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
res = 0
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
res = 1
elif nonlinearity == 'tanh':
res = 5.0 / 3
elif nonlinearity == 'relu':
res = math.sqrt(2.0)
elif nonlinearity == 'leaky_relu':
if param is None:
negative_slope = 0.01
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
# True/False are instances of int, hence check above
negative_slope = param
else:
raise ValueError("negative_slope {} not a valid number".format(param))
res = math.sqrt(2.0 / (1 + negative_slope ** 2))
else:
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
return res


def _calculate_fan_in_and_fan_out(tensor):
"""_calculate_fan_in_and_fan_out"""
dimensions = len(tensor)
if dimensions < 2:
raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
if dimensions == 2: # Linear
fan_in = tensor[1]
fan_out = tensor[0]
from mindspore.ops import functional as F
from mindspore.common.tensor import Tensor
from scipy.stats import truncnorm

def _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size):
fan_in = in_channel * kernel_size * kernel_size
scale = 1.0
scale /= max(1., fan_in)
stddev = (scale ** 0.5) / .87962566103423978
mu, sigma = 0, stddev
weight = truncnorm(-2, 2, loc=mu, scale=sigma).rvs(out_channel * in_channel * kernel_size * kernel_size)
weight = np.reshape(weight, (out_channel, in_channel, kernel_size, kernel_size))
return Tensor(weight, dtype=mstype.float32)

def _weight_variable(shape, factor=0.01):
init_value = np.random.randn(*shape).astype(np.float32) * factor
return Tensor(init_value)


def _conv3x3(in_channel, out_channel, stride=1, use_se=False):
if use_se:
weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=3)
else:
num_input_fmaps = tensor[1]
num_output_fmaps = tensor[0]
receptive_field_size = 1
if dimensions > 2:
receptive_field_size = tensor[2] * tensor[3]
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
return fan_in, fan_out


def _calculate_correct_fan(tensor, mode):
mode = mode.lower()
valid_modes = ['fan_in', 'fan_out']
if mode not in valid_modes:
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
return fan_in if mode == 'fan_in' else fan_out

weight_shape = (out_channel, in_channel, 3, 3)
weight = _weight_variable(weight_shape)
return nn.Conv2d(in_channel, out_channel,
kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight)

def kaiming_normal(inputs_shape, a=0, mode='fan_in', nonlinearity='leaky_relu'):
fan = _calculate_correct_fan(inputs_shape, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
return np.random.normal(0, std, size=inputs_shape).astype(np.float32)


def kaiming_uniform(inputs_shape, a=0, mode='fan_in', nonlinearity='leaky_relu'):
fan = _calculate_correct_fan(inputs_shape, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
return np.random.uniform(-bound, bound, size=inputs_shape).astype(np.float32)


def _conv3x3(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278):
weight_shape = (out_channel, in_channel, 3, 3)
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
return Conv2d_Thor(in_channel, out_channel,
kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight,
damping=damping, loss_scale=loss_scale, frequency=frequency)


def _conv1x1(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278):
weight_shape = (out_channel, in_channel, 1, 1)
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
return Conv2d_Thor(in_channel, out_channel,
kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight,
damping=damping, loss_scale=loss_scale, frequency=frequency)
def _conv1x1(in_channel, out_channel, stride=1, use_se=False):
if use_se:
weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=1)
else:
weight_shape = (out_channel, in_channel, 1, 1)
weight = _weight_variable(weight_shape)
return nn.Conv2d(in_channel, out_channel,
kernel_size=1, stride=stride, padding=0, pad_mode='same', weight_init=weight)


def _conv7x7(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278):
weight_shape = (out_channel, in_channel, 7, 7)
weight = Tensor(kaiming_normal(weight_shape, mode="fan_out", nonlinearity='relu'))
return Conv2d_Thor(in_channel, out_channel,
kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight,
damping=damping, loss_scale=loss_scale, frequency=frequency)
def _conv7x7(in_channel, out_channel, stride=1, use_se=False):
if use_se:
weight = _conv_variance_scaling_initializer(in_channel, out_channel, kernel_size=7)
else:
weight_shape = (out_channel, in_channel, 7, 7)
weight = _weight_variable(weight_shape)
return nn.Conv2d(in_channel, out_channel,
kernel_size=7, stride=stride, padding=0, pad_mode='same', weight_init=weight)


def _bn(channel):
@@ -120,14 +73,17 @@ def _bn(channel):

def _bn_last(channel):
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9,
gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
gamma_init=0, beta_init=0, moving_mean_init=0, moving_var_init=1)


def _fc(in_channel, out_channel, damping, loss_scale, frequency):
weight_shape = (out_channel, in_channel)
weight = Tensor(kaiming_uniform(weight_shape, a=math.sqrt(5)))
return Dense_Thor(in_channel, out_channel, has_bias=False, weight_init=weight,
bias_init=0, damping=damping, loss_scale=loss_scale, frequency=frequency)
def _fc(in_channel, out_channel, use_se=False):
if use_se:
weight = np.random.normal(loc=0, scale=0.01, size=out_channel*in_channel)
weight = Tensor(np.reshape(weight, (out_channel, in_channel)), dtype=mstype.float32)
else:
weight_shape = (out_channel, in_channel)
weight = _weight_variable(weight_shape)
return nn.Dense(in_channel, out_channel, has_bias=True, weight_init=weight, bias_init=0)


class ResidualBlock(nn.Cell):
@@ -138,6 +94,8 @@ class ResidualBlock(nn.Cell):
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer. Default: 1.
use_se (bool): enable SE-ResNet50 net. Default: False.
se_block(bool): use se block in SE-ResNet50 net. Default: False.

Returns:
Tensor, output tensor.
@@ -151,24 +109,29 @@ class ResidualBlock(nn.Cell):
in_channel,
out_channel,
stride=1,
damping=0.03,
loss_scale=1,
frequency=278):
use_se=False, se_block=False):
super(ResidualBlock, self).__init__()

self.stride = stride
self.use_se = use_se
self.se_block = se_block
channel = out_channel // self.expansion
self.conv1 = _conv1x1(in_channel, channel, stride=1, damping=damping, loss_scale=loss_scale,
frequency=frequency)
self.conv1 = _conv1x1(in_channel, channel, stride=1, use_se=self.use_se)
self.bn1 = _bn(channel)
if self.use_se and self.stride != 1:
self.e2 = nn.SequentialCell([_conv3x3(channel, channel, stride=1, use_se=True), _bn(channel),
nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='same')])
else:
self.conv2 = _conv3x3(channel, channel, stride=stride, use_se=self.use_se)
self.bn2 = _bn(channel)

self.conv2 = _conv3x3(channel, channel, stride=stride, damping=damping, loss_scale=loss_scale,
frequency=frequency)
self.bn2 = _bn(channel)

self.conv3 = _conv1x1(channel, out_channel, stride=1, damping=damping, loss_scale=loss_scale,
frequency=frequency)
self.conv3 = _conv1x1(channel, out_channel, stride=1, use_se=self.use_se)
self.bn3 = _bn_last(out_channel)

if self.se_block:
self.se_global_pool = P.ReduceMean(keep_dims=False)
self.se_dense_0 = _fc(out_channel, int(out_channel/4), use_se=self.use_se)
self.se_dense_1 = _fc(int(out_channel/4), out_channel, use_se=self.use_se)
self.se_sigmoid = nn.Sigmoid()
self.se_mul = P.Mul()
self.relu = nn.ReLU()

self.down_sample = False
@@ -178,10 +141,17 @@ class ResidualBlock(nn.Cell):
self.down_sample_layer = None

if self.down_sample:
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride,
damping=damping, loss_scale=loss_scale,
frequency=frequency),
_bn(out_channel)])
if self.use_se:
if stride == 1:
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel,
stride, use_se=self.use_se), _bn(out_channel)])
else:
self.down_sample_layer = nn.SequentialCell([nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='same'),
_conv1x1(in_channel, out_channel, 1,
use_se=self.use_se), _bn(out_channel)])
else:
self.down_sample_layer = nn.SequentialCell([_conv1x1(in_channel, out_channel, stride,
use_se=self.use_se), _bn(out_channel)])
self.add = P.Add()

def construct(self, x):
@@ -190,13 +160,23 @@ class ResidualBlock(nn.Cell):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)

if self.use_se and self.stride != 1:
out = self.e2(out)
else:
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.se_block:
out_se = out
out = self.se_global_pool(out, (2, 3))
out = self.se_dense_0(out)
out = self.relu(out)
out = self.se_dense_1(out)
out = self.se_sigmoid(out)
out = F.reshape(out, F.shape(out) + (1, 1))
out = self.se_mul(out, out_se)

if self.down_sample:
identity = self.down_sample_layer(identity)
@@ -218,6 +198,8 @@ class ResNet(nn.Cell):
out_channels (list): Output channel in each layer.
strides (list): Stride size in each layer.
num_classes (int): The number of classes that the training images are belonging to.
use_se (bool): enable SE-ResNet50 net. Default: False.
se_block(bool): use se block in SE-ResNet50 net in layer 3 and layer 4. Default: False.
Returns:
Tensor, output tensor.

@@ -237,57 +219,59 @@ class ResNet(nn.Cell):
out_channels,
strides,
num_classes,
damping,
loss_scale,
frequency):
use_se=False):
super(ResNet, self).__init__()

if not len(layer_nums) == len(in_channels) == len(out_channels) == 4:
raise ValueError("the length of layer_num, in_channels, out_channels list must be 4!")

self.conv1 = _conv7x7(3, 64, stride=2, damping=damping, loss_scale=loss_scale, frequency=frequency)
self.use_se = use_se
self.se_block = False
if self.use_se:
self.se_block = True

if self.use_se:
self.conv1_0 = _conv3x3(3, 32, stride=2, use_se=self.use_se)
self.bn1_0 = _bn(32)
self.conv1_1 = _conv3x3(32, 32, stride=1, use_se=self.use_se)
self.bn1_1 = _bn(32)
self.conv1_2 = _conv3x3(32, 64, stride=1, use_se=self.use_se)
else:
self.conv1 = _conv7x7(3, 64, stride=2)
self.bn1 = _bn(64)
self.relu = P.ReLU()
self.maxpool = P.MaxPoolWithArgmax(pad_mode="same", kernel_size=3, strides=2)

self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
self.layer1 = self._make_layer(block,
layer_nums[0],
in_channel=in_channels[0],
out_channel=out_channels[0],
stride=strides[0],
damping=damping,
loss_scale=loss_scale,
frequency=frequency)
use_se=self.use_se)
self.layer2 = self._make_layer(block,
layer_nums[1],
in_channel=in_channels[1],
out_channel=out_channels[1],
stride=strides[1],
damping=damping,
loss_scale=loss_scale,
frequency=frequency)
use_se=self.use_se)
self.layer3 = self._make_layer(block,
layer_nums[2],
in_channel=in_channels[2],
out_channel=out_channels[2],
stride=strides[2], damping=damping,
loss_scale=loss_scale,
frequency=frequency)
stride=strides[2],
use_se=self.use_se,
se_block=self.se_block)
self.layer4 = self._make_layer(block,
layer_nums[3],
in_channel=in_channels[3],
out_channel=out_channels[3],
stride=strides[3],
damping=damping,
loss_scale=loss_scale,
frequency=frequency)
use_se=self.use_se,
se_block=self.se_block)

self.mean = P.ReduceMean(keep_dims=True)
self.flatten = nn.Flatten()
self.end_point = _fc(out_channels[3], num_classes, damping=damping, loss_scale=loss_scale, frequency=frequency)
self.end_point = _fc(out_channels[3], num_classes, use_se=self.use_se)

def _make_layer(self, block, layer_num, in_channel, out_channel, stride,
damping, loss_scale, frequency):
def _make_layer(self, block, layer_num, in_channel, out_channel, stride, use_se=False, se_block=False):
"""
Make stage network of ResNet.

@@ -297,7 +281,7 @@ class ResNet(nn.Cell):
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer.
se_block(bool): use se block in SE-ResNet50 net. Default: False.
Returns:
SequentialCell, the output layer.

@@ -306,22 +290,34 @@ class ResNet(nn.Cell):
"""
layers = []

resnet_block = block(in_channel, out_channel, stride=stride,
damping=damping, loss_scale=loss_scale, frequency=frequency)
resnet_block = block(in_channel, out_channel, stride=stride, use_se=use_se)
layers.append(resnet_block)

for _ in range(1, layer_num):
resnet_block = block(out_channel, out_channel, stride=1,
damping=damping, loss_scale=loss_scale, frequency=frequency)
if se_block:
for _ in range(1, layer_num - 1):
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se)
layers.append(resnet_block)
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se, se_block=se_block)
layers.append(resnet_block)

else:
for _ in range(1, layer_num):
resnet_block = block(out_channel, out_channel, stride=1, use_se=use_se)
layers.append(resnet_block)
return nn.SequentialCell(layers)

def construct(self, x):
x = self.conv1(x)
if self.use_se:
x = self.conv1_0(x)
x = self.bn1_0(x)
x = self.relu(x)
x = self.conv1_1(x)
x = self.bn1_1(x)
x = self.relu(x)
x = self.conv1_2(x)
else:
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
c1, _ = self.maxpool(x)
c1 = self.maxpool(x)

c2 = self.layer1(c1)
c3 = self.layer2(c2)
@@ -335,7 +331,7 @@ class ResNet(nn.Cell):
return out


def resnet50(class_num=10, damping=0.03, loss_scale=1, frequency=278):
def resnet50(class_num=10):
"""
Get ResNet50 neural network.

@@ -348,12 +344,50 @@ def resnet50(class_num=10, damping=0.03, loss_scale=1, frequency=278):
Examples:
>>> net = resnet50(10)
"""
return ResNet(ResidualBlock,
[3, 4, 6, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
[1, 2, 2, 2],
class_num)

def se_resnet50(class_num=1001):
"""
Get SE-ResNet50 neural network.

Args:
class_num (int): Class number.

Returns:
Cell, cell instance of SE-ResNet50 neural network.

Examples:
>>> net = se-resnet50(1001)
"""
return ResNet(ResidualBlock,
[3, 4, 6, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
[1, 2, 2, 2],
class_num,
damping,
loss_scale,
frequency)
use_se=True)

def resnet101(class_num=1001):
"""
Get ResNet101 neural network.

Args:
class_num (int): Class number.

Returns:
Cell, cell instance of ResNet101 neural network.

Examples:
>>> net = resnet101(1001)
"""
return ResNet(ResidualBlock,
[3, 4, 23, 3],
[64, 256, 512, 1024],
[256, 512, 1024, 2048],
[1, 2, 2, 2],
class_num)

+ 0
- 202
tests/st/networks/models/resnet50/src_thor/thor.py View File

@@ -1,202 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""momentum"""
import mindspore.common.dtype as mstype
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
from mindspore.common.parameter import ParameterTuple
from mindspore.common.tensor import Tensor
from mindspore.nn.optim.optimizer import Optimizer
from mindspore.ops import functional as F, composite as C, operations as P
from mindspore.parallel._utils import _get_device_num, _get_gradients_mean

from .grad_reducer_thor import DistributedGradReducerThor

momentum_opt = C.MultitypeFuncGraph("momentum_opt")


@momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor")
def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, moment):
"""Apply momentum optimizer to the weight parameter using Tensor."""
success = True
success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum))
return success


op_add = P.AddN()
apply_decay = C.MultitypeFuncGraph("apply_decay")


@apply_decay.register("Number", "Bool", "Tensor", "Tensor")
def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
"""Get grad with weight_decay."""
if if_apply:
return op_add((weight * weight_decay, gradient))
return gradient


class THOR(Optimizer):
"""THOR"""

def __init__(self, params, learning_rate, momentum, matrix_A, matrix_G, A_inv_max, G_inv_max, weight_decay=0.0,
loss_scale=1.0,
decay_filter=lambda x: x.name not in []):
super(THOR, self).__init__(learning_rate, params, weight_decay, loss_scale)
if isinstance(momentum, float) and momentum < 0.0:
raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum))
self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum")
self.params = self.parameters
self.moments = self.params.clone(prefix="moments", init='zeros')
self.hyper_map = C.HyperMap()
self.opt = P.ApplyMomentum()
self.matrix_A = ParameterTuple(matrix_A)
self.matrix_G = ParameterTuple(matrix_G)
self.A_inv_max = ParameterTuple(A_inv_max)
self.G_inv_max = ParameterTuple(G_inv_max)
self.cube_matmul_left = P.CusMatMulCubeFraczLeftCast()
self.cube_matmul_left_fc = P.CusMatMulCubeDenseLeft()
self.cube_matmul_right_fc = P.CusMatMulCubeDenseRight()
self.cube_matmul_right_mul = P.CusMatMulCubeFraczRightMul()
self.transpose = P.Transpose()
self.shape = P.Shape()
self.reshape = P.Reshape()
self.mul = P.Mul()
self.weight_idx = []
for i in range(len(self.params)):
if "conv" in self.params[i].name or "end_point" in self.params[i].name:
self.weight_idx.append(i)
self.weight_idx.append(len(self.params))
self.feature_map = [1.0 / 12544, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136,
1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136, 1.0 / 3136,
1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784,
1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784, 1.0 / 784,
1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196,
1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196, 1.0 / 196,
1.0 / 196, 1.0 / 196, 1.0 / 196,
1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49, 1.0 / 49,
1.0]
mean = _get_gradients_mean()
degree = _get_device_num()
parameter_length = len(self.feature_map)
self.grad_reducer_Amax = DistributedGradReducerThor(parameter_length, ((27,), 2), mean, degree)
self.grad_reducer_Gmax = DistributedGradReducerThor(parameter_length, ((27,), 4), mean, degree)
self.grad_reducer_A = DistributedGradReducerThor(parameter_length, ((27,), 6), mean, degree)
self.grad_reducer_G = DistributedGradReducerThor(parameter_length, ((27,), 8), mean, degree)
self.matrix_A_inv = ()
self.matrix_G_inv = ()
self.matrix_max_inv = ()

for i in range(54):
self.matrix_max_inv = self.matrix_max_inv + (
Parameter(initializer(1, [1], mstype.float32), name="matrix_max" + str(i), requires_grad=False),)
self.log = P.Log()
self.exp = P.Exp()
self.sqrt = P.Sqrt()
self.matrix_max_inv = ParameterTuple(self.matrix_max_inv)
self.assign = P.Assign()
self.cast = P.Cast()
self.thor = True
self.weight_decay = weight_decay * loss_scale
self.decay_flags = tuple(decay_filter(x) for x in self.parameters)

def construct(self, gradients):
params = self.params
moments = self.moments
if self.thor:
matrix_A_allreduce = ()
matrix_G_allreduce = ()
matrix_A_max_allreduce = ()
matrix_G_max_allreduce = ()
for i in range(54):
g = gradients[i * 3]
matrix_A = self.matrix_A[i]
matrix_G = self.matrix_G[i]
A_max = self.A_inv_max[i]
G_max = self.G_inv_max[i]
matrix_A = F.depend(matrix_A, g)
matrix_G = F.depend(matrix_G, g)
A_max = F.depend(A_max, g)
G_max = F.depend(G_max, g)
matrix_A_allreduce = matrix_A_allreduce + (matrix_A,)
matrix_G_allreduce = matrix_G_allreduce + (matrix_G,)
matrix_A_max_allreduce = matrix_A_max_allreduce + (A_max,)
matrix_G_max_allreduce = matrix_G_max_allreduce + (G_max,)
matrix_A_allreduce = self.grad_reducer_A(matrix_A_allreduce)
matrix_G_allreduce = self.grad_reducer_G(matrix_G_allreduce)
matrix_A_max_allreduce = self.grad_reducer_Amax(matrix_A_max_allreduce)
matrix_G_max_allreduce = self.grad_reducer_Gmax(matrix_G_max_allreduce)
new_grads = ()
for i in range(54):
g = gradients[i * 3]
temp_a = matrix_A_allreduce[i]
temp_g = matrix_G_allreduce[i]
temp_a = self.cast(temp_a, mstype.float32)
temp_g = self.cast(temp_g, mstype.float32)
matrix_A_inv_max = self.log(matrix_A_max_allreduce[i])
matrix_A_inv_max = self.mul(matrix_A_inv_max, -1)
matrix_A_inv_max = self.exp(matrix_A_inv_max)
temp_a = self.mul(temp_a, matrix_A_inv_max)
matrix_G_inv_max = self.log(matrix_G_max_allreduce[i])
matrix_G_inv_max = self.mul(matrix_G_inv_max, -1)
matrix_G_inv_max = self.exp(matrix_G_inv_max)
temp_g = self.mul(temp_g, matrix_G_inv_max)
temp_max = self.mul(matrix_A_max_allreduce[i], matrix_G_max_allreduce[i])
temp_max = self.mul(temp_max, self.feature_map[i])
temp_a = self.cast(temp_a, mstype.float16)
temp_g = self.cast(temp_g, mstype.float16)
if i == 53:
g = self.cube_matmul_left_fc(temp_g, g)
g = self.cube_matmul_right_fc(g, temp_a, temp_max)
else:
g = self.cube_matmul_left(temp_g, g)
g = self.cube_matmul_right_mul(g, temp_a, temp_max)
fake_A = self.assign(self.matrix_A[i], temp_a)
fake_G = self.assign(self.matrix_G[i], temp_g)
fake_max = self.assign(self.matrix_max_inv[i], temp_max)
g = F.depend(g, fake_A)
g = F.depend(g, fake_G)
g = F.depend(g, fake_max)
if i == 53:
new_grads = new_grads + (g,)
else:
new_grads = new_grads + (g, gradients[i * 3 + 1], gradients[i * 3 + 2])
gradients = new_grads
else:
new_grads = ()
for i in range(54):
g = gradients[i * 3]
matrix_A = self.matrix_A[i]
matrix_G = self.matrix_G[i]
matrix_max = self.matrix_max_inv[i]
matrix_A = F.depend(matrix_A, g)
matrix_G = F.depend(matrix_G, g)
matrix_max = F.depend(matrix_max, g)
if i == 53:
g = self.cube_matmul_left_fc(matrix_G, g)
g = self.cube_matmul_right_fc(g, matrix_A, matrix_max)
new_grads = new_grads + (g,)
else:
g = self.cube_matmul_left(matrix_G, g)
g = self.cube_matmul_right_mul(g, matrix_A, matrix_max)
new_grads = new_grads + (g, gradients[i * 3 + 1], gradients[i * 3 + 2])
gradients = new_grads

if self.weight_decay > 0:
gradients = self.hyper_map(F.partial(apply_decay, self.weight_decay), self.decay_flags,
params, gradients)
gradients = self.scale_grad(gradients)
lr = self.get_lr()
success = self.hyper_map(F.partial(momentum_opt, self.opt, lr, self.momentum), gradients, params, moments)
return success

+ 0
- 479
tests/st/networks/models/resnet50/src_thor/thor_layer.py View File

@@ -1,479 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""thor_layer"""
import numpy as np
import mindspore as ms
import mindspore.common.dtype as mstype
from mindspore._checkparam import Validator, twice
from mindspore._extends import cell_attr_register
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
from mindspore.common.tensor import Tensor
from mindspore.nn.cell import Cell
from mindspore.nn.layer.activation import get_activation
from mindspore.ops import operations as P

C0 = 16


def caculate_device_shape(matrix_dim, channel, is_A):
ll = (0)
if is_A:
if channel // C0 == 0:
matrix_dim = (matrix_dim / channel) * C0
ll = (int(matrix_dim // C0), int(matrix_dim // C0), C0, C0), int(matrix_dim)
else:
ll = (int(matrix_dim // C0), int(matrix_dim // C0), C0, C0), int(matrix_dim)
return ll


class _Conv(Cell):
r"""Applies a N-D convolution over an input signal composed of several input
planes.
"""

def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
pad_mode,
padding,
dilation,
group,
data_format,
has_bias,
weight_init,
bias_init,
):
super(_Conv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.pad_mode = pad_mode
self.padding = padding
self.dilation = dilation
self.group = group
self.data_format = data_format
self.has_bias = has_bias
if not (isinstance(in_channels, int) and in_channels > 0):
raise ValueError('Attr \'in_channels\' of \'Conv2D\' Op passed '
+ str(in_channels) + ', should be a int and greater than 0.')
if (not isinstance(kernel_size, tuple)) or len(kernel_size) != 2 or \
(not isinstance(kernel_size[0], int)) or (not isinstance(kernel_size[1], int)) or \
kernel_size[0] < 1 or kernel_size[1] < 1:
raise ValueError('Attr \'kernel_size\' of \'Conv2D\' Op passed '
+ str(self.kernel_size) + ', should be a int or tuple and equal to or greater than 1.')
if in_channels % group != 0:
raise ValueError('Attr \'in_channels\' of \'Conv2D\' Op must be divisible by '
'attr \'group\' of \'Conv2D\' Op.')
if out_channels % group != 0:
raise ValueError('Attr \'out_channels\' of \'Conv2D\' Op must be divisible by '
'attr \'group\' of \'Conv2D\' Op.')

self.weight = Parameter(initializer(
weight_init, [out_channels, in_channels // group, *kernel_size]), name='weight')

if Validator.check_bool(has_bias):
self.bias = Parameter(_initializer(
bias_init, [out_channels]), name='bias')
else:
if bias_init != 'zeros':
logger.warning("Value of 'has_bias' is False, value of 'bias_init' will be ignored.")
self.bias = None

def construct(self, *inputs):
raise NotImplementedError


class Conv2d_Thor(_Conv):
"""Conv2d_Thor"""

def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
pad_mode='same',
padding=0,
dilation=1,
group=1,
data_format='NCHW',
has_bias=False,
weight_init='normal',
damping=0.03,
loss_scale=1,
frequency=278,
bias_init='zeros'):
self.thor = True
ksizes = (1, kernel_size, kernel_size, 1)
self.hw = kernel_size * kernel_size
strides = (1, stride, stride, 1)
kernel_size = twice(kernel_size)
super(Conv2d_Thor, self).__init__(
in_channels,
out_channels,
kernel_size,
stride,
pad_mode,
padding,
dilation,
group,
data_format,
has_bias,
weight_init,
bias_init,
)
self.conv2d = P.Conv2D(out_channel=self.out_channels,
kernel_size=self.kernel_size,
mode=1,
pad_mode=self.pad_mode,
pad=self.padding,
stride=self.stride,
dilation=self.dilation,
group=self.group
)

self.img2col = P.CusImg2Col(ksizes=ksizes, strides=strides)
self.cube_matmul = P.CusMatMulCube(transpose_a=True)
self.matrix_combine = P.CusMatrixCombine()
self.cholesky = P.CusCholeskyTrsm()
self.transpose02314 = P.CusTranspose02314()
self.matrix_A_dim = self.in_channels * self.kernel_size[0] * self.kernel_size[1]
self.matrix_G_dim = self.out_channels
self.matrix_A_device_shape, self.matrix_A_device_dim = caculate_device_shape(self.matrix_A_dim,
self.in_channels, True)
self.matrix_G_device_shape, self.matrix_G_device_dim = caculate_device_shape(self.matrix_G_dim,
self.in_channels, False)
self.matrix_A_device_temp_shape = (
self.matrix_A_device_shape[0], self.matrix_A_device_shape[2], self.matrix_A_device_shape[1],
self.matrix_A_device_shape[3])
self.matrix_G_device_temp_shape = (
self.matrix_G_device_shape[0], self.matrix_G_device_shape[2], self.matrix_G_device_shape[1],
self.matrix_G_device_shape[3])
self.matrix_A_inv = Parameter(
Tensor(np.reshape(np.identity(self.matrix_A_device_dim).astype(np.float16), self.matrix_A_device_shape)),
name='matrix_A_inv', requires_grad=False)
self.A_inv_max = Parameter(initializer(0, [1], mstype.float32), name="A_inv_max", requires_grad=False)
self.matrix_G_inv = Parameter(
Tensor(np.reshape(np.identity(self.matrix_G_device_dim).astype(np.float16), self.matrix_G_device_shape)),
name="matrix_G_inv", requires_grad=False)

self.G_inv_max = Parameter(initializer(0, [1], mstype.float32), name="G_inv_max", requires_grad=False)
self.fake_G = Tensor(
np.reshape(np.identity(self.matrix_G_device_dim).astype(np.float16), self.matrix_G_device_shape))

self.shape = P.Shape()
self.reshape = P.Reshape()
self.transpose = P.Transpose()
self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False)
self.mul = P.Mul()
self.cast = P.Cast()
self.damping = Tensor(damping)
self.vector_matmul = P.CusBatchMatMul()
self.diag_block_dim = 128
self.channels_slice_flag = False
if self.in_channels % C0 != 0:
self.channels_slice_flag = True

self.padA_flag = False
if (self.matrix_A_dim // self.diag_block_dim) * self.diag_block_dim != self.matrix_A_dim \
and self.matrix_A_dim > self.diag_block_dim:
self.padA_flag = True
pad_dim = self.diag_block_dim - self.matrix_A_dim % self.diag_block_dim
self.padA = P.Pad(((0, pad_dim), (0, pad_dim)))
self.device_shape_pad_flag = False
if self.matrix_A_dim != self.matrix_A_device_dim:
self.device_shape_pad_flag = True
self.device_shape_pad = P.Pad(((0, 0), (0, C0 - self.in_channels), (0, 0), (0, C0 - self.in_channels)))
self.slice = P.Slice()
self.gather = P.Gather()
self.freq = Tensor(frequency, mstype.int32)
self.loss_scale = Tensor(1 / loss_scale, mstype.float16)
self.axis = 0

dampingA_dim = self.matrix_A_dim
if (self.matrix_A_dim % self.diag_block_dim) != 0 and self.matrix_A_dim > self.diag_block_dim:
dampingA_dim = (self.matrix_A_dim // self.diag_block_dim + 1) * self.diag_block_dim
dampingG_dim = self.matrix_G_dim
if (self.matrix_G_dim % self.diag_block_dim) != 0 and self.matrix_G_dim > self.diag_block_dim:
dampingG_dim = (self.matrix_G_dim // self.diag_block_dim + 1) * self.diag_block_dim

self.dampingA = Tensor(np.identity(dampingA_dim), mstype.float32)
self.dampingG = Tensor(np.identity(dampingG_dim), mstype.float32)
self.fused_abs_max1 = P.CusFusedAbsMax1([self.matrix_A_dim, self.matrix_A_dim])
self.fused_abs_max2 = P.CusFusedAbsMax1()
self.log = P.Log()
self.exp = P.Exp()
self.sqrt = P.Sqrt()
self.getG = P.InsertGradientOf(self.save_gradient)

def save_gradient(self, dout):
"""save_gradient"""
out = dout
dout = self.mul(dout, self.loss_scale)
dout = self.mul(dout, 32.0)
dout = self.transpose02314(dout)
dout_shape = self.shape(dout)
normalizer = dout_shape[0]

matrix_G = self.cube_matmul(dout, dout)
normalizer = self.cast(normalizer, ms.float32)
matrix_G = self.mul(matrix_G, 1.0 / normalizer)
damping_step = self.gather(self.damping, self.cov_step, 0)
self.cov_step = self.cov_step + self.freq
damping_step = self.cast(damping_step, mstype.float32)
damping = self.mul(damping_step, 32.0 / normalizer)
damping = self.sqrt(damping)
dampingG = self.cast(self.dampingG, mstype.float32)
matrix_G = matrix_G + damping * dampingG

matrix_G_inv = self.cholesky(matrix_G)
matrix_G_inv = self.vector_matmul(matrix_G_inv, matrix_G_inv)
matrix_G_inv_max = self.fused_abs_max2(matrix_G_inv)
matrix_G_inv_max = self.fused_abs_max2(matrix_G_inv_max)
self.G_inv_max = matrix_G_inv_max
matrix_G_inv = self.matrix_combine(matrix_G_inv)
matrix_G_inv = self.reshape(matrix_G_inv, self.matrix_G_device_temp_shape)
matrix_G_inv = self.transpose(matrix_G_inv, (2, 0, 1, 3))
matrix_G = self.cast(matrix_G_inv, mstype.float16)
self.matrix_G_inv = matrix_G
return out

def construct(self, x):
if self.thor:
matrix_A = self.img2col(x)
matrix_A_shape = self.shape(matrix_A)
normalizer = matrix_A_shape[0]
matrix_A = self.cube_matmul(matrix_A, matrix_A)

if self.channels_slice_flag:
matrix_A = self.reshape(matrix_A, (self.hw, C0, self.hw, C0))
matrix_A = self.slice(matrix_A, (0, 0, 0, 0), (self.hw, self.in_channels, self.hw, self.in_channels))
matrix_A = self.reshape(matrix_A, (self.matrix_A_dim, self.matrix_A_dim))
normalizer = self.cast(normalizer, ms.float32)
matrix_A = self.mul(matrix_A, 1.0 / normalizer)
if self.padA_flag:
matrix_A = self.padA(matrix_A)
damping_step = self.gather(self.damping, self.cov_step, self.axis)
damping_step = self.cast(damping_step, mstype.float32)
damping = self.mul(damping_step, 32.0 / normalizer)
damping = self.sqrt(damping)
damping_A = self.cast(self.dampingA, mstype.float32)
matrix_A = matrix_A + damping * damping_A
matrix_A_inv = self.cholesky(matrix_A)
matrix_A_inv = self.vector_matmul(matrix_A_inv, matrix_A_inv)
matrix_A_inv_max = self.fused_abs_max1(matrix_A_inv)
matrix_A_inv_max = self.fused_abs_max2(matrix_A_inv_max)
self.A_inv_max = matrix_A_inv_max
matrix_A_inv = self.matrix_combine(matrix_A_inv)
matrix_A_inv = self.cast(matrix_A_inv, mstype.float16)
if self.padA_flag:
matrix_A_inv = self.slice(matrix_A_inv, (0, 0), (self.matrix_A_dim, self.matrix_A_dim))

if self.device_shape_pad_flag:
matrix_A_inv = self.reshape(matrix_A_inv, (self.hw, self.in_channels, self.hw, self.in_channels))
matrix_A_inv = self.device_shape_pad(matrix_A_inv)
matrix_A_inv = self.reshape(matrix_A_inv, self.matrix_A_device_temp_shape)
matrix_A_inv = self.transpose(matrix_A_inv, (2, 0, 1, 3))
self.matrix_A_inv = matrix_A_inv
self.matrix_G_inv = self.fake_G
out = self.conv2d(x, self.weight)
out = self.getG(out)
else:
out = self.conv2d(x, self.weight)

return out

def extra_repr(self):
"""extra_repr"""
s = 'input_channels={}, output_channels={}, kernel_size={},' \
'stride={}, pad_mode={}, padding={}, dilation={}, ' \
'group={}, data_format={}, has_bias={},' \
'weight_init={}, bias_init={}'.format(
self.in_channels,
self.out_channels,
self.kernel_size,
self.stride,
self.pad_mode,
self.padding,
self.dilation,
self.group,
self.data_format,
self.has_bias,
self.weight,
self.bias)

if self.has_bias:
s += ', bias={}'.format(self.bias)
return s


class Dense_Thor(Cell):
"""Dense_Thor"""

@cell_attr_register(attrs=['has_bias', 'activation'])
def __init__(self,
in_channels,
out_channels,
weight_init='normal',
bias_init='zeros',
damping=0.03,
loss_scale=1,
frequency=278,
has_bias=True,
activation=None):
super(Dense_Thor, self).__init__()
self.in_channels = Validator.check_positive_int(in_channels)
self.out_channels = Validator.check_positive_int(out_channels)
self.has_bias = Validator.check_bool(has_bias)
self.thor = True
if isinstance(weight_init, Tensor):
if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \
weight_init.shape[1] != in_channels:
raise ValueError("weight_init shape error")

self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")

if self.has_bias:
if isinstance(bias_init, Tensor):
if bias_init.ndim != 1 or bias_init.shape[0] != out_channels:
raise ValueError("bias_init shape error")

self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")

self.matmul = P.MatMul(transpose_b=True)
self.bias_add = P.BiasAdd()

self.activation = get_activation(activation)
self.activation_flag = self.activation is not None

self.matrix_A_inv = Parameter(Tensor(np.zeros([128, 128, 16, 16]).astype(np.float16)), name='matrix_A_inv',
requires_grad=False)
self.matrix_G_inv = Parameter(Tensor(np.zeros([63, 63, 16, 16]).astype(np.float16)), name="matrix_G_inv",
requires_grad=False)
self.fake_G = Tensor(np.zeros([63, 63, 16, 16]).astype(np.float16))

self.matmul = P.MatMul(transpose_b=True)
self.cube_matmul = P.CusMatMulCube(transpose_a=True)
self.matrix_combine = P.CusMatrixCombine()
self.cholesky = P.CusCholeskyTrsm()
self.shape = P.Shape()
self.reshape = P.Reshape()
self.transpose = P.Transpose()
self.cov_step = Parameter(initializer(0, [1], mstype.int32), name="cov_step", requires_grad=False)
self.mul = P.Mul()
self.cast = P.Cast()
self.damping = Tensor(damping)
self.loss_scale = Tensor(1 / loss_scale, mstype.float16)
self.vector_matmul = P.CusBatchMatMul()
self.pad = P.Pad(((0, 24), (0, 24)))
self.pad1 = P.Pad(((0, 8), (0, 8)))
self.slice = P.Slice()
self.gather = P.Gather()
self.assignadd = P.AssignAdd()
self.freq = Tensor(frequency, mstype.int32)
self.axis = 0
self.A_inv_max = Parameter(initializer(0, [1], mstype.float32), name="A_inv_max", requires_grad=False)
self.G_inv_max = Parameter(initializer(0, [1], mstype.float32), name="G_inv_max", requires_grad=False)
self.fused_abs_max1 = P.CusFusedAbsMax1([1000, 1000])
self.fused_abs_max2 = P.CusFusedAbsMax1()
self.log = P.Log()
self.exp = P.Exp()
self.dampingA = Tensor(np.identity(2048), mstype.float32)
self.dampingG = Tensor(np.identity(1024), mstype.float32)
self.add = P.Add()
self.sqrt = P.Sqrt()
self.getG = P.InsertGradientOf(self.save_gradient)

def save_gradient(self, dout):
"""save_gradient"""
out = dout
dout = self.mul(dout, self.loss_scale)
dout = self.mul(dout, 32.0)
normalizer = 32
matrix_G = self.cube_matmul(dout, dout)
normalizer = self.cast(normalizer, ms.float32)
matrix_G = self.mul(matrix_G, 1.0 / normalizer)
matrix_G = self.pad(matrix_G)
damping_step = self.gather(self.damping, self.cov_step, 0)
damping_step = self.cast(damping_step, mstype.float32)
self.cov_step = self.cov_step + self.freq
damping = self.sqrt(damping_step)
dampingG = self.cast(self.dampingG, mstype.float32)
matrix_G = matrix_G + damping * dampingG
matrix_G_inv = self.cholesky(matrix_G)
matrix_G_inv = self.vector_matmul(matrix_G_inv, matrix_G_inv)
matrix_G_inv_max = self.fused_abs_max1(matrix_G_inv)
matrix_G_inv_max = self.fused_abs_max2(matrix_G_inv_max)
self.G_inv_max = matrix_G_inv_max
matrix_G_inv = self.matrix_combine(matrix_G_inv)
matrix_G_inv = self.slice(matrix_G_inv, (0, 0), (1000, 1000))
matrix_G_inv = self.pad1(matrix_G_inv)
matrix_G_inv_shape = self.shape(matrix_G_inv)
matrix_G_inv = self.reshape(matrix_G_inv, (matrix_G_inv_shape[0] / 16, 16, matrix_G_inv_shape[0] / 16, 16))
matrix_G_inv = self.transpose(matrix_G_inv, (2, 0, 1, 3))
matrix_G_inv = self.cast(matrix_G_inv, mstype.float16)
self.matrix_G_inv = matrix_G_inv
return out

def construct(self, x):
"""construct"""
if self.thor:
inputs = self.cube_matmul(x, x)
normalizer = 32
normalizer = self.cast(normalizer, ms.float32)
matrix_A = self.mul(inputs, 1.0 / normalizer)

damping_step = self.gather(self.damping, self.cov_step, self.axis)
damping_step = self.cast(damping_step, mstype.float32)
damping = self.sqrt(damping_step)
dampingA = self.cast(self.dampingA, mstype.float32)
matrix_A = matrix_A + damping * dampingA
matrix_A_inv = self.cholesky(matrix_A)
matrix_A_inv = self.vector_matmul(matrix_A_inv, matrix_A_inv)

matrix_A_inv_max = self.fused_abs_max2(matrix_A_inv)
matrix_A_inv_max = self.fused_abs_max2(matrix_A_inv_max)
self.A_inv_max = matrix_A_inv_max

matrix_A_inv = self.matrix_combine(matrix_A_inv)
matrix_A_inv_shape = self.shape(matrix_A_inv)
matrix_A_inv = self.reshape(matrix_A_inv, (matrix_A_inv_shape[0] / 16, 16, matrix_A_inv_shape[0] / 16, 16))
matrix_A_inv = self.transpose(matrix_A_inv, (2, 0, 1, 3))
matrix_A_inv = self.cast(matrix_A_inv, mstype.float16)
self.matrix_A_inv = matrix_A_inv
self.matrix_G_inv = self.fake_G
output = self.matmul(x, self.weight)
output = self.getG(output)
else:
output = self.matmul(x, self.weight)

if self.has_bias:
output = self.bias_add(output, self.bias)
if self.activation_flag:
return self.activation(output)
return output

def extend_repr(self):
"""extend_repr"""
s = 'in_channels={}, out_channels={}'.format(self.in_channels, self.out_channels)
if self.has_bias:
s += ', has_bias={}'.format(self.has_bias)
if self.activation_flag:
s += ', activation={}'.format(self.activation)

return str_info

+ 21
- 24
tests/st/networks/models/resnet50/test_resnet50_imagenet.py View File

@@ -21,7 +21,8 @@ from multiprocessing import Process, Queue
import pytest
import numpy as np

from mindspore import context, Tensor
from mindspore import context
from mindspore.common.tensor import Tensor
from mindspore.communication.management import init
from mindspore.train.model import Model
from mindspore.context import ParallelMode
@@ -29,6 +30,7 @@ from mindspore.train.callback import Callback
from mindspore.train.loss_scale_manager import FixedLossScaleManager
import mindspore.nn as nn
import mindspore.dataset as ds
from mindspore.nn.optim import THOR

from tests.st.networks.models.resnet50.src.resnet import resnet50
from tests.st.networks.models.resnet50.src.dataset import create_dataset
@@ -39,7 +41,7 @@ from tests.st.networks.models.resnet50.src.CrossEntropySmooth import CrossEntrop
from tests.st.networks.models.resnet50.src_thor.config import config as thor_config
from tests.st.networks.models.resnet50.src_thor.model_thor import Model as THOR_Model
from tests.st.networks.models.resnet50.src_thor.resnet import resnet50 as resnet50_thor
from tests.st.networks.models.resnet50.src_thor.thor import THOR

MINDSPORE_HCCL_CONFIG_PATH = "/home/workspace/mindspore_config/hccl/rank_tabel_4p/rank_table_4p_1.json"
MINDSPORE_HCCL_CONFIG_PATH_2 = "/home/workspace/mindspore_config/hccl/rank_tabel_4p/rank_table_4p_2.json"
@@ -50,7 +52,8 @@ np.random.seed(1)
ds.config.set_seed(1)
os.environ['GLOG_v'] = str(2)

def get_model_lr(global_step, lr_init, decay, total_epochs, steps_per_epoch):

def get_thor_lr(global_step, lr_init, decay, total_epochs, steps_per_epoch, decay_epochs=100):
"""get_model_lr"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
@@ -58,9 +61,9 @@ def get_model_lr(global_step, lr_init, decay, total_epochs, steps_per_epoch):
epoch = (i + 1) / steps_per_epoch
base = (1.0 - float(epoch) / total_epochs) ** decay
lr_local = lr_init * base
if epoch >= 39:
if epoch >= decay_epochs:
lr_local = lr_local * 0.5
if epoch >= 40:
if epoch >= decay_epochs + 1:
lr_local = lr_local * 0.5
lr_each_step.append(lr_local)
current_step = global_step
@@ -69,7 +72,7 @@ def get_model_lr(global_step, lr_init, decay, total_epochs, steps_per_epoch):
return learning_rate


def get_model_damping(global_step, damping_init, decay_rate, total_epochs, steps_per_epoch):
def get_thor_damping(global_step, damping_init, decay_rate, total_epochs, steps_per_epoch):
"""get_model_damping"""
damping_each_step = []
total_steps = steps_per_epoch * total_epochs
@@ -77,7 +80,6 @@ def get_model_damping(global_step, damping_init, decay_rate, total_epochs, steps
epoch = (step + 1) / steps_per_epoch
damping_here = damping_init * (decay_rate ** (epoch / 10))
damping_each_step.append(damping_here)

current_step = global_step
damping_each_step = np.array(damping_each_step).astype(np.float32)
damping_now = damping_each_step[current_step:]
@@ -140,6 +142,7 @@ def train_process(q, device_id, epoch_size, device_num, enable_hccl):
init()

# network

net = resnet50(class_num=config.class_num)

# evaluation network
@@ -160,7 +163,7 @@ def train_process(q, device_id, epoch_size, device_num, enable_hccl):
eval_interval = config.eval_interval
dataset.__loop_size__ = step_size * eval_interval

# evalutation dataset
# evaluation dataset
eval_dataset = create_dataset(dataset_path=eval_path, do_train=False,
repeat_num=1, batch_size=config.eval_batch_size)

@@ -233,16 +236,11 @@ def train_process_thor(q, device_id, epoch_size, device_num, enable_hccl):
os.environ['RANK_SIZE'] = str(device_num)
if enable_hccl:
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True, all_reduce_fusion_config=[107])
gradients_mean=True, all_reduce_fusion_config=[85, 160])
init()

# network
damping = get_model_damping(0, 0.03, 0.87, 50, 5004)
net = resnet50_thor(class_num=thor_config.class_num, damping=damping, loss_scale=thor_config.loss_scale,
frequency=thor_config.frequency)

# evaluation network
dist_eval_network = ClassifyCorrectCell(net)
net = resnet50_thor(thor_config.class_num)

if not thor_config.label_smooth:
thor_config.label_smooth_factor = 0.0
@@ -258,7 +256,7 @@ def train_process_thor(q, device_id, epoch_size, device_num, enable_hccl):
step_size = dataset.get_dataset_size()
eval_interval = thor_config.eval_interval

# evalutation dataset
# evaluation dataset
eval_dataset = create_dataset(dataset_path=eval_path, do_train=False,
repeat_num=1, batch_size=thor_config.eval_batch_size)

@@ -266,16 +264,15 @@ def train_process_thor(q, device_id, epoch_size, device_num, enable_hccl):
loss_scale = FixedLossScaleManager(thor_config.loss_scale, drop_overflow_update=False)

# learning rate
lr = Tensor(get_model_lr(0, 0.045, 6, 70, 5004))
lr = get_thor_lr(0, 0.05803, 4.04839, 53, 5004, decay_epochs=39)
damping = get_thor_damping(0, 0.02714, 0.50036, 70, 5004)
# optimizer
opt = THOR(filter(lambda x: x.requires_grad, net.get_parameters()), lr, thor_config.momentum,
filter(lambda x: 'matrix_A' in x.name, net.get_parameters()),
filter(lambda x: 'matrix_G' in x.name, net.get_parameters()),
filter(lambda x: 'A_inv_max' in x.name, net.get_parameters()),
filter(lambda x: 'G_inv_max' in x.name, net.get_parameters()),
thor_config.weight_decay, thor_config.loss_scale)
split_indices = [26, 53]
opt = THOR(net, Tensor(lr), Tensor(damping), thor_config.momentum, thor_config.weight_decay, thor_config.loss_scale,
thor_config.batch_size, split_indices=split_indices)

# evaluation network
dist_eval_network = ClassifyCorrectCell(net)
# model
model = THOR_Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, amp_level="O2",
keep_batchnorm_fp32=False,


Loading…
Cancel
Save