You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tnt.py 15 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """TNT"""
  16. import math
  17. import copy
  18. import numpy as np
  19. import mindspore.common.dtype as mstype
  20. from mindspore import nn
  21. from mindspore.ops import operations as P
  22. from mindspore.common.tensor import Tensor
  23. from mindspore.common.parameter import Parameter
  24. class MLP(nn.Cell):
  25. """MLP"""
  26. def __init__(self, in_features, hidden_features=None, out_features=None, dropout=0.):
  27. super(MLP, self).__init__()
  28. out_features = out_features or in_features
  29. hidden_features = hidden_features or in_features
  30. self.fc1 = nn.Dense(in_features, hidden_features)
  31. self.dropout = nn.Dropout(1. - dropout)
  32. self.fc2 = nn.Dense(hidden_features, out_features)
  33. self.act = nn.GELU()
  34. def construct(self, x):
  35. x = self.fc1(x)
  36. x = self.act(x)
  37. x = self.dropout(x)
  38. x = self.fc2(x)
  39. x = self.dropout(x)
  40. return x
  41. class Attention(nn.Cell):
  42. """Multi-head Attention"""
  43. def __init__(self, dim, hidden_dim=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
  44. super(Attention, self).__init__()
  45. hidden_dim = hidden_dim or dim
  46. self.hidden_dim = hidden_dim
  47. self.num_heads = num_heads
  48. head_dim = hidden_dim // num_heads
  49. self.head_dim = head_dim
  50. self.scale = head_dim ** -0.5
  51. self.qk = nn.Dense(dim, hidden_dim * 2, has_bias=qkv_bias)
  52. self.v = nn.Dense(dim, hidden_dim, has_bias=qkv_bias)
  53. self.softmax = nn.Softmax(axis=-1)
  54. self.batmatmul_trans_b = P.BatchMatMul(transpose_b=True)
  55. self.attn_drop = nn.Dropout(1. - attn_drop)
  56. self.batmatmul = P.BatchMatMul()
  57. self.proj = nn.Dense(hidden_dim, dim)
  58. self.proj_drop = nn.Dropout(1. - proj_drop)
  59. self.transpose = P.Transpose()
  60. self.reshape = P.Reshape()
  61. def construct(self, x):
  62. """Multi-head Attention"""
  63. B, N, _ = x.shape
  64. qk = self.transpose(self.reshape(self.qk(x), (B, N, 2, self.num_heads, self.head_dim)), (2, 0, 3, 1, 4))
  65. q, k = qk[0], qk[1]
  66. v = self.transpose(self.reshape(self.v(x), (B, N, self.num_heads, self.head_dim)), (0, 2, 1, 3))
  67. attn = self.softmax(self.batmatmul_trans_b(q, k) * self.scale)
  68. attn = self.attn_drop(attn)
  69. x = self.reshape(self.transpose(self.batmatmul(attn, v), (0, 2, 1, 3)), (B, N, -1))
  70. x = self.proj(x)
  71. x = self.proj_drop(x)
  72. return x
  73. class DropConnect(nn.Cell):
  74. """drop connect implementation"""
  75. def __init__(self, drop_connect_rate=0., seed0=0, seed1=0):
  76. super(DropConnect, self).__init__()
  77. self.shape = P.Shape()
  78. self.dtype = P.DType()
  79. self.keep_prob = 1 - drop_connect_rate
  80. self.dropout = P.Dropout(keep_prob=self.keep_prob)
  81. self.keep_prob_tensor = Tensor(self.keep_prob, dtype=mstype.float32)
  82. def construct(self, x):
  83. shape = self.shape(x)
  84. dtype = self.dtype(x)
  85. ones_tensor = P.Fill()(dtype, (shape[0], 1, 1, 1), 1)
  86. _, mask = self.dropout(ones_tensor)
  87. x = x * mask
  88. x = x / self.keep_prob_tensor
  89. return x
  90. class Pixel2Patch(nn.Cell):
  91. """Projecting Pixel Embedding to Patch Embedding"""
  92. def __init__(self, outer_dim):
  93. super(Pixel2Patch, self).__init__()
  94. self.norm_proj = nn.LayerNorm([outer_dim])
  95. self.proj = nn.Dense(outer_dim, outer_dim)
  96. self.fake = Parameter(Tensor(np.zeros((1, 1, outer_dim)),
  97. mstype.float32), name='fake', requires_grad=False)
  98. self.reshape = P.Reshape()
  99. self.tile = P.Tile()
  100. self.concat = P.Concat(axis=1)
  101. def construct(self, pixel_embed, patch_embed):
  102. B, N, _ = patch_embed.shape
  103. proj = self.reshape(pixel_embed, (B, N - 1, -1))
  104. proj = self.proj(self.norm_proj(proj))
  105. proj = self.concat((self.tile(self.fake, (B, 1, 1)), proj))
  106. patch_embed = patch_embed + proj
  107. return patch_embed
  108. class TNTBlock(nn.Cell):
  109. """TNT Block"""
  110. def __init__(self, inner_config, outer_config, dropout=0., attn_dropout=0., drop_connect=0.):
  111. super().__init__()
  112. # inner transformer
  113. inner_dim = inner_config['dim']
  114. num_heads = inner_config['num_heads']
  115. mlp_ratio = inner_config['mlp_ratio']
  116. self.inner_norm1 = nn.LayerNorm([inner_dim])
  117. self.inner_attn = Attention(inner_dim, num_heads=num_heads, qkv_bias=True, attn_drop=attn_dropout,
  118. proj_drop=dropout)
  119. self.inner_norm2 = nn.LayerNorm([inner_dim])
  120. self.inner_mlp = MLP(inner_dim, int(inner_dim * mlp_ratio), dropout=dropout)
  121. # outer transformer
  122. outer_dim = outer_config['dim']
  123. num_heads = outer_config['num_heads']
  124. mlp_ratio = outer_config['mlp_ratio']
  125. self.outer_norm1 = nn.LayerNorm([outer_dim])
  126. self.outer_attn = Attention(outer_dim, num_heads=num_heads, qkv_bias=True, attn_drop=attn_dropout,
  127. proj_drop=dropout)
  128. self.outer_norm2 = nn.LayerNorm([outer_dim])
  129. self.outer_mlp = MLP(outer_dim, int(outer_dim * mlp_ratio), dropout=dropout)
  130. # pixel2patch
  131. self.pixel2patch = Pixel2Patch(outer_dim)
  132. # assistant
  133. self.drop_connect = DropConnect(drop_connect)
  134. self.reshape = P.Reshape()
  135. self.tile = P.Tile()
  136. self.concat = P.Concat(axis=1)
  137. def construct(self, pixel_embed, patch_embed):
  138. """TNT Block"""
  139. pixel_embed = pixel_embed + self.inner_attn(self.inner_norm1(pixel_embed))
  140. pixel_embed = pixel_embed + self.inner_mlp(self.inner_norm2(pixel_embed))
  141. patch_embed = self.pixel2patch(pixel_embed, patch_embed)
  142. patch_embed = patch_embed + self.outer_attn(self.outer_norm1(patch_embed))
  143. patch_embed = patch_embed + self.outer_mlp(self.outer_norm2(patch_embed))
  144. return pixel_embed, patch_embed
  145. def _get_clones(module, N):
  146. """get_clones"""
  147. return nn.CellList([copy.deepcopy(module) for i in range(N)])
  148. class TNTEncoder(nn.Cell):
  149. """TNT"""
  150. def __init__(self, encoder_layer, num_layers):
  151. super().__init__()
  152. self.layers = _get_clones(encoder_layer, num_layers)
  153. self.num_layers = num_layers
  154. def construct(self, pixel_embed, patch_embed):
  155. """TNT"""
  156. for layer in self.layers:
  157. pixel_embed, patch_embed = layer(pixel_embed, patch_embed)
  158. return pixel_embed, patch_embed
  159. class _stride_unfold_(nn.Cell):
  160. """Unfold with stride"""
  161. def __init__(
  162. self, kernel_size, stride=-1):
  163. super(_stride_unfold_, self).__init__()
  164. if stride == -1:
  165. self.stride = kernel_size
  166. else:
  167. self.stride = stride
  168. self.kernel_size = kernel_size
  169. self.reshape = P.Reshape()
  170. self.transpose = P.Transpose()
  171. self.unfold = _unfold_(kernel_size)
  172. def construct(self, x):
  173. """TNT"""
  174. N, C, H, W = x.shape
  175. leftup_idx_x = []
  176. leftup_idx_y = []
  177. nh = int((H - self.kernel_size) / self.stride + 1)
  178. nw = int((W - self.kernel_size) / self.stride + 1)
  179. for i in range(nh):
  180. leftup_idx_x.append(i * self.stride)
  181. for i in range(nw):
  182. leftup_idx_y.append(i * self.stride)
  183. NumBlock_x = len(leftup_idx_x)
  184. NumBlock_y = len(leftup_idx_y)
  185. zeroslike = P.ZerosLike()
  186. cc_2 = P.Concat(axis=2)
  187. cc_3 = P.Concat(axis=3)
  188. unf_x = P.Zeros()((N, C, NumBlock_x * self.kernel_size,
  189. NumBlock_y * self.kernel_size), mstype.float32)
  190. N, C, H, W = unf_x.shape
  191. for i in range(NumBlock_x):
  192. for j in range(NumBlock_y):
  193. unf_i = i * self.kernel_size
  194. unf_j = j * self.kernel_size
  195. org_i = leftup_idx_x[i]
  196. org_j = leftup_idx_y[j]
  197. fill = x[:, :, org_i:org_i + self.kernel_size,
  198. org_j:org_j + self.kernel_size]
  199. unf_x += cc_3((cc_3((zeroslike(unf_x[:, :, :, :unf_j]),
  200. cc_2((cc_2((zeroslike(unf_x[:, :, :unf_i, unf_j:unf_j + self.kernel_size]), fill)),
  201. zeroslike(unf_x[:, :, unf_i + self.kernel_size:,
  202. unf_j:unf_j + self.kernel_size]))))),
  203. zeroslike(unf_x[:, :, :, unf_j + self.kernel_size:])))
  204. y = self.unfold(unf_x)
  205. return y
  206. class _unfold_(nn.Cell):
  207. """Unfold"""
  208. def __init__(
  209. self, kernel_size, stride=-1):
  210. super(_unfold_, self).__init__()
  211. if stride == -1:
  212. self.stride = kernel_size
  213. self.kernel_size = kernel_size
  214. self.reshape = P.Reshape()
  215. self.transpose = P.Transpose()
  216. def construct(self, x):
  217. """TNT"""
  218. N, C, H, W = x.shape
  219. numH = int(H / self.kernel_size)
  220. numW = int(W / self.kernel_size)
  221. if numH * self.kernel_size != H or numW * self.kernel_size != W:
  222. x = x[:, :, :numH * self.kernel_size, :, numW * self.kernel_size]
  223. output_img = self.reshape(x, (N, C, numH, self.kernel_size, W))
  224. output_img = self.transpose(output_img, (0, 1, 2, 4, 3))
  225. output_img = self.reshape(output_img, (N, C, int(
  226. numH * numW), self.kernel_size, self.kernel_size))
  227. output_img = self.transpose(output_img, (0, 2, 1, 4, 3))
  228. output_img = self.reshape(output_img, (N, int(numH * numW), -1))
  229. return output_img
  230. class PixelEmbed(nn.Cell):
  231. """Image to Pixel Embedding"""
  232. def __init__(self, img_size, patch_size=16, in_channels=3, embedding_dim=768, stride=4):
  233. super(PixelEmbed, self).__init__()
  234. self.num_patches = (img_size // patch_size) * (img_size // patch_size)
  235. new_patch_size = math.ceil(patch_size / stride)
  236. self.new_patch_size = new_patch_size
  237. self.inner_dim = embedding_dim // new_patch_size // new_patch_size
  238. self.proj = nn.Conv2d(in_channels, self.inner_dim, kernel_size=7, pad_mode='pad',
  239. padding=3, stride=stride, has_bias=True)
  240. self.unfold = _unfold_(kernel_size=new_patch_size)
  241. self.reshape = P.Reshape()
  242. self.transpose = P.Transpose()
  243. def construct(self, x):
  244. B = x.shape[0]
  245. x = self.proj(x) # B, C, H, W
  246. x = self.unfold(x) # B, N, Ck2
  247. x = self.reshape(x, (B * self.num_patches, self.inner_dim, -1)) # B*N, C, M
  248. x = self.transpose(x, (0, 2, 1)) # B*N, M, C
  249. return x
  250. class TNT(nn.Cell):
  251. """TNT"""
  252. def __init__(
  253. self,
  254. img_size,
  255. patch_size,
  256. num_channels,
  257. embedding_dim,
  258. num_heads,
  259. num_layers,
  260. hidden_dim,
  261. num_class,
  262. stride=4,
  263. dropout=0,
  264. attn_dropout=0,
  265. drop_connect=0.1
  266. ):
  267. super(TNT, self).__init__()
  268. assert embedding_dim % num_heads == 0
  269. assert img_size % patch_size == 0
  270. self.embedding_dim = embedding_dim
  271. self.num_heads = num_heads
  272. self.patch_size = patch_size
  273. self.num_channels = num_channels
  274. self.img_size = img_size
  275. self.num_patches = int((img_size // patch_size) ** 2)
  276. new_patch_size = math.ceil(patch_size / stride)
  277. inner_dim = embedding_dim // new_patch_size // new_patch_size
  278. self.patch_pos = Parameter(Tensor(np.random.rand(1, self.num_patches + 1, embedding_dim),
  279. mstype.float32), name='patch_pos', requires_grad=True)
  280. self.pixel_pos = Parameter(Tensor(np.random.rand(1, inner_dim, new_patch_size * new_patch_size),
  281. mstype.float32), name='pixel_pos', requires_grad=True)
  282. self.cls_token = Parameter(Tensor(np.random.rand(1, 1, embedding_dim),
  283. mstype.float32), requires_grad=True)
  284. self.patch_embed = Parameter(Tensor(np.zeros((1, self.num_patches, embedding_dim)),
  285. mstype.float32), name='patch_embed', requires_grad=False)
  286. self.fake = Parameter(Tensor(np.zeros((1, 1, embedding_dim)),
  287. mstype.float32), name='fake', requires_grad=False)
  288. self.pos_drop = nn.Dropout(1. - dropout)
  289. self.pixel_embed = PixelEmbed(img_size, patch_size, num_channels, embedding_dim, stride)
  290. self.pixel2patch = Pixel2Patch(embedding_dim)
  291. inner_config = {'dim': inner_dim, 'num_heads': 4, 'mlp_ratio': 4}
  292. outer_config = {'dim': embedding_dim, 'num_heads': num_heads, 'mlp_ratio': hidden_dim / embedding_dim}
  293. encoder_layer = TNTBlock(inner_config, outer_config, dropout=dropout, attn_dropout=attn_dropout,
  294. drop_connect=drop_connect)
  295. self.encoder = TNTEncoder(encoder_layer, num_layers)
  296. self.head = nn.SequentialCell(
  297. nn.LayerNorm([embedding_dim]),
  298. nn.Dense(embedding_dim, num_class)
  299. )
  300. self.add = P.TensorAdd()
  301. self.reshape = P.Reshape()
  302. self.concat = P.Concat(axis=1)
  303. self.tile = P.Tile()
  304. self.transpose = P.Transpose()
  305. def construct(self, x):
  306. """TNT"""
  307. B, _, _, _ = x.shape
  308. pixel_embed = self.pixel_embed(x)
  309. pixel_embed = pixel_embed + self.transpose(self.pixel_pos, (0, 2, 1)) # B*N, M, C
  310. patch_embed = self.concat((self.cls_token, self.patch_embed))
  311. patch_embed = self.tile(patch_embed, (B, 1, 1))
  312. patch_embed = self.pos_drop(patch_embed + self.patch_pos)
  313. patch_embed = self.pixel2patch(pixel_embed, patch_embed)
  314. pixel_embed, patch_embed = self.encoder(pixel_embed, patch_embed)
  315. y = self.head(patch_embed[:, 0])
  316. return y
  317. def tnt_b(num_class):
  318. return TNT(img_size=384,
  319. patch_size=16,
  320. num_channels=3,
  321. embedding_dim=640,
  322. num_heads=10,
  323. num_layers=12,
  324. hidden_dim=640*4,
  325. stride=4,
  326. num_class=num_class)