You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_conformer.py 38 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828
  1. # Copyright 2022 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """UT test example for conformer."""
  15. import math
  16. import numpy as np
  17. import pytest
  18. import mindspore
  19. import mindspore.nn as nn
  20. import mindspore.common.initializer as Init
  21. from mindspore import Tensor, context, Parameter
  22. from mindspore.ops import operations as P
  23. from mindspore.common.initializer import initializer
  24. from mindspore.ops import functional as F
  25. from mindspore.common.initializer import TruncatedNormal, HeNormal
  26. from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell
  27. from mindspore.nn.loss.loss import LossBase
  28. mindspore.set_seed(0)
  29. np.random.seed(0)
  30. def flatten(input_tensor, start_dim):
  31. shape = input_tensor.shape
  32. new_shape = shape[:start_dim]
  33. dims = 1
  34. for i in range(start_dim, len(shape)):
  35. dims = dims * shape[i]
  36. return input_tensor.reshape(new_shape+(dims,))
  37. def one_hot_int(label, num_classes):
  38. num_elements = label.size
  39. one_hot_label = np.zeros((num_elements, num_classes), dtype=np.int32)
  40. for index in range(num_elements):
  41. one_hot_label[index][label[index]] = 1
  42. return Tensor(one_hot_label, mindspore.float32)
  43. class CrossEntropySmooth(LossBase):
  44. """CrossEntropy"""
  45. def __init__(self, reduction='mean', is_auto_parallel=False):
  46. super(CrossEntropySmooth, self).__init__()
  47. self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)
  48. if is_auto_parallel:
  49. self.ce.reduce_mean.add_prim_attr("cross_batch", True)
  50. def construct(self, logit, label):
  51. loss = None
  52. idx = 0
  53. for o in logit:
  54. o = F.cast(o, mindspore.float32)
  55. loss = self.ce(o, label) / len(logit) if idx == 0 else loss + self.ce(o, label) / len(logit)
  56. idx = idx + 1
  57. return loss
  58. class NetWithLossCell(nn.Cell):
  59. """Metwithlosscell"""
  60. def __init__(self, backbone, loss_fn):
  61. super(NetWithLossCell, self).__init__(auto_prefix=False)
  62. self._backbone = backbone
  63. self._loss_fn = loss_fn
  64. def construct(self, data, label):
  65. output = self._backbone(data)
  66. loss = self._loss_fn(output, label)
  67. return loss
  68. class DropPath(nn.Cell):
  69. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  70. """
  71. def __init__(self, drop_prob=None, num_dimension=4, dp=1):
  72. super(DropPath, self).__init__()
  73. self.drop_prob = drop_prob
  74. strategy_feat = (dp,) + (1,)*(num_dimension-1)
  75. self.uniformreal = P.UniformReal().shard((strategy_feat,))
  76. self.floor = P.Floor().shard((strategy_feat,))
  77. self.div = P.Div().shard((strategy_feat, ()))
  78. self.mul = P.Mul().shard((strategy_feat, strategy_feat))
  79. self.add = P.Add().shard(((), strategy_feat))
  80. def drop_path(self, x, drop_prob=0., training=True):
  81. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  82. This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
  83. the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
  84. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
  85. changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
  86. 'survival rate' as the argument.
  87. """
  88. if drop_prob == 0. or not training:
  89. return x
  90. keep_prob = 1 - drop_prob
  91. shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  92. random_tensor = self.add(keep_prob, F.cast(self.uniformreal(shape), mindspore.float32))
  93. random_tensor = self.floor(random_tensor)
  94. output = self.mul(self.div(x, keep_prob), random_tensor)
  95. return output # fp32
  96. def construct(self, x):
  97. return self.drop_path(x, self.drop_prob, self.training)
  98. class Norm(nn.Cell):
  99. r"""
  100. A self-defined layer norm operation using reduce sum and reduce mean
  101. Args:
  102. normalized_shape (tuple): The shape of the input tensor
  103. dp (int): The data parallel way of the inputs, Default:1
  104. eps (float): The epsilon value of the denominator. Default 1e-5.
  105. Inputs:
  106. - **x** (Tensor) - Tensor of shape :math:`(batch, seq\_length, hidden\_size)`.
  107. Outputs:
  108. Tensor of shape :math:`(batch, seq_length, hidden_size)`.
  109. """
  110. def __init__(self, normalized_shape, axes=-1,
  111. num_dimension=3, affine=True,
  112. dp=1, eps=1e-5, is_gn=False, num_groups=1):
  113. super(Norm, self).__init__()
  114. gamma = initializer('ones', normalized_shape)
  115. beta = initializer('zeros', normalized_shape)
  116. if affine:
  117. self.gamma = Parameter(gamma, name="gamma", parallel_optimizer=False)
  118. self.beta = Parameter(beta, name="beta", parallel_optimizer=False)
  119. else:
  120. self.gamma = gamma
  121. self.beta = beta
  122. strategy = [dp if i == 0 else 1 for i in range(num_dimension)]
  123. strategy = tuple(strategy)
  124. if is_gn:
  125. strategy1 = [dp if i == 0 else 1 for i in range(num_dimension-1)]
  126. strategy1 = tuple(strategy1)
  127. else:
  128. strategy1 = strategy
  129. self.mean = P.ReduceMean(keep_dims=True).shard((strategy1,))
  130. self.square = P.Square().shard((strategy1,))
  131. self.sqrt = P.Sqrt().shard((strategy1,))
  132. self.sub1 = P.Sub().shard((strategy1, strategy1))
  133. self.add = P.TensorAdd().shard((strategy1, ()))
  134. self.eps = eps
  135. self.real_div = P.RealDiv().shard((strategy1, strategy1))
  136. self.mul = P.Mul().shard((strategy, (1, 1, 1)))
  137. self.add2 = P.TensorAdd().shard((strategy, (1, 1, 1)))
  138. self.axes = axes
  139. self.is_gn = is_gn
  140. self.num_groups = num_groups
  141. # layer norm (1,1,-1) (-1,1,1)
  142. if num_dimension == 3:
  143. self.view_shape = (1, 1, -1)
  144. else:
  145. self.view_shape = (-1, 1, 1)
  146. def construct(self, x):
  147. r"""
  148. x : batch x seq_length x hidden_size
  149. """
  150. origin_shape = x.shape
  151. if self.is_gn:
  152. x = x.view(origin_shape[0], self.num_groups, -1)
  153. mean = self.mean(x, self.axes)
  154. diff = self.sub1(x, mean)
  155. variance = self.mean(self.square(diff), self.axes)
  156. variance_eps = self.sqrt(self.add(variance, self.eps))
  157. output = self.real_div(diff, variance_eps)
  158. if self.is_gn:
  159. output = output.view(origin_shape)
  160. output = self.add2(self.mul(output, self.gamma.view(self.view_shape)), self.beta.view(self.view_shape))
  161. return output
  162. class Mlp(nn.Cell):
  163. r"""
  164. MPL block
  165. """
  166. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., dp=1, mp=1):
  167. super(Mlp, self).__init__()
  168. out_features = out_features or in_features
  169. hidden_features = hidden_features or in_features
  170. self.fc1 = nn.Dense(in_features, hidden_features, weight_init=TruncatedNormal(0.02)).to_float(mindspore.float16)
  171. self.fc1.matmul.shard(((dp, 1), (mp, 1)))
  172. self.fc1.bias_add.shard(((dp, mp), (mp,)))
  173. self.act = act_layer()
  174. self.act.gelu.shard(((dp, mp),))
  175. self.fc2 = nn.Dense(hidden_features, out_features,
  176. weight_init=TruncatedNormal(0.02)).to_float(mindspore.float16)
  177. self.fc2.matmul.shard(((dp, mp), (1, mp)))
  178. self.fc2.bias_add.shard(((dp, 1), (1,)))
  179. self.drop = nn.Dropout(1.0-drop)
  180. self.drop.dropout.shard(((dp, 1),))
  181. self.drop2 = nn.Dropout(1.0-drop)
  182. self.drop2.dropout.shard(((dp, mp),))
  183. def construct(self, x):
  184. r"""
  185. x : fp32
  186. """
  187. origin_shape = x.shape
  188. x = x.view(-1, origin_shape[-1])
  189. x = self.fc1(F.cast(x, mindspore.float16))
  190. x = self.act(F.cast(x, mindspore.float32))
  191. x = self.drop2(x)
  192. x = self.fc2(F.cast(x, mindspore.float16))
  193. x = self.drop(F.cast(x, mindspore.float32))
  194. x = x.view(origin_shape[:-1]+(-1,))
  195. return x
  196. class Attention(nn.Cell):
  197. """Multi-head Attention"""
  198. def __init__(self, dim, hidden_dim=None,
  199. num_heads=8, qkv_bias=False, qk_scale=None,
  200. attn_drop=0., proj_drop=0., dp=1, mp=1):
  201. super(Attention, self).__init__()
  202. hidden_dim = hidden_dim or dim
  203. self.hidden_dim = hidden_dim
  204. self.num_heads = num_heads
  205. head_dim = hidden_dim // num_heads
  206. self.head_dim = head_dim
  207. self.scale = head_dim ** -0.5
  208. self.qk_scale = qk_scale
  209. self.mul = P.Mul().shard(((dp, mp, 1, 1), ()))
  210. self.q = nn.Dense(dim, hidden_dim, has_bias=qkv_bias,
  211. weight_init=TruncatedNormal(0.02)).to_float(mindspore.float16)
  212. self.q.matmul.shard(((dp, 1), (mp, 1)))
  213. if qkv_bias:
  214. self.q.bias_add.shard(((dp, mp), (mp,)))
  215. self.k = nn.Dense(dim, hidden_dim, has_bias=qkv_bias,
  216. weight_init=TruncatedNormal(0.02)).to_float(mindspore.float16)
  217. self.k.matmul.shard(((dp, 1), (mp, 1)))
  218. if qkv_bias:
  219. self.k.bias_add.shard(((dp, mp), (mp,)))
  220. self.v = nn.Dense(dim, hidden_dim, has_bias=qkv_bias,
  221. weight_init=TruncatedNormal(0.02)).to_float(mindspore.float16)
  222. self.v.matmul.shard(((dp, 1), (mp, 1)))
  223. if qkv_bias:
  224. self.v.bias_add.shard(((dp, mp), (mp,)))
  225. self.softmax = nn.Softmax(axis=-1)
  226. self.softmax.softmax.shard(((dp, mp, 1, 1),))
  227. self.batmatmul_trans_b = P.BatchMatMul().shard(((dp, mp, 1, 1), (dp, mp, 1, 1)))
  228. self.attn_drop = nn.Dropout(1. - attn_drop)
  229. self.attn_drop.dropout.shard(((dp, mp, 1, 1),))
  230. self.proj = nn.Dense(hidden_dim, dim, weight_init=TruncatedNormal(0.02)).to_float(mindspore.float16)
  231. self.proj.matmul.shard(((dp, mp), (1, mp)))
  232. self.proj.bias_add.shard(((dp, 1), (1,)))
  233. self.proj_drop = nn.Dropout(1. - proj_drop)
  234. self.proj_drop.dropout.shard(((dp, 1),))
  235. self.transpose = P.Transpose().shard(((dp, 1, mp, 1),))
  236. self.transpose2 = P.Transpose().shard(((dp, 1, 1, 1),))
  237. self.reshape = P.Reshape()
  238. def construct(self, x):
  239. """Multi-head Attention"""
  240. b_size, n_channel, _ = x.shape # fp32
  241. x = F.cast(x, mindspore.float16)
  242. x = x.view(b_size*n_channel, -1)
  243. q = self.q(x)
  244. k = self.k(x)
  245. v = self.v(x)
  246. q = self.transpose(
  247. F.reshape(
  248. q,
  249. (-1, n_channel, self.num_heads, self.head_dim)),
  250. (0, 2, 1, 3))
  251. k = self.transpose(
  252. F.reshape(
  253. k, (-1, n_channel, self.num_heads, self.head_dim)),
  254. (0, 2, 3, 1))
  255. v = self.transpose(
  256. F.reshape(
  257. v,
  258. (-1, n_channel, self.num_heads, self.head_dim)),
  259. (0, 2, 1, 3))
  260. attn = self.softmax(F.cast(self.batmatmul_trans_b(self.mul(q, self.scale), k), mindspore.float32))
  261. attn = self.attn_drop(attn)
  262. x = self.reshape(self.transpose2(self.batmatmul_trans_b(F.cast(attn, mindspore.float16), v),
  263. (0, 2, 1, 3)), (b_size*n_channel, -1))
  264. x = self.proj(x)
  265. x = self.proj_drop(x) # fp16
  266. return x.view(b_size, n_channel, -1)
  267. class Block(nn.Cell):
  268. """Block."""
  269. def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
  270. drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,
  271. dp=1, mp=1):
  272. super(Block, self).__init__()
  273. self.norm1 = norm_layer([dim], epsilon=1e-6)
  274. self.norm1.layer_norm.shard(((dp, 1, 1), (1,), (1,)))
  275. self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
  276. attn_drop=attn_drop, proj_drop=drop,
  277. dp=dp, mp=mp)
  278. # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
  279. self.drop_path = DropPath(drop_path, num_dimension=3, dp=dp) if drop_path > 0. else P.Identity()
  280. self.norm2 = norm_layer([dim], epsilon=1e-6)
  281. self.norm2.layer_norm.shard(((dp, 1, 1), (1,), (1,)))
  282. mlp_hidden_dim = int(dim * mlp_ratio)
  283. self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop,
  284. dp=dp, mp=mp)
  285. self.add = P.Add().shard(((dp, 1, 1), (dp, 1, 1)))
  286. def construct(self, x):
  287. # x fp32
  288. x = self.add(x, self.drop_path(self.attn(self.norm1(x)))) # output x fp32
  289. x = self.add(x, self.drop_path(self.mlp(self.norm2(x)))) # output x fp32
  290. return x
  291. class ConvBlock(nn.Cell):
  292. """ConvBlock"""
  293. def __init__(self, inplanes, outplanes, stride=1,
  294. res_conv=False, act_layer=nn.ReLU, groups=1,
  295. norm_layer=nn.BatchNorm2d, drop_block=None,
  296. drop_path=0., return_x_2=False, weighted_fusion=False, dp=1):
  297. super(ConvBlock, self).__init__()
  298. self.init_network(inplanes, outplanes, norm_layer,
  299. act_layer, stride, groups, dp)
  300. self.add = P.Add().shard(((dp, 1, 1, 1), (dp, 1, 1, 1)))
  301. self.mul = P.Mul().shard(((1,), (dp, 1, 1, 1)))
  302. if res_conv:
  303. self.residual_conv = nn.Conv2d(inplanes, outplanes,
  304. kernel_size=1, stride=stride,
  305. padding=0, has_bias=False, pad_mode="pad",
  306. weight_init=HeNormal(mode='fan_out',
  307. nonlinearity='relu')).to_float(mindspore.float16)
  308. self.residual_conv.conv2d.shard(((dp, 1, 1, 1), (1, 1, 1, 1)))
  309. self.residual_conv.bias_add.shard(((dp, 1, 1, 1), (1,)))
  310. self.residual_bn = norm_layer(outplanes, eps=1e-6)
  311. self.residual_bn.bn_train.shard(((dp, 1, 1, 1), (1,), (1,), (1,), (1,)))
  312. self.res_conv = res_conv
  313. self.drop_block = drop_block
  314. self.drop_path = DropPath(drop_path, dp=dp)
  315. self.return_x_2 = return_x_2
  316. self.weighted_fusion = weighted_fusion
  317. if weighted_fusion:
  318. self.add1 = P.Add().shard(((), (1,)))
  319. self.div = P.Div().shard(((), (1,)))
  320. self.exp = P.Exp().shard(((1,),))
  321. self.neg = P.Neg().shard(((1,),))
  322. self.c = Parameter(Tensor(np.zeros((1,)), mindspore.float16), requires_grad=True)
  323. def init_network(self, inplanes, outplanes, norm_layer,
  324. act_layer, stride, groups, dp):
  325. expansion = 4
  326. med_planes = outplanes // expansion
  327. self.conv1 = nn.Conv2d(inplanes, med_planes,
  328. kernel_size=1, stride=1,
  329. padding=0, has_bias=False, pad_mode="pad",
  330. weight_init=HeNormal(mode='fan_out', nonlinearity='relu')).to_float(mindspore.float16)
  331. self.conv1.conv2d.shard(((dp, 1, 1, 1), (1, 1, 1, 1)))
  332. self.conv1.bias_add.shard(((dp, 1, 1, 1), (1,)))
  333. self.bn1 = norm_layer(med_planes, eps=1e-6)
  334. self.bn1.bn_train.shard(((dp, 1, 1, 1), (1,), (1,), (1,), (1,)))
  335. self.act1 = act_layer()
  336. self.act1.relu.shard(((dp, 1, 1, 1),))
  337. self.conv2 = nn.Conv2d(med_planes, med_planes,
  338. kernel_size=3, stride=stride, group=groups,
  339. padding=1, has_bias=False, pad_mode="pad",
  340. weight_init=HeNormal(mode='fan_out', nonlinearity='relu')).to_float(mindspore.float16)
  341. self.conv2.conv2d.shard(((dp, 1, 1, 1), (1, 1, 1, 1)))
  342. self.conv2.bias_add.shard(((dp, 1, 1, 1), (1,)))
  343. self.bn2 = norm_layer(med_planes, eps=1e-6)
  344. self.bn2.bn_train.shard(((dp, 1, 1, 1), (1,), (1,), (1,), (1,)))
  345. self.act2 = act_layer()
  346. self.act2.relu.shard(((dp, 1, 1, 1),))
  347. self.conv3 = nn.Conv2d(med_planes, outplanes,
  348. kernel_size=1, stride=1,
  349. padding=0, has_bias=False, pad_mode="pad",
  350. weight_init=HeNormal(mode='fan_out', nonlinearity='relu')).to_float(mindspore.float16)
  351. self.conv3.conv2d.shard(((dp, 1, 1, 1), (1, 1, 1, 1)))
  352. self.conv3.bias_add.shard(((dp, 1, 1, 1), (1,)))
  353. self.bn3 = norm_layer(outplanes, eps=1e-6)
  354. self.bn3.bn_train.shard(((dp, 1, 1, 1), (1,), (1,), (1,), (1,)))
  355. self.act3 = act_layer()
  356. self.act3.relu.shard(((dp, 1, 1, 1),))
  357. def construct(self, x, x_t=None):
  358. """ConvBlock construct"""
  359. residual = x
  360. x = self.conv1(x) # fp16
  361. x = self.bn1(F.cast(x, mindspore.float32))
  362. x = F.cast(x, mindspore.float16)
  363. if self.drop_block is not None:
  364. x = self.drop_block(x)
  365. x = self.act1(x) # fp16
  366. if x_t is None:
  367. x = self.conv2(x)
  368. else:
  369. if self.weighted_fusion:
  370. c = self.div(1.0, self.add1(1.0, self.exp(self.neg(self.c))))
  371. x = self.conv2(self.add(self.mul(c, x), self.mul(1.0-c, F.cast(x_t, mindspore.float16))))
  372. else:
  373. x = self.conv2(self.add(x, F.cast(x_t, mindspore.float16)))
  374. x = self.bn2(F.cast(x, mindspore.float32))
  375. x = F.cast(x, mindspore.float16)
  376. if self.drop_block is not None:
  377. x = self.drop_block(x)
  378. x2 = self.act2(x)
  379. x = self.conv3(x2)
  380. x = self.bn3(F.cast(x, mindspore.float32))
  381. x = F.cast(x, mindspore.float16)
  382. if self.drop_block is not None:
  383. x = self.drop_block(x)
  384. if self.drop_path is not None:
  385. x = self.drop_path(x)
  386. if self.res_conv:
  387. residual = self.residual_conv(residual)
  388. residual = self.residual_bn(F.cast(residual, mindspore.float32))
  389. residual = F.cast(residual, mindspore.float16)
  390. x = self.add(x, residual)
  391. x = self.act3(x)
  392. if self.return_x_2:
  393. return x, x2
  394. return x
  395. class FCUDown(nn.Cell):
  396. """ CNN feature maps -> Transformer patch embeddings
  397. """
  398. def __init__(self, inplanes, outplanes, dw_stride, act_layer=nn.GELU,
  399. norm_layer=nn.LayerNorm, cls_token=True, dp=1):
  400. super(FCUDown, self).__init__()
  401. self.dw_stride = dw_stride
  402. self.cls_token = cls_token
  403. self.conv_project = nn.Conv2d(inplanes, outplanes,
  404. kernel_size=1, stride=1,
  405. padding=0, has_bias=True, pad_mode="pad",
  406. weight_init=HeNormal(mode='fan_out',
  407. nonlinearity='relu')).to_float(mindspore.float16)
  408. self.conv_project.conv2d.shard(((dp, 1, 1, 1), (1, 1, 1, 1)))
  409. self.conv_project.bias_add.shard(((dp, 1, 1, 1), (1,)))
  410. self.sample_pooling = nn.AvgPool2d(kernel_size=dw_stride, stride=dw_stride)
  411. self.sample_pooling.avg_pool.shard(((dp, 1, 1, 1),))
  412. self.ln = norm_layer([outplanes], epsilon=1e-6)
  413. self.ln.layer_norm.shard(((dp, 1, 1), (1,), (1,)))
  414. self.act = act_layer()
  415. self.act.gelu.shard(((dp, 1, 1),))
  416. self.concat = P.Concat(axis=1).shard(((dp, 1, 1), (dp, 1, 1)))
  417. self.transpose = P.Transpose().shard(((dp, 1, 1),))
  418. self.slice = P.StridedSlice().shard(((dp, 1, 1),))
  419. def construct(self, x, x_t):
  420. """FCUDown construct"""
  421. # x fp16, x_t fp32
  422. x = self.conv_project(x) # [N, C, H, W]
  423. tmp = self.sample_pooling(x)
  424. tmp1 = flatten(tmp, 2)
  425. x = self.transpose(tmp1, (0, 2, 1))
  426. x = self.ln(F.cast(x, mindspore.float32))
  427. x = self.act(x)
  428. if self.cls_token:
  429. b_size, _, height = F.shape(x_t)
  430. tmp2 = self.slice(x_t, (0, 0, 0), (b_size, 1, height), (1, 1, 1))
  431. x = self.concat([tmp2, x])
  432. return x
  433. class FCUUp(nn.Cell):
  434. """ Transformer patch embeddings -> CNN feature maps
  435. """
  436. def __init__(self, inplanes, outplanes, up_stride, act_layer=nn.ReLU,
  437. norm_layer=nn.BatchNorm2d, cls_token=True, seq_length=196, dp=1):
  438. super(FCUUp, self).__init__()
  439. self.up_stride = up_stride
  440. self.conv_project = nn.Conv2d(inplanes, outplanes,
  441. kernel_size=1, stride=1,
  442. padding=0, has_bias=True, pad_mode="pad",
  443. weight_init=HeNormal(mode='fan_out',
  444. nonlinearity='relu')).to_float(mindspore.float16)
  445. self.conv_project.conv2d.shard(((dp, 1, 1, 1), (1, 1, 1, 1)))
  446. self.conv_project.bias_add.shard(((dp, 1, 1, 1), (1,)))
  447. self.ln = Norm(inplanes, axes=-1, affine=False, dp=dp, eps=1e-6)
  448. self.bn = norm_layer(outplanes, eps=1e-6)
  449. self.bn.bn_train.shard(((dp, 1, 1, 1), (1,), (1,), (1,), (1,)))
  450. self.act = act_layer()
  451. self.act.relu.shard(((dp, 1, 1, 1),))
  452. self.cls_token = cls_token
  453. height = weight = int(math.sqrt(seq_length))
  454. self.resize_neighbor = P.ResizeNearestNeighbor(size=(height * self.up_stride,
  455. weight * self.up_stride)).shard(((dp, 1, 1, 1),))
  456. self.reshape = P.Reshape()
  457. self.transpose = P.Transpose().shard(((dp, 1, 1),))
  458. self.slice = P.StridedSlice().shard(((dp, 1, 1),))
  459. def construct(self, x, height, weight):
  460. """FCUUp construct"""
  461. # x fp32
  462. b_size, t_num, channel = F.shape(x)
  463. x = self.ln(x)
  464. if self.cls_token:
  465. x_r = self.reshape(self.transpose(\
  466. self.slice(x, (0, 1, 0), (b_size, t_num, channel),\
  467. (1, 1, 1)), (0, 2, 1)), (b_size, channel, height, weight))
  468. else:
  469. x_r = self.reshape(self.transpose(x, (0, 2, 1)), (b_size, channel, height, weight))
  470. # x_r fp32
  471. x_r_fp32 = F.cast(self.conv_project(F.cast(x_r, mindspore.float16)), mindspore.float32)
  472. x_r_fp16 = F.cast(self.bn(x_r_fp32), mindspore.float16)
  473. x_r = self.act(x_r_fp16)
  474. return self.resize_neighbor(x_r)
  475. class ConvTransBlock(nn.Cell):
  476. """
  477. Basic module for ConvTransformer, keep feature maps for CNN block and patch embeddings for transformer encoder block
  478. """
  479. def __init__(self, inplanes, outplanes, res_conv, stride, dw_stride, embed_dim, num_heads=12, mlp_ratio=4.,
  480. qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
  481. last_fusion=False, groups=1, cls_token=True, weighted_fusion=False, dp=1, mp=1, seq_length=196):
  482. super(ConvTransBlock, self).__init__()
  483. expansion = 4
  484. self.cnn_block = ConvBlock(inplanes=inplanes, outplanes=outplanes, res_conv=res_conv, stride=stride,
  485. groups=groups, drop_path=drop_path_rate, return_x_2=True, dp=dp)
  486. if last_fusion:
  487. self.fusion_block = ConvBlock(inplanes=outplanes, outplanes=outplanes, stride=2, res_conv=True,
  488. groups=groups, drop_path=drop_path_rate, weighted_fusion=weighted_fusion,
  489. dp=dp)
  490. else:
  491. self.fusion_block = ConvBlock(inplanes=outplanes, outplanes=outplanes,
  492. groups=groups, drop_path=drop_path_rate, weighted_fusion=weighted_fusion,
  493. dp=dp)
  494. self.squeeze_block = FCUDown(inplanes=outplanes // expansion,
  495. outplanes=embed_dim, dw_stride=dw_stride, cls_token=cls_token,
  496. dp=dp)
  497. self.expand_block = FCUUp(inplanes=embed_dim,
  498. outplanes=outplanes // expansion, up_stride=dw_stride, cls_token=cls_token,
  499. dp=dp, seq_length=seq_length)
  500. self.trans_block = Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio,
  501. qkv_bias=qkv_bias, qk_scale=qk_scale,
  502. drop=drop_rate, attn_drop=attn_drop_rate,
  503. drop_path=drop_path_rate, dp=dp, mp=mp)
  504. self.dw_stride = dw_stride
  505. self.embed_dim = embed_dim
  506. self.last_fusion = last_fusion
  507. self.weighted_fusion = weighted_fusion
  508. if weighted_fusion:
  509. self.exp = P.Exp().shard(((1,),))
  510. self.c = Parameter(Tensor(np.zeros((1,)), mindspore.float16), requires_grad=True)
  511. self.add = P.Add().shard(((dp, 1, 1), (dp, 1, 1)))
  512. self.add1 = P.Add().shard(((), (1,)))
  513. self.div = P.Div().shard(((), (1,)))
  514. self.mul = P.Mul().shard(((1,), (dp, 1, 1)))
  515. self.sub = P.Sub().shard(((), (1,)))
  516. self.neg = P.Neg().shard(((1,),))
  517. def construct(self, x, x_t):
  518. """ConvTransBlock construct"""
  519. # x fp16, x_t fp32
  520. x, x2 = self.cnn_block(x) # both fp16
  521. _, _, height, weight = x2.shape
  522. x_st = self.squeeze_block(x2, x_t) # x_st fp32
  523. if self.weighted_fusion:
  524. c = self.div(1.0, self.add1(1.0, self.exp(self.neg(self.c))))
  525. x_t = self.trans_block(self.add(self.mul(c, x_st), self.mul(self.sub(1.0, c), x_t)))
  526. else:
  527. x_t = self.trans_block(self.add(x_st, x_t)) # x_t fp32
  528. x_t_r = self.expand_block(x_t, height // self.dw_stride, weight // self.dw_stride) # x_t_r fp16
  529. x = self.fusion_block(x, x_t_r)
  530. return x, x_t
  531. class ConformerOverflow(nn.Cell):
  532. """Conformeroverflow"""
  533. def __init__(self, patch_size=16, in_chans=3, num_classes=1000,
  534. base_channel=64, channel_ratio=4, embed_dim=768,
  535. stage_point=None, num_heads=12, mlp_ratio=4.,
  536. qkv_bias=False, qk_scale=None, drop_rate=0.,
  537. attn_drop_rate=0., drop_path_rate=0., cls_token=True,
  538. batch_size=8, weighted_fusion=False, dp=1, mp=1, seq_length=196):
  539. # Transformer
  540. super(ConformerOverflow, self).__init__()
  541. self.num_classes = num_classes
  542. self.num_features = self.embed_dim = embed_dim
  543. depth = stage_point[-1]
  544. self.cls_token_flag = cls_token
  545. if self.cls_token_flag:
  546. self.cls_token = mindspore.Parameter(initializer('zeros', (1, 1, embed_dim), mindspore.float32))
  547. self.trans_dpr = [Tensor(x, mindspore.float32) for x in np.linspace(0, drop_path_rate, depth, dtype=np.float32)]
  548. # Classifier head
  549. self.trans_norm = nn.LayerNorm([embed_dim], epsilon=1e-05)
  550. self.trans_norm.layer_norm.shard(((dp, 1, 1), (1,), (1,)))
  551. self.trans_cls_head = nn.Dense(embed_dim, num_classes,
  552. weight_init=TruncatedNormal(0.02)).to_float(mindspore.float16)
  553. self.trans_cls_head.matmul.shard(((dp, 1), (1, 1)))
  554. self.trans_cls_head.bias_add.shard(((dp, 1), (1,)))
  555. self.pooling = nn.AvgPool2d(kernel_size=7, stride=7)
  556. self.pooling.avg_pool.shard(((dp, 1, 1, 1),))
  557. self.conv_cls_head = nn.Dense(int(256 * channel_ratio), num_classes,
  558. weight_init=TruncatedNormal(0.02)).to_float(mindspore.float16)
  559. self.conv_cls_head.matmul.shard(((dp, 1), (1, 1)))
  560. self.conv_cls_head.bias_add.shard(((dp, 1), (1,)))
  561. # Stem stage: get the feature maps by conv block (copied form resnet.py)
  562. self.conv1 = nn.Conv2d(in_chans, 64, kernel_size=7, stride=2,
  563. padding=3, has_bias=False, pad_mode="pad",
  564. weight_init=HeNormal(mode='fan_out',
  565. nonlinearity='relu')).to_float(mindspore.float16)
  566. self.conv1.conv2d.shard(((dp, 1, 1, 1), (1, 1, 1, 1)))
  567. self.conv1.bias_add.shard(((dp, 1, 1, 1), (1,)))
  568. self.bn1 = nn.BatchNorm2d(64)
  569. self.bn1.bn_train.shard(((dp, 1, 1, 1), (1,), (1,), (1,), (1,)))
  570. self.act1 = nn.ReLU()
  571. self.act1.relu.shard(((dp, 1, 1, 1),))
  572. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
  573. self.maxpool.max_pool.shard(((dp, 1, 1, 1),))
  574. self.concat = P.Concat(axis=1).shard(((dp, 1, 1), (dp, 1, 1)))
  575. self.conv_trans_list = []
  576. self.broadcastto = P.BroadcastTo((batch_size, -1, -1)).shard(((1, 1, 1),))
  577. self.slice = P.StridedSlice().shard(((dp, 1, 1),))
  578. self.squeeze = P.Squeeze(1).shard(((dp, 1, 1),))
  579. self.mean = P.ReduceMean(keep_dims=True).shard(((dp, 1, 1),))
  580. self.trunc_normal_ = Init.TruncatedNormal(.02)
  581. if self.cls_token_flag:
  582. self.trunc_normal_(self.cls_token.asnumpy())
  583. self.init_stage1_4(base_channel, channel_ratio, patch_size, embed_dim, num_heads,
  584. mlp_ratio, qkv_bias, qk_scale, drop_rate, attn_drop_rate,
  585. stage_point, weighted_fusion, seq_length, dp, mp)
  586. self.init_stage5_12(base_channel, channel_ratio, patch_size, embed_dim, num_heads,
  587. mlp_ratio, qkv_bias, qk_scale, drop_rate, attn_drop_rate,
  588. stage_point, weighted_fusion, seq_length, depth, dp, mp)
  589. def init_stage1_4(self, base_channel, channel_ratio, patch_size,
  590. embed_dim, num_heads, mlp_ratio, qkv_bias, qk_scale,
  591. drop_rate, attn_drop_rate, stage_point, weighted_fusion,
  592. seq_length, dp, mp):
  593. # 1 stage
  594. stage_1_channel = int(base_channel * channel_ratio)
  595. trans_dw_stride = patch_size // 4
  596. self.conv_1 = ConvBlock(inplanes=64, outplanes=stage_1_channel, res_conv=True, stride=1, dp=dp)
  597. self.trans_patch_conv = nn.Conv2d(64, embed_dim,
  598. kernel_size=trans_dw_stride, stride=trans_dw_stride,
  599. padding=0, has_bias=True, pad_mode="pad",
  600. weight_init=HeNormal(mode='fan_out',
  601. nonlinearity='relu')).to_float(mindspore.float16)
  602. self.trans_patch_conv.conv2d.shard(((dp, 1, 1, 1), (1, 1, 1, 1)))
  603. self.trans_patch_conv.bias_add.shard(((dp, 1, 1, 1), (1,)))
  604. self.trans_1 = Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
  605. qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=self.trans_dpr[0],
  606. dp=dp, mp=mp)
  607. # 2~4 stage
  608. init_stage = 2
  609. fin_stage = stage_point[1] + 1 # fin_stage = depth // 3 + 1
  610. for i in range(init_stage, fin_stage):
  611. self.conv_trans_list.append(
  612. ConvTransBlock(stage_1_channel, stage_1_channel, False, 1,
  613. dw_stride=trans_dw_stride, embed_dim=embed_dim, num_heads=num_heads,
  614. mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
  615. drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
  616. drop_path_rate=self.trans_dpr[i - 1], cls_token=self.cls_token_flag,
  617. weighted_fusion=weighted_fusion, dp=dp, mp=mp, seq_length=seq_length)
  618. )
  619. def init_stage5_12(self, base_channel, channel_ratio, patch_size,
  620. embed_dim, num_heads, mlp_ratio, qkv_bias, qk_scale,
  621. drop_rate, attn_drop_rate, stage_point, weighted_fusion,
  622. seq_length, depth, dp, mp):
  623. stage_1_channel = int(base_channel * channel_ratio)
  624. stage_2_channel = int(base_channel * channel_ratio * 2)
  625. trans_dw_stride = patch_size // 4
  626. fin_stage = stage_point[1] + 1
  627. # 5~8 stage
  628. init_stage = fin_stage # 5
  629. fin_stage = stage_point[2] + 1 # fin_stage = fin_stage + depth // 3 # 9
  630. for i in range(init_stage, fin_stage):
  631. s = 2 if i == init_stage else 1
  632. in_channel = stage_1_channel if i == init_stage else stage_2_channel
  633. res_conv = bool(i == init_stage)
  634. self.conv_trans_list.append(
  635. ConvTransBlock(in_channel, stage_2_channel, res_conv, s, dw_stride=trans_dw_stride // 2,
  636. embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio,
  637. qkv_bias=qkv_bias, qk_scale=qk_scale, drop_rate=drop_rate,
  638. attn_drop_rate=attn_drop_rate, drop_path_rate=self.trans_dpr[i - 1],
  639. cls_token=self.cls_token_flag, weighted_fusion=weighted_fusion,
  640. dp=dp, mp=mp, seq_length=seq_length)
  641. )
  642. stage_3_channel = int(base_channel * channel_ratio * 2 * 2)
  643. # 9~12 stage
  644. init_stage = fin_stage # 9
  645. fin_stage = stage_point[3] + 1 # fin_stage = fin_stage + depth // 3 # 13
  646. for i in range(init_stage, fin_stage):
  647. s = 2 if i == init_stage else 1
  648. in_channel = stage_2_channel if i == init_stage else stage_3_channel
  649. res_conv = bool(i == init_stage)
  650. last_fusion = bool(i == depth)
  651. self.conv_trans_list.append(
  652. ConvTransBlock(
  653. in_channel, stage_3_channel, res_conv, s, dw_stride=trans_dw_stride // 4,
  654. embed_dim=embed_dim,
  655. num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
  656. drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
  657. drop_path_rate=self.trans_dpr[i - 1],
  658. last_fusion=last_fusion,
  659. cls_token=self.cls_token_flag,
  660. weighted_fusion=weighted_fusion,
  661. dp=dp, mp=mp,
  662. seq_length=seq_length
  663. )
  664. )
  665. self.conv_trans_blks = nn.CellList(self.conv_trans_list)
  666. def construct(self, x):
  667. """conformer construct"""
  668. # x fp32
  669. cls_tokens = None
  670. if self.cls_token_flag:
  671. cls_tokens = self.broadcastto(self.cls_token) # fp32
  672. # stem stage [N, 3, 224, 224] -> [N, 64, 56, 56]
  673. x_fp32 = F.cast(self.conv1(F.cast(x, mindspore.float16)), mindspore.float32)
  674. x_fp16 = F.cast(self.bn1(x_fp32), mindspore.float16)
  675. x_base = self.maxpool(self.act1(x_fp16)) # fp16
  676. # 1 stage
  677. x = self.conv_1(x_base) # fp16
  678. tmp = self.trans_patch_conv(x_base)
  679. tmp1 = flatten(tmp, 2)
  680. x_t = F.cast(tmp1.transpose((0, 2, 1)), mindspore.float32) # fp32
  681. if self.cls_token_flag:
  682. x_t = self.concat([cls_tokens, x_t])
  683. x_t = self.trans_1(x_t) # fp32
  684. # 2 ~ final
  685. for blk in self.conv_trans_blks:
  686. x, x_t = blk(x, x_t) # x fp16, x_t fp32
  687. # conv classification
  688. tmp2 = self.pooling(x)
  689. x_p = flatten(tmp2, 1)
  690. conv_cls = self.conv_cls_head(x_p) # conv_cls fp16
  691. # trans classification
  692. x_t = self.trans_norm(x_t)
  693. x_t = F.cast(x_t, mindspore.float16)
  694. b_size, _, height = F.shape(x_t)
  695. tmp3 = self.squeeze(self.slice(x_t, (0, 0, 0), (b_size, 1, height), (1, 1, 1)))
  696. if self.cls_token_flag:
  697. tran_cls = self.trans_cls_head(tmp3)
  698. else:
  699. tran_cls = self.trans_cls_head(self.mean(x_t, 1))
  700. return [conv_cls, tran_cls]
  701. @pytest.mark.level0
  702. @pytest.mark.platform_arm_ascend_training
  703. @pytest.mark.platform_x86_ascend_training
  704. @pytest.mark.env_onecard
  705. def test_conformer_arm_ascend():
  706. """
  707. Feature: test conformer architecture
  708. Description: convolution and transformer
  709. Expectation: compile success
  710. """
  711. context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
  712. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=1, global_rank=0)
  713. net = ConformerOverflow(patch_size=16, channel_ratio=4, embed_dim=384, stage_point=[1, 4, 8, 12],
  714. num_heads=6, mlp_ratio=4, qkv_bias=False, qk_scale=None, cls_token=True,
  715. num_classes=1000, drop_rate=0.0, drop_path_rate=0.1, attn_drop_rate=0.0,
  716. batch_size=32, weighted_fusion=True, dp=8, mp=1, seq_length=196)
  717. ls = CrossEntropySmooth(reduction="mean")
  718. net_with_loss_net = NetWithLossCell(net, ls)
  719. net_with_loss = _VirtualDatasetCell(net_with_loss_net)
  720. optimizer = nn.AdamWeightDecay(net.trainable_params())
  721. train_net = nn.TrainOneStepCell(net_with_loss, optimizer)
  722. data = Tensor(np.ones([32, 3, 224, 224]), dtype=mindspore.float32)
  723. label = Tensor(np.ones([32]).astype(np.int32))
  724. label = one_hot_int(label, 1000)
  725. train_net(data, label)