# Copyright 2022 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """UT test example for conformer.""" import math import numpy as np import pytest import mindspore import mindspore.nn as nn import mindspore.common.initializer as Init from mindspore import Tensor, context, Parameter from mindspore.ops import operations as P from mindspore.common.initializer import initializer from mindspore.ops import functional as F from mindspore.common.initializer import TruncatedNormal, HeNormal from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell from mindspore.nn.loss.loss import LossBase mindspore.set_seed(0) np.random.seed(0) def flatten(input_tensor, start_dim): shape = input_tensor.shape new_shape = shape[:start_dim] dims = 1 for i in range(start_dim, len(shape)): dims = dims * shape[i] return input_tensor.reshape(new_shape+(dims,)) def one_hot_int(label, num_classes): num_elements = label.size one_hot_label = np.zeros((num_elements, num_classes), dtype=np.int32) for index in range(num_elements): one_hot_label[index][label[index]] = 1 return Tensor(one_hot_label, mindspore.float32) class CrossEntropySmooth(LossBase): """CrossEntropy""" def __init__(self, reduction='mean', is_auto_parallel=False): super(CrossEntropySmooth, self).__init__() self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction) if is_auto_parallel: self.ce.reduce_mean.add_prim_attr("cross_batch", True) def construct(self, logit, label): loss = None idx = 0 for o in logit: o = F.cast(o, mindspore.float32) loss = self.ce(o, label) / len(logit) if idx == 0 else loss + self.ce(o, label) / len(logit) idx = idx + 1 return loss class NetWithLossCell(nn.Cell): """Metwithlosscell""" def __init__(self, backbone, loss_fn): super(NetWithLossCell, self).__init__(auto_prefix=False) self._backbone = backbone self._loss_fn = loss_fn def construct(self, data, label): output = self._backbone(data) loss = self._loss_fn(output, label) return loss class DropPath(nn.Cell): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ def __init__(self, drop_prob=None, num_dimension=4, dp=1): super(DropPath, self).__init__() self.drop_prob = drop_prob strategy_feat = (dp,) + (1,)*(num_dimension-1) self.uniformreal = P.UniformReal().shard((strategy_feat,)) self.floor = P.Floor().shard((strategy_feat,)) self.div = P.Div().shard((strategy_feat, ())) self.mul = P.Mul().shard((strategy_feat, strategy_feat)) self.add = P.Add().shard(((), strategy_feat)) def drop_path(self, x, drop_prob=0., training=True): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. """ if drop_prob == 0. or not training: return x keep_prob = 1 - drop_prob shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets random_tensor = self.add(keep_prob, F.cast(self.uniformreal(shape), mindspore.float32)) random_tensor = self.floor(random_tensor) output = self.mul(self.div(x, keep_prob), random_tensor) return output # fp32 def construct(self, x): return self.drop_path(x, self.drop_prob, self.training) class Norm(nn.Cell): r""" A self-defined layer norm operation using reduce sum and reduce mean Args: normalized_shape (tuple): The shape of the input tensor dp (int): The data parallel way of the inputs, Default:1 eps (float): The epsilon value of the denominator. Default 1e-5. Inputs: - **x** (Tensor) - Tensor of shape :math:`(batch, seq\_length, hidden\_size)`. Outputs: Tensor of shape :math:`(batch, seq_length, hidden_size)`. """ def __init__(self, normalized_shape, axes=-1, num_dimension=3, affine=True, dp=1, eps=1e-5, is_gn=False, num_groups=1): super(Norm, self).__init__() gamma = initializer('ones', normalized_shape) beta = initializer('zeros', normalized_shape) if affine: self.gamma = Parameter(gamma, name="gamma", parallel_optimizer=False) self.beta = Parameter(beta, name="beta", parallel_optimizer=False) else: self.gamma = gamma self.beta = beta strategy = [dp if i == 0 else 1 for i in range(num_dimension)] strategy = tuple(strategy) if is_gn: strategy1 = [dp if i == 0 else 1 for i in range(num_dimension-1)] strategy1 = tuple(strategy1) else: strategy1 = strategy self.mean = P.ReduceMean(keep_dims=True).shard((strategy1,)) self.square = P.Square().shard((strategy1,)) self.sqrt = P.Sqrt().shard((strategy1,)) self.sub1 = P.Sub().shard((strategy1, strategy1)) self.add = P.TensorAdd().shard((strategy1, ())) self.eps = eps self.real_div = P.RealDiv().shard((strategy1, strategy1)) self.mul = P.Mul().shard((strategy, (1, 1, 1))) self.add2 = P.TensorAdd().shard((strategy, (1, 1, 1))) self.axes = axes self.is_gn = is_gn self.num_groups = num_groups # layer norm (1,1,-1) (-1,1,1) if num_dimension == 3: self.view_shape = (1, 1, -1) else: self.view_shape = (-1, 1, 1) def construct(self, x): r""" x : batch x seq_length x hidden_size """ origin_shape = x.shape if self.is_gn: x = x.view(origin_shape[0], self.num_groups, -1) mean = self.mean(x, self.axes) diff = self.sub1(x, mean) variance = self.mean(self.square(diff), self.axes) variance_eps = self.sqrt(self.add(variance, self.eps)) output = self.real_div(diff, variance_eps) if self.is_gn: output = output.view(origin_shape) output = self.add2(self.mul(output, self.gamma.view(self.view_shape)), self.beta.view(self.view_shape)) return output class Mlp(nn.Cell): r""" MPL block """ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., dp=1, mp=1): super(Mlp, self).__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Dense(in_features, hidden_features, weight_init=TruncatedNormal(0.02)).to_float(mindspore.float16) self.fc1.matmul.shard(((dp, 1), (mp, 1))) self.fc1.bias_add.shard(((dp, mp), (mp,))) self.act = act_layer() self.act.gelu.shard(((dp, mp),)) self.fc2 = nn.Dense(hidden_features, out_features, weight_init=TruncatedNormal(0.02)).to_float(mindspore.float16) self.fc2.matmul.shard(((dp, mp), (1, mp))) self.fc2.bias_add.shard(((dp, 1), (1,))) self.drop = nn.Dropout(1.0-drop) self.drop.dropout.shard(((dp, 1),)) self.drop2 = nn.Dropout(1.0-drop) self.drop2.dropout.shard(((dp, mp),)) def construct(self, x): r""" x : fp32 """ origin_shape = x.shape x = x.view(-1, origin_shape[-1]) x = self.fc1(F.cast(x, mindspore.float16)) x = self.act(F.cast(x, mindspore.float32)) x = self.drop2(x) x = self.fc2(F.cast(x, mindspore.float16)) x = self.drop(F.cast(x, mindspore.float32)) x = x.view(origin_shape[:-1]+(-1,)) return x class Attention(nn.Cell): """Multi-head Attention""" def __init__(self, dim, hidden_dim=None, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., dp=1, mp=1): super(Attention, self).__init__() hidden_dim = hidden_dim or dim self.hidden_dim = hidden_dim self.num_heads = num_heads head_dim = hidden_dim // num_heads self.head_dim = head_dim self.scale = head_dim ** -0.5 self.qk_scale = qk_scale self.mul = P.Mul().shard(((dp, mp, 1, 1), ())) self.q = nn.Dense(dim, hidden_dim, has_bias=qkv_bias, weight_init=TruncatedNormal(0.02)).to_float(mindspore.float16) self.q.matmul.shard(((dp, 1), (mp, 1))) if qkv_bias: self.q.bias_add.shard(((dp, mp), (mp,))) self.k = nn.Dense(dim, hidden_dim, has_bias=qkv_bias, weight_init=TruncatedNormal(0.02)).to_float(mindspore.float16) self.k.matmul.shard(((dp, 1), (mp, 1))) if qkv_bias: self.k.bias_add.shard(((dp, mp), (mp,))) self.v = nn.Dense(dim, hidden_dim, has_bias=qkv_bias, weight_init=TruncatedNormal(0.02)).to_float(mindspore.float16) self.v.matmul.shard(((dp, 1), (mp, 1))) if qkv_bias: self.v.bias_add.shard(((dp, mp), (mp,))) self.softmax = nn.Softmax(axis=-1) self.softmax.softmax.shard(((dp, mp, 1, 1),)) self.batmatmul_trans_b = P.BatchMatMul().shard(((dp, mp, 1, 1), (dp, mp, 1, 1))) self.attn_drop = nn.Dropout(1. - attn_drop) self.attn_drop.dropout.shard(((dp, mp, 1, 1),)) self.proj = nn.Dense(hidden_dim, dim, weight_init=TruncatedNormal(0.02)).to_float(mindspore.float16) self.proj.matmul.shard(((dp, mp), (1, mp))) self.proj.bias_add.shard(((dp, 1), (1,))) self.proj_drop = nn.Dropout(1. - proj_drop) self.proj_drop.dropout.shard(((dp, 1),)) self.transpose = P.Transpose().shard(((dp, 1, mp, 1),)) self.transpose2 = P.Transpose().shard(((dp, 1, 1, 1),)) self.reshape = P.Reshape() def construct(self, x): """Multi-head Attention""" b_size, n_channel, _ = x.shape # fp32 x = F.cast(x, mindspore.float16) x = x.view(b_size*n_channel, -1) q = self.q(x) k = self.k(x) v = self.v(x) q = self.transpose( F.reshape( q, (-1, n_channel, self.num_heads, self.head_dim)), (0, 2, 1, 3)) k = self.transpose( F.reshape( k, (-1, n_channel, self.num_heads, self.head_dim)), (0, 2, 3, 1)) v = self.transpose( F.reshape( v, (-1, n_channel, self.num_heads, self.head_dim)), (0, 2, 1, 3)) attn = self.softmax(F.cast(self.batmatmul_trans_b(self.mul(q, self.scale), k), mindspore.float32)) attn = self.attn_drop(attn) x = self.reshape(self.transpose2(self.batmatmul_trans_b(F.cast(attn, mindspore.float16), v), (0, 2, 1, 3)), (b_size*n_channel, -1)) x = self.proj(x) x = self.proj_drop(x) # fp16 return x.view(b_size, n_channel, -1) class Block(nn.Cell): """Block.""" def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, dp=1, mp=1): super(Block, self).__init__() self.norm1 = norm_layer([dim], epsilon=1e-6) self.norm1.layer_norm.shard(((dp, 1, 1), (1,), (1,))) self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, dp=dp, mp=mp) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPath(drop_path, num_dimension=3, dp=dp) if drop_path > 0. else P.Identity() self.norm2 = norm_layer([dim], epsilon=1e-6) self.norm2.layer_norm.shard(((dp, 1, 1), (1,), (1,))) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, dp=dp, mp=mp) self.add = P.Add().shard(((dp, 1, 1), (dp, 1, 1))) def construct(self, x): # x fp32 x = self.add(x, self.drop_path(self.attn(self.norm1(x)))) # output x fp32 x = self.add(x, self.drop_path(self.mlp(self.norm2(x)))) # output x fp32 return x class ConvBlock(nn.Cell): """ConvBlock""" def __init__(self, inplanes, outplanes, stride=1, res_conv=False, act_layer=nn.ReLU, groups=1, norm_layer=nn.BatchNorm2d, drop_block=None, drop_path=0., return_x_2=False, weighted_fusion=False, dp=1): super(ConvBlock, self).__init__() self.init_network(inplanes, outplanes, norm_layer, act_layer, stride, groups, dp) self.add = P.Add().shard(((dp, 1, 1, 1), (dp, 1, 1, 1))) self.mul = P.Mul().shard(((1,), (dp, 1, 1, 1))) if res_conv: self.residual_conv = nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=stride, padding=0, has_bias=False, pad_mode="pad", weight_init=HeNormal(mode='fan_out', nonlinearity='relu')).to_float(mindspore.float16) self.residual_conv.conv2d.shard(((dp, 1, 1, 1), (1, 1, 1, 1))) self.residual_conv.bias_add.shard(((dp, 1, 1, 1), (1,))) self.residual_bn = norm_layer(outplanes, eps=1e-6) self.residual_bn.bn_train.shard(((dp, 1, 1, 1), (1,), (1,), (1,), (1,))) self.res_conv = res_conv self.drop_block = drop_block self.drop_path = DropPath(drop_path, dp=dp) self.return_x_2 = return_x_2 self.weighted_fusion = weighted_fusion if weighted_fusion: self.add1 = P.Add().shard(((), (1,))) self.div = P.Div().shard(((), (1,))) self.exp = P.Exp().shard(((1,),)) self.neg = P.Neg().shard(((1,),)) self.c = Parameter(Tensor(np.zeros((1,)), mindspore.float16), requires_grad=True) def init_network(self, inplanes, outplanes, norm_layer, act_layer, stride, groups, dp): expansion = 4 med_planes = outplanes // expansion self.conv1 = nn.Conv2d(inplanes, med_planes, kernel_size=1, stride=1, padding=0, has_bias=False, pad_mode="pad", weight_init=HeNormal(mode='fan_out', nonlinearity='relu')).to_float(mindspore.float16) self.conv1.conv2d.shard(((dp, 1, 1, 1), (1, 1, 1, 1))) self.conv1.bias_add.shard(((dp, 1, 1, 1), (1,))) self.bn1 = norm_layer(med_planes, eps=1e-6) self.bn1.bn_train.shard(((dp, 1, 1, 1), (1,), (1,), (1,), (1,))) self.act1 = act_layer() self.act1.relu.shard(((dp, 1, 1, 1),)) self.conv2 = nn.Conv2d(med_planes, med_planes, kernel_size=3, stride=stride, group=groups, padding=1, has_bias=False, pad_mode="pad", weight_init=HeNormal(mode='fan_out', nonlinearity='relu')).to_float(mindspore.float16) self.conv2.conv2d.shard(((dp, 1, 1, 1), (1, 1, 1, 1))) self.conv2.bias_add.shard(((dp, 1, 1, 1), (1,))) self.bn2 = norm_layer(med_planes, eps=1e-6) self.bn2.bn_train.shard(((dp, 1, 1, 1), (1,), (1,), (1,), (1,))) self.act2 = act_layer() self.act2.relu.shard(((dp, 1, 1, 1),)) self.conv3 = nn.Conv2d(med_planes, outplanes, kernel_size=1, stride=1, padding=0, has_bias=False, pad_mode="pad", weight_init=HeNormal(mode='fan_out', nonlinearity='relu')).to_float(mindspore.float16) self.conv3.conv2d.shard(((dp, 1, 1, 1), (1, 1, 1, 1))) self.conv3.bias_add.shard(((dp, 1, 1, 1), (1,))) self.bn3 = norm_layer(outplanes, eps=1e-6) self.bn3.bn_train.shard(((dp, 1, 1, 1), (1,), (1,), (1,), (1,))) self.act3 = act_layer() self.act3.relu.shard(((dp, 1, 1, 1),)) def construct(self, x, x_t=None): """ConvBlock construct""" residual = x x = self.conv1(x) # fp16 x = self.bn1(F.cast(x, mindspore.float32)) x = F.cast(x, mindspore.float16) if self.drop_block is not None: x = self.drop_block(x) x = self.act1(x) # fp16 if x_t is None: x = self.conv2(x) else: if self.weighted_fusion: c = self.div(1.0, self.add1(1.0, self.exp(self.neg(self.c)))) x = self.conv2(self.add(self.mul(c, x), self.mul(1.0-c, F.cast(x_t, mindspore.float16)))) else: x = self.conv2(self.add(x, F.cast(x_t, mindspore.float16))) x = self.bn2(F.cast(x, mindspore.float32)) x = F.cast(x, mindspore.float16) if self.drop_block is not None: x = self.drop_block(x) x2 = self.act2(x) x = self.conv3(x2) x = self.bn3(F.cast(x, mindspore.float32)) x = F.cast(x, mindspore.float16) if self.drop_block is not None: x = self.drop_block(x) if self.drop_path is not None: x = self.drop_path(x) if self.res_conv: residual = self.residual_conv(residual) residual = self.residual_bn(F.cast(residual, mindspore.float32)) residual = F.cast(residual, mindspore.float16) x = self.add(x, residual) x = self.act3(x) if self.return_x_2: return x, x2 return x class FCUDown(nn.Cell): """ CNN feature maps -> Transformer patch embeddings """ def __init__(self, inplanes, outplanes, dw_stride, act_layer=nn.GELU, norm_layer=nn.LayerNorm, cls_token=True, dp=1): super(FCUDown, self).__init__() self.dw_stride = dw_stride self.cls_token = cls_token self.conv_project = nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, padding=0, has_bias=True, pad_mode="pad", weight_init=HeNormal(mode='fan_out', nonlinearity='relu')).to_float(mindspore.float16) self.conv_project.conv2d.shard(((dp, 1, 1, 1), (1, 1, 1, 1))) self.conv_project.bias_add.shard(((dp, 1, 1, 1), (1,))) self.sample_pooling = nn.AvgPool2d(kernel_size=dw_stride, stride=dw_stride) self.sample_pooling.avg_pool.shard(((dp, 1, 1, 1),)) self.ln = norm_layer([outplanes], epsilon=1e-6) self.ln.layer_norm.shard(((dp, 1, 1), (1,), (1,))) self.act = act_layer() self.act.gelu.shard(((dp, 1, 1),)) self.concat = P.Concat(axis=1).shard(((dp, 1, 1), (dp, 1, 1))) self.transpose = P.Transpose().shard(((dp, 1, 1),)) self.slice = P.StridedSlice().shard(((dp, 1, 1),)) def construct(self, x, x_t): """FCUDown construct""" # x fp16, x_t fp32 x = self.conv_project(x) # [N, C, H, W] tmp = self.sample_pooling(x) tmp1 = flatten(tmp, 2) x = self.transpose(tmp1, (0, 2, 1)) x = self.ln(F.cast(x, mindspore.float32)) x = self.act(x) if self.cls_token: b_size, _, height = F.shape(x_t) tmp2 = self.slice(x_t, (0, 0, 0), (b_size, 1, height), (1, 1, 1)) x = self.concat([tmp2, x]) return x class FCUUp(nn.Cell): """ Transformer patch embeddings -> CNN feature maps """ def __init__(self, inplanes, outplanes, up_stride, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, cls_token=True, seq_length=196, dp=1): super(FCUUp, self).__init__() self.up_stride = up_stride self.conv_project = nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, padding=0, has_bias=True, pad_mode="pad", weight_init=HeNormal(mode='fan_out', nonlinearity='relu')).to_float(mindspore.float16) self.conv_project.conv2d.shard(((dp, 1, 1, 1), (1, 1, 1, 1))) self.conv_project.bias_add.shard(((dp, 1, 1, 1), (1,))) self.ln = Norm(inplanes, axes=-1, affine=False, dp=dp, eps=1e-6) self.bn = norm_layer(outplanes, eps=1e-6) self.bn.bn_train.shard(((dp, 1, 1, 1), (1,), (1,), (1,), (1,))) self.act = act_layer() self.act.relu.shard(((dp, 1, 1, 1),)) self.cls_token = cls_token height = weight = int(math.sqrt(seq_length)) self.resize_neighbor = P.ResizeNearestNeighbor(size=(height * self.up_stride, weight * self.up_stride)).shard(((dp, 1, 1, 1),)) self.reshape = P.Reshape() self.transpose = P.Transpose().shard(((dp, 1, 1),)) self.slice = P.StridedSlice().shard(((dp, 1, 1),)) def construct(self, x, height, weight): """FCUUp construct""" # x fp32 b_size, t_num, channel = F.shape(x) x = self.ln(x) if self.cls_token: x_r = self.reshape(self.transpose(\ self.slice(x, (0, 1, 0), (b_size, t_num, channel),\ (1, 1, 1)), (0, 2, 1)), (b_size, channel, height, weight)) else: x_r = self.reshape(self.transpose(x, (0, 2, 1)), (b_size, channel, height, weight)) # x_r fp32 x_r_fp32 = F.cast(self.conv_project(F.cast(x_r, mindspore.float16)), mindspore.float32) x_r_fp16 = F.cast(self.bn(x_r_fp32), mindspore.float16) x_r = self.act(x_r_fp16) return self.resize_neighbor(x_r) class ConvTransBlock(nn.Cell): """ Basic module for ConvTransformer, keep feature maps for CNN block and patch embeddings for transformer encoder block """ def __init__(self, inplanes, outplanes, res_conv, stride, dw_stride, embed_dim, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., last_fusion=False, groups=1, cls_token=True, weighted_fusion=False, dp=1, mp=1, seq_length=196): super(ConvTransBlock, self).__init__() expansion = 4 self.cnn_block = ConvBlock(inplanes=inplanes, outplanes=outplanes, res_conv=res_conv, stride=stride, groups=groups, drop_path=drop_path_rate, return_x_2=True, dp=dp) if last_fusion: self.fusion_block = ConvBlock(inplanes=outplanes, outplanes=outplanes, stride=2, res_conv=True, groups=groups, drop_path=drop_path_rate, weighted_fusion=weighted_fusion, dp=dp) else: self.fusion_block = ConvBlock(inplanes=outplanes, outplanes=outplanes, groups=groups, drop_path=drop_path_rate, weighted_fusion=weighted_fusion, dp=dp) self.squeeze_block = FCUDown(inplanes=outplanes // expansion, outplanes=embed_dim, dw_stride=dw_stride, cls_token=cls_token, dp=dp) self.expand_block = FCUUp(inplanes=embed_dim, outplanes=outplanes // expansion, up_stride=dw_stride, cls_token=cls_token, dp=dp, seq_length=seq_length) self.trans_block = Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate, dp=dp, mp=mp) self.dw_stride = dw_stride self.embed_dim = embed_dim self.last_fusion = last_fusion self.weighted_fusion = weighted_fusion if weighted_fusion: self.exp = P.Exp().shard(((1,),)) self.c = Parameter(Tensor(np.zeros((1,)), mindspore.float16), requires_grad=True) self.add = P.Add().shard(((dp, 1, 1), (dp, 1, 1))) self.add1 = P.Add().shard(((), (1,))) self.div = P.Div().shard(((), (1,))) self.mul = P.Mul().shard(((1,), (dp, 1, 1))) self.sub = P.Sub().shard(((), (1,))) self.neg = P.Neg().shard(((1,),)) def construct(self, x, x_t): """ConvTransBlock construct""" # x fp16, x_t fp32 x, x2 = self.cnn_block(x) # both fp16 _, _, height, weight = x2.shape x_st = self.squeeze_block(x2, x_t) # x_st fp32 if self.weighted_fusion: c = self.div(1.0, self.add1(1.0, self.exp(self.neg(self.c)))) x_t = self.trans_block(self.add(self.mul(c, x_st), self.mul(self.sub(1.0, c), x_t))) else: x_t = self.trans_block(self.add(x_st, x_t)) # x_t fp32 x_t_r = self.expand_block(x_t, height // self.dw_stride, weight // self.dw_stride) # x_t_r fp16 x = self.fusion_block(x, x_t_r) return x, x_t class ConformerOverflow(nn.Cell): """Conformeroverflow""" def __init__(self, patch_size=16, in_chans=3, num_classes=1000, base_channel=64, channel_ratio=4, embed_dim=768, stage_point=None, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., cls_token=True, batch_size=8, weighted_fusion=False, dp=1, mp=1, seq_length=196): # Transformer super(ConformerOverflow, self).__init__() self.num_classes = num_classes self.num_features = self.embed_dim = embed_dim depth = stage_point[-1] self.cls_token_flag = cls_token if self.cls_token_flag: self.cls_token = mindspore.Parameter(initializer('zeros', (1, 1, embed_dim), mindspore.float32)) self.trans_dpr = [Tensor(x, mindspore.float32) for x in np.linspace(0, drop_path_rate, depth, dtype=np.float32)] # Classifier head self.trans_norm = nn.LayerNorm([embed_dim], epsilon=1e-05) self.trans_norm.layer_norm.shard(((dp, 1, 1), (1,), (1,))) self.trans_cls_head = nn.Dense(embed_dim, num_classes, weight_init=TruncatedNormal(0.02)).to_float(mindspore.float16) self.trans_cls_head.matmul.shard(((dp, 1), (1, 1))) self.trans_cls_head.bias_add.shard(((dp, 1), (1,))) self.pooling = nn.AvgPool2d(kernel_size=7, stride=7) self.pooling.avg_pool.shard(((dp, 1, 1, 1),)) self.conv_cls_head = nn.Dense(int(256 * channel_ratio), num_classes, weight_init=TruncatedNormal(0.02)).to_float(mindspore.float16) self.conv_cls_head.matmul.shard(((dp, 1), (1, 1))) self.conv_cls_head.bias_add.shard(((dp, 1), (1,))) # Stem stage: get the feature maps by conv block (copied form resnet.py) self.conv1 = nn.Conv2d(in_chans, 64, kernel_size=7, stride=2, padding=3, has_bias=False, pad_mode="pad", weight_init=HeNormal(mode='fan_out', nonlinearity='relu')).to_float(mindspore.float16) self.conv1.conv2d.shard(((dp, 1, 1, 1), (1, 1, 1, 1))) self.conv1.bias_add.shard(((dp, 1, 1, 1), (1,))) self.bn1 = nn.BatchNorm2d(64) self.bn1.bn_train.shard(((dp, 1, 1, 1), (1,), (1,), (1,), (1,))) self.act1 = nn.ReLU() self.act1.relu.shard(((dp, 1, 1, 1),)) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") self.maxpool.max_pool.shard(((dp, 1, 1, 1),)) self.concat = P.Concat(axis=1).shard(((dp, 1, 1), (dp, 1, 1))) self.conv_trans_list = [] self.broadcastto = P.BroadcastTo((batch_size, -1, -1)).shard(((1, 1, 1),)) self.slice = P.StridedSlice().shard(((dp, 1, 1),)) self.squeeze = P.Squeeze(1).shard(((dp, 1, 1),)) self.mean = P.ReduceMean(keep_dims=True).shard(((dp, 1, 1),)) self.trunc_normal_ = Init.TruncatedNormal(.02) if self.cls_token_flag: self.trunc_normal_(self.cls_token.asnumpy()) self.init_stage1_4(base_channel, channel_ratio, patch_size, embed_dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop_rate, attn_drop_rate, stage_point, weighted_fusion, seq_length, dp, mp) self.init_stage5_12(base_channel, channel_ratio, patch_size, embed_dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop_rate, attn_drop_rate, stage_point, weighted_fusion, seq_length, depth, dp, mp) def init_stage1_4(self, base_channel, channel_ratio, patch_size, embed_dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop_rate, attn_drop_rate, stage_point, weighted_fusion, seq_length, dp, mp): # 1 stage stage_1_channel = int(base_channel * channel_ratio) trans_dw_stride = patch_size // 4 self.conv_1 = ConvBlock(inplanes=64, outplanes=stage_1_channel, res_conv=True, stride=1, dp=dp) self.trans_patch_conv = nn.Conv2d(64, embed_dim, kernel_size=trans_dw_stride, stride=trans_dw_stride, padding=0, has_bias=True, pad_mode="pad", weight_init=HeNormal(mode='fan_out', nonlinearity='relu')).to_float(mindspore.float16) self.trans_patch_conv.conv2d.shard(((dp, 1, 1, 1), (1, 1, 1, 1))) self.trans_patch_conv.bias_add.shard(((dp, 1, 1, 1), (1,))) self.trans_1 = Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=self.trans_dpr[0], dp=dp, mp=mp) # 2~4 stage init_stage = 2 fin_stage = stage_point[1] + 1 # fin_stage = depth // 3 + 1 for i in range(init_stage, fin_stage): self.conv_trans_list.append( ConvTransBlock(stage_1_channel, stage_1_channel, False, 1, dw_stride=trans_dw_stride, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=self.trans_dpr[i - 1], cls_token=self.cls_token_flag, weighted_fusion=weighted_fusion, dp=dp, mp=mp, seq_length=seq_length) ) def init_stage5_12(self, base_channel, channel_ratio, patch_size, embed_dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop_rate, attn_drop_rate, stage_point, weighted_fusion, seq_length, depth, dp, mp): stage_1_channel = int(base_channel * channel_ratio) stage_2_channel = int(base_channel * channel_ratio * 2) trans_dw_stride = patch_size // 4 fin_stage = stage_point[1] + 1 # 5~8 stage init_stage = fin_stage # 5 fin_stage = stage_point[2] + 1 # fin_stage = fin_stage + depth // 3 # 9 for i in range(init_stage, fin_stage): s = 2 if i == init_stage else 1 in_channel = stage_1_channel if i == init_stage else stage_2_channel res_conv = bool(i == init_stage) self.conv_trans_list.append( ConvTransBlock(in_channel, stage_2_channel, res_conv, s, dw_stride=trans_dw_stride // 2, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=self.trans_dpr[i - 1], cls_token=self.cls_token_flag, weighted_fusion=weighted_fusion, dp=dp, mp=mp, seq_length=seq_length) ) stage_3_channel = int(base_channel * channel_ratio * 2 * 2) # 9~12 stage init_stage = fin_stage # 9 fin_stage = stage_point[3] + 1 # fin_stage = fin_stage + depth // 3 # 13 for i in range(init_stage, fin_stage): s = 2 if i == init_stage else 1 in_channel = stage_2_channel if i == init_stage else stage_3_channel res_conv = bool(i == init_stage) last_fusion = bool(i == depth) self.conv_trans_list.append( ConvTransBlock( in_channel, stage_3_channel, res_conv, s, dw_stride=trans_dw_stride // 4, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=self.trans_dpr[i - 1], last_fusion=last_fusion, cls_token=self.cls_token_flag, weighted_fusion=weighted_fusion, dp=dp, mp=mp, seq_length=seq_length ) ) self.conv_trans_blks = nn.CellList(self.conv_trans_list) def construct(self, x): """conformer construct""" # x fp32 cls_tokens = None if self.cls_token_flag: cls_tokens = self.broadcastto(self.cls_token) # fp32 # stem stage [N, 3, 224, 224] -> [N, 64, 56, 56] x_fp32 = F.cast(self.conv1(F.cast(x, mindspore.float16)), mindspore.float32) x_fp16 = F.cast(self.bn1(x_fp32), mindspore.float16) x_base = self.maxpool(self.act1(x_fp16)) # fp16 # 1 stage x = self.conv_1(x_base) # fp16 tmp = self.trans_patch_conv(x_base) tmp1 = flatten(tmp, 2) x_t = F.cast(tmp1.transpose((0, 2, 1)), mindspore.float32) # fp32 if self.cls_token_flag: x_t = self.concat([cls_tokens, x_t]) x_t = self.trans_1(x_t) # fp32 # 2 ~ final for blk in self.conv_trans_blks: x, x_t = blk(x, x_t) # x fp16, x_t fp32 # conv classification tmp2 = self.pooling(x) x_p = flatten(tmp2, 1) conv_cls = self.conv_cls_head(x_p) # conv_cls fp16 # trans classification x_t = self.trans_norm(x_t) x_t = F.cast(x_t, mindspore.float16) b_size, _, height = F.shape(x_t) tmp3 = self.squeeze(self.slice(x_t, (0, 0, 0), (b_size, 1, height), (1, 1, 1))) if self.cls_token_flag: tran_cls = self.trans_cls_head(tmp3) else: tran_cls = self.trans_cls_head(self.mean(x_t, 1)) return [conv_cls, tran_cls] @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @pytest.mark.env_onecard def test_conformer_arm_ascend(): """ Feature: test conformer architecture Description: convolution and transformer Expectation: compile success """ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=1, global_rank=0) net = ConformerOverflow(patch_size=16, channel_ratio=4, embed_dim=384, stage_point=[1, 4, 8, 12], num_heads=6, mlp_ratio=4, qkv_bias=False, qk_scale=None, cls_token=True, num_classes=1000, drop_rate=0.0, drop_path_rate=0.1, attn_drop_rate=0.0, batch_size=32, weighted_fusion=True, dp=8, mp=1, seq_length=196) ls = CrossEntropySmooth(reduction="mean") net_with_loss_net = NetWithLossCell(net, ls) net_with_loss = _VirtualDatasetCell(net_with_loss_net) optimizer = nn.AdamWeightDecay(net.trainable_params()) train_net = nn.TrainOneStepCell(net_with_loss, optimizer) data = Tensor(np.ones([32, 3, 224, 224]), dtype=mindspore.float32) label = Tensor(np.ones([32]).astype(np.int32)) label = one_hot_int(label, 1000) train_net(data, label)