|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """TNT"""
- import math
- import copy
- import numpy as np
- import mindspore.common.dtype as mstype
- from mindspore import nn
- from mindspore.ops import operations as P
- from mindspore.common.tensor import Tensor
- from mindspore.common.parameter import Parameter
-
-
- class MLP(nn.Cell):
- """MLP"""
-
- def __init__(self, in_features, hidden_features=None, out_features=None, dropout=0.):
- super(MLP, self).__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- self.fc1 = nn.Dense(in_features, hidden_features)
- self.dropout = nn.Dropout(1. - dropout)
- self.fc2 = nn.Dense(hidden_features, out_features)
- self.act = nn.GELU()
-
- def construct(self, x):
- x = self.fc1(x)
- x = self.act(x)
- x = self.dropout(x)
- x = self.fc2(x)
- x = self.dropout(x)
- return x
-
-
- class Attention(nn.Cell):
- """Multi-head Attention"""
-
- def __init__(self, dim, hidden_dim=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
- super(Attention, self).__init__()
- hidden_dim = hidden_dim or dim
- self.hidden_dim = hidden_dim
- self.num_heads = num_heads
- head_dim = hidden_dim // num_heads
- self.head_dim = head_dim
- self.scale = head_dim ** -0.5
-
- self.qk = nn.Dense(dim, hidden_dim * 2, has_bias=qkv_bias)
- self.v = nn.Dense(dim, hidden_dim, has_bias=qkv_bias)
- self.softmax = nn.Softmax(axis=-1)
- self.batmatmul_trans_b = P.BatchMatMul(transpose_b=True)
- self.attn_drop = nn.Dropout(1. - attn_drop)
- self.batmatmul = P.BatchMatMul()
- self.proj = nn.Dense(hidden_dim, dim)
- self.proj_drop = nn.Dropout(1. - proj_drop)
-
- self.transpose = P.Transpose()
- self.reshape = P.Reshape()
-
- def construct(self, x):
- """Multi-head Attention"""
- B, N, _ = x.shape
- qk = self.transpose(self.reshape(self.qk(x), (B, N, 2, self.num_heads, self.head_dim)), (2, 0, 3, 1, 4))
- q, k = qk[0], qk[1]
- v = self.transpose(self.reshape(self.v(x), (B, N, self.num_heads, self.head_dim)), (0, 2, 1, 3))
-
- attn = self.softmax(self.batmatmul_trans_b(q, k) * self.scale)
- attn = self.attn_drop(attn)
- x = self.reshape(self.transpose(self.batmatmul(attn, v), (0, 2, 1, 3)), (B, N, -1))
- x = self.proj(x)
- x = self.proj_drop(x)
- return x
-
-
- class DropConnect(nn.Cell):
- """drop connect implementation"""
-
- def __init__(self, drop_connect_rate=0., seed0=0, seed1=0):
- super(DropConnect, self).__init__()
- self.shape = P.Shape()
- self.dtype = P.DType()
- self.keep_prob = 1 - drop_connect_rate
- self.dropout = P.Dropout(keep_prob=self.keep_prob)
- self.keep_prob_tensor = Tensor(self.keep_prob, dtype=mstype.float32)
-
- def construct(self, x):
- shape = self.shape(x)
- dtype = self.dtype(x)
- ones_tensor = P.Fill()(dtype, (shape[0], 1, 1, 1), 1)
- _, mask = self.dropout(ones_tensor)
- x = x * mask
- x = x / self.keep_prob_tensor
- return x
-
-
- class Pixel2Patch(nn.Cell):
- """Projecting Pixel Embedding to Patch Embedding"""
-
- def __init__(self, outer_dim):
- super(Pixel2Patch, self).__init__()
- self.norm_proj = nn.LayerNorm([outer_dim])
- self.proj = nn.Dense(outer_dim, outer_dim)
- self.fake = Parameter(Tensor(np.zeros((1, 1, outer_dim)),
- mstype.float32), name='fake', requires_grad=False)
- self.reshape = P.Reshape()
- self.tile = P.Tile()
- self.concat = P.Concat(axis=1)
-
- def construct(self, pixel_embed, patch_embed):
- B, N, _ = patch_embed.shape
- proj = self.reshape(pixel_embed, (B, N - 1, -1))
- proj = self.proj(self.norm_proj(proj))
- proj = self.concat((self.tile(self.fake, (B, 1, 1)), proj))
- patch_embed = patch_embed + proj
- return patch_embed
-
-
- class TNTBlock(nn.Cell):
- """TNT Block"""
-
- def __init__(self, inner_config, outer_config, dropout=0., attn_dropout=0., drop_connect=0.):
- super().__init__()
- # inner transformer
- inner_dim = inner_config['dim']
- num_heads = inner_config['num_heads']
- mlp_ratio = inner_config['mlp_ratio']
- self.inner_norm1 = nn.LayerNorm([inner_dim])
- self.inner_attn = Attention(inner_dim, num_heads=num_heads, qkv_bias=True, attn_drop=attn_dropout,
- proj_drop=dropout)
- self.inner_norm2 = nn.LayerNorm([inner_dim])
- self.inner_mlp = MLP(inner_dim, int(inner_dim * mlp_ratio), dropout=dropout)
- # outer transformer
- outer_dim = outer_config['dim']
- num_heads = outer_config['num_heads']
- mlp_ratio = outer_config['mlp_ratio']
- self.outer_norm1 = nn.LayerNorm([outer_dim])
- self.outer_attn = Attention(outer_dim, num_heads=num_heads, qkv_bias=True, attn_drop=attn_dropout,
- proj_drop=dropout)
- self.outer_norm2 = nn.LayerNorm([outer_dim])
- self.outer_mlp = MLP(outer_dim, int(outer_dim * mlp_ratio), dropout=dropout)
- # pixel2patch
- self.pixel2patch = Pixel2Patch(outer_dim)
- # assistant
- self.drop_connect = DropConnect(drop_connect)
- self.reshape = P.Reshape()
- self.tile = P.Tile()
- self.concat = P.Concat(axis=1)
-
- def construct(self, pixel_embed, patch_embed):
- """TNT Block"""
- pixel_embed = pixel_embed + self.inner_attn(self.inner_norm1(pixel_embed))
- pixel_embed = pixel_embed + self.inner_mlp(self.inner_norm2(pixel_embed))
-
- patch_embed = self.pixel2patch(pixel_embed, patch_embed)
-
- patch_embed = patch_embed + self.outer_attn(self.outer_norm1(patch_embed))
- patch_embed = patch_embed + self.outer_mlp(self.outer_norm2(patch_embed))
- return pixel_embed, patch_embed
-
-
- def _get_clones(module, N):
- """get_clones"""
- return nn.CellList([copy.deepcopy(module) for i in range(N)])
-
-
- class TNTEncoder(nn.Cell):
- """TNT"""
-
- def __init__(self, encoder_layer, num_layers):
- super().__init__()
- self.layers = _get_clones(encoder_layer, num_layers)
- self.num_layers = num_layers
-
- def construct(self, pixel_embed, patch_embed):
- """TNT"""
- for layer in self.layers:
- pixel_embed, patch_embed = layer(pixel_embed, patch_embed)
- return pixel_embed, patch_embed
-
-
- class _stride_unfold_(nn.Cell):
- """Unfold with stride"""
-
- def __init__(
- self, kernel_size, stride=-1):
- super(_stride_unfold_, self).__init__()
- if stride == -1:
- self.stride = kernel_size
- else:
- self.stride = stride
- self.kernel_size = kernel_size
- self.reshape = P.Reshape()
- self.transpose = P.Transpose()
- self.unfold = _unfold_(kernel_size)
-
- def construct(self, x):
- """TNT"""
- N, C, H, W = x.shape
- leftup_idx_x = []
- leftup_idx_y = []
- nh = int((H - self.kernel_size) / self.stride + 1)
- nw = int((W - self.kernel_size) / self.stride + 1)
- for i in range(nh):
- leftup_idx_x.append(i * self.stride)
- for i in range(nw):
- leftup_idx_y.append(i * self.stride)
- NumBlock_x = len(leftup_idx_x)
- NumBlock_y = len(leftup_idx_y)
- zeroslike = P.ZerosLike()
- cc_2 = P.Concat(axis=2)
- cc_3 = P.Concat(axis=3)
- unf_x = P.Zeros()((N, C, NumBlock_x * self.kernel_size,
- NumBlock_y * self.kernel_size), mstype.float32)
- N, C, H, W = unf_x.shape
- for i in range(NumBlock_x):
- for j in range(NumBlock_y):
- unf_i = i * self.kernel_size
- unf_j = j * self.kernel_size
- org_i = leftup_idx_x[i]
- org_j = leftup_idx_y[j]
- fill = x[:, :, org_i:org_i + self.kernel_size,
- org_j:org_j + self.kernel_size]
- unf_x += cc_3((cc_3((zeroslike(unf_x[:, :, :, :unf_j]),
- cc_2((cc_2((zeroslike(unf_x[:, :, :unf_i, unf_j:unf_j + self.kernel_size]), fill)),
- zeroslike(unf_x[:, :, unf_i + self.kernel_size:,
- unf_j:unf_j + self.kernel_size]))))),
- zeroslike(unf_x[:, :, :, unf_j + self.kernel_size:])))
- y = self.unfold(unf_x)
- return y
-
-
- class _unfold_(nn.Cell):
- """Unfold"""
-
- def __init__(
- self, kernel_size, stride=-1):
- super(_unfold_, self).__init__()
- if stride == -1:
- self.stride = kernel_size
- self.kernel_size = kernel_size
-
- self.reshape = P.Reshape()
- self.transpose = P.Transpose()
-
- def construct(self, x):
- """TNT"""
- N, C, H, W = x.shape
- numH = int(H / self.kernel_size)
- numW = int(W / self.kernel_size)
- if numH * self.kernel_size != H or numW * self.kernel_size != W:
- x = x[:, :, :numH * self.kernel_size, :, numW * self.kernel_size]
- output_img = self.reshape(x, (N, C, numH, self.kernel_size, W))
-
- output_img = self.transpose(output_img, (0, 1, 2, 4, 3))
-
- output_img = self.reshape(output_img, (N, C, int(
- numH * numW), self.kernel_size, self.kernel_size))
-
- output_img = self.transpose(output_img, (0, 2, 1, 4, 3))
-
- output_img = self.reshape(output_img, (N, int(numH * numW), -1))
- return output_img
-
-
- class PixelEmbed(nn.Cell):
- """Image to Pixel Embedding"""
-
- def __init__(self, img_size, patch_size=16, in_channels=3, embedding_dim=768, stride=4):
- super(PixelEmbed, self).__init__()
- self.num_patches = (img_size // patch_size) * (img_size // patch_size)
- new_patch_size = math.ceil(patch_size / stride)
- self.new_patch_size = new_patch_size
- self.inner_dim = embedding_dim // new_patch_size // new_patch_size
- self.proj = nn.Conv2d(in_channels, self.inner_dim, kernel_size=7, pad_mode='pad',
- padding=3, stride=stride, has_bias=True)
- self.unfold = _unfold_(kernel_size=new_patch_size)
- self.reshape = P.Reshape()
- self.transpose = P.Transpose()
-
- def construct(self, x):
- B = x.shape[0]
- x = self.proj(x) # B, C, H, W
- x = self.unfold(x) # B, N, Ck2
- x = self.reshape(x, (B * self.num_patches, self.inner_dim, -1)) # B*N, C, M
- x = self.transpose(x, (0, 2, 1)) # B*N, M, C
- return x
-
-
- class TNT(nn.Cell):
- """TNT"""
-
- def __init__(
- self,
- img_size,
- patch_size,
- num_channels,
- embedding_dim,
- num_heads,
- num_layers,
- hidden_dim,
- num_class,
- stride=4,
- dropout=0,
- attn_dropout=0,
- drop_connect=0.1
- ):
- super(TNT, self).__init__()
-
- assert embedding_dim % num_heads == 0
- assert img_size % patch_size == 0
- self.embedding_dim = embedding_dim
- self.num_heads = num_heads
- self.patch_size = patch_size
- self.num_channels = num_channels
- self.img_size = img_size
- self.num_patches = int((img_size // patch_size) ** 2)
- new_patch_size = math.ceil(patch_size / stride)
- inner_dim = embedding_dim // new_patch_size // new_patch_size
-
- self.patch_pos = Parameter(Tensor(np.random.rand(1, self.num_patches + 1, embedding_dim),
- mstype.float32), name='patch_pos', requires_grad=True)
- self.pixel_pos = Parameter(Tensor(np.random.rand(1, inner_dim, new_patch_size * new_patch_size),
- mstype.float32), name='pixel_pos', requires_grad=True)
- self.cls_token = Parameter(Tensor(np.random.rand(1, 1, embedding_dim),
- mstype.float32), requires_grad=True)
- self.patch_embed = Parameter(Tensor(np.zeros((1, self.num_patches, embedding_dim)),
- mstype.float32), name='patch_embed', requires_grad=False)
- self.fake = Parameter(Tensor(np.zeros((1, 1, embedding_dim)),
- mstype.float32), name='fake', requires_grad=False)
- self.pos_drop = nn.Dropout(1. - dropout)
-
- self.pixel_embed = PixelEmbed(img_size, patch_size, num_channels, embedding_dim, stride)
- self.pixel2patch = Pixel2Patch(embedding_dim)
-
- inner_config = {'dim': inner_dim, 'num_heads': 4, 'mlp_ratio': 4}
- outer_config = {'dim': embedding_dim, 'num_heads': num_heads, 'mlp_ratio': hidden_dim / embedding_dim}
- encoder_layer = TNTBlock(inner_config, outer_config, dropout=dropout, attn_dropout=attn_dropout,
- drop_connect=drop_connect)
- self.encoder = TNTEncoder(encoder_layer, num_layers)
-
- self.head = nn.SequentialCell(
- nn.LayerNorm([embedding_dim]),
- nn.Dense(embedding_dim, num_class)
- )
-
- self.add = P.TensorAdd()
- self.reshape = P.Reshape()
- self.concat = P.Concat(axis=1)
- self.tile = P.Tile()
- self.transpose = P.Transpose()
-
- def construct(self, x):
- """TNT"""
- B, _, _, _ = x.shape
- pixel_embed = self.pixel_embed(x)
- pixel_embed = pixel_embed + self.transpose(self.pixel_pos, (0, 2, 1)) # B*N, M, C
-
- patch_embed = self.concat((self.cls_token, self.patch_embed))
- patch_embed = self.tile(patch_embed, (B, 1, 1))
- patch_embed = self.pos_drop(patch_embed + self.patch_pos)
-
- patch_embed = self.pixel2patch(pixel_embed, patch_embed)
-
- pixel_embed, patch_embed = self.encoder(pixel_embed, patch_embed)
-
- y = self.head(patch_embed[:, 0])
- return y
-
-
- def tnt_b(num_class):
- return TNT(img_size=384,
- patch_size=16,
- num_channels=3,
- embedding_dim=640,
- num_heads=10,
- num_layers=12,
- hidden_dim=640*4,
- stride=4,
- num_class=num_class)
|