You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

effnet.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. """EffNet model define"""
  2. import numpy as np
  3. import mindspore.nn as nn
  4. from mindspore.ops import operations as P
  5. from mindspore.common.initializer import TruncatedNormal
  6. from mindspore import Tensor
  7. __all__ = ['effnet']
  8. def weight_variable():
  9. """weight initial"""
  10. return TruncatedNormal(0.02)
  11. def _make_divisible(v, divisor, min_value=None):
  12. """
  13. This function is taken from the original tf repo.
  14. It ensures that all layers have a channel number that is divisible by 8
  15. It can be seen here:
  16. https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
  17. :param v:
  18. :param divisor:
  19. iparam min_value:
  20. :return:
  21. """
  22. if min_value is None:
  23. min_value = divisor
  24. new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
  25. # Make sure that round down does not go down by more than 10%.
  26. if new_v < 0.9 * v:
  27. new_v += divisor
  28. return new_v
  29. class Swish(nn.Cell):
  30. def __init__(self):
  31. super().__init__(Swish)
  32. self.sigmoid = nn.Sigmoid()
  33. def construct(self, x):
  34. s = self.sigmoid(x)
  35. m = x*s
  36. return m
  37. class AdaptiveAvgPool(nn.Cell):
  38. def __init__(self, output_size=None):
  39. super().__init__(AdaptiveAvgPool)
  40. self.mean = P.ReduceMean(keep_dims=True)
  41. self.output_size = output_size
  42. def construct(self, x):
  43. return self.mean(x, (2, 3))
  44. class SELayer(nn.Cell):
  45. """
  46. SELayer
  47. """
  48. def __init__(self, channel, reduction=4):
  49. super().__init__(SELayer)
  50. reduced_chs = _make_divisible(channel/reduction, 1)
  51. self.avg_pool = AdaptiveAvgPool(output_size=(1, 1))
  52. weight = weight_variable()
  53. self.conv_reduce = nn.Conv2d(
  54. in_channels=channel, out_channels=reduced_chs, kernel_size=1, has_bias=True, weight_init=weight)
  55. self.act1 = Swish()
  56. self.conv_expand = nn.Conv2d(
  57. in_channels=reduced_chs, out_channels=channel, kernel_size=1, has_bias=True)
  58. self.act2 = nn.Sigmoid()
  59. def construct(self, x):
  60. o = self.avg_pool(x)
  61. o = self.conv_reduce(o)
  62. o = self.act1(o)
  63. o = self.conv_expand(o)
  64. o = self.act2(o)
  65. return x * o
  66. class DepthwiseSeparableConv(nn.Cell):
  67. """
  68. DepthwiseSeparableConv
  69. """
  70. def __init__(self, in_chs, out_chs, dw_kernel_size=3,
  71. stride=1, noskip=False, se_ratio=0.0, drop_connect_rate=0.0):
  72. super().__init__(DepthwiseSeparableConv)
  73. assert stride in [1, 2]
  74. self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
  75. self.drop_connect_rate = drop_connect_rate
  76. self.conv_dw = nn.Conv2d(in_channels=in_chs, out_channels=in_chs, kernel_size=dw_kernel_size,
  77. stride=stride, pad_mode="pad", padding=1, has_bias=False, group=in_chs)
  78. self.bn1 = nn.BatchNorm2d(in_chs, eps=0.001)
  79. self.act1 = Swish()
  80. if se_ratio is not None and se_ratio > 0.:
  81. self.se = SELayer(in_chs, reduction=se_ratio)
  82. else:
  83. print("ERRRRRORRRR -- not prepared for this one\n")
  84. self.conv_pw = nn.Conv2d(
  85. in_channels=in_chs, out_channels=out_chs, kernel_size=1, stride=stride, has_bias=False)
  86. self.bn2 = nn.BatchNorm2d(out_chs, eps=0.001)
  87. def construct(self, x):
  88. """
  89. construct
  90. """
  91. residual = x
  92. x = self.conv_dw(x)
  93. x = self.bn1(x)
  94. x = self.act1(x)
  95. x = self.se(x)
  96. x = self.conv_pw(x)
  97. x = self.bn2(x)
  98. if self.has_residual:
  99. x += residual
  100. return x
  101. def conv_3x3_bn(inp, oup, stride):
  102. weight = weight_variable()
  103. return nn.SequentialCell([
  104. nn.Conv2d(in_channels=inp, out_channels=oup, kernel_size=3, stride=stride,
  105. padding=1, weight_init=weight, has_bias=False, pad_mode='pad'),
  106. nn.BatchNorm2d(oup, eps=0.001), # , momentum=0.1),
  107. nn.HSwish()])
  108. def conv_1x1_bn(inp, oup):
  109. weight = weight_variable()
  110. return nn.SequentialCell([
  111. nn.Conv2d(in_channels=inp, out_channels=oup, kernel_size=1,
  112. stride=1, padding=0, weight_init=weight, has_bias=False),
  113. nn.BatchNorm2d(oup, eps=0.001),
  114. nn.HSwish()])
  115. class InvertedResidual(nn.Cell):
  116. """
  117. InvertedResidual
  118. """
  119. def __init__(self, in_chs, out_chs, kernel_size, stride, padding, expansion, se_ratio):
  120. super().__init__(InvertedResidual)
  121. assert stride in [1, 2]
  122. mid_chs: int = _make_divisible(in_chs * expansion, 1)
  123. self.has_residual = (in_chs == out_chs and stride == 1)
  124. self.drop_connect_rate = 0
  125. self.conv_pw = nn.Conv2d(
  126. in_channels=in_chs, out_channels=mid_chs, kernel_size=1, stride=1, has_bias=False)
  127. self.bn1 = nn.BatchNorm2d(mid_chs, eps=0.001) # ,momentum=0.1)
  128. self.act1 = Swish()
  129. if stride > 1:
  130. self.conv_dw = nn.Conv2d(in_channels=mid_chs, out_channels=mid_chs, kernel_size=kernel_size,
  131. stride=stride, padding=padding, has_bias=False, group=mid_chs, pad_mode='same')
  132. else:
  133. self.conv_dw = nn.Conv2d(in_channels=mid_chs, out_channels=mid_chs, kernel_size=kernel_size,
  134. stride=stride, padding=padding, has_bias=False, group=mid_chs, pad_mode='pad')
  135. self.bn2 = nn.BatchNorm2d(mid_chs, eps=0.001) # ,momentum=0.1)
  136. self.act2 = Swish()
  137. # Squeeze-and-excitation
  138. if se_ratio is not None and se_ratio > 0.:
  139. self.se = SELayer(mid_chs, reduction=se_ratio)
  140. else:
  141. print("ERRRRRORRRR -- not prepared for this one\n")
  142. # Point-wise linear projection
  143. self.conv_pwl = nn.Conv2d(
  144. in_channels=mid_chs, out_channels=out_chs, kernel_size=1, stride=1, has_bias=False)
  145. self.bn3 = nn.BatchNorm2d(out_chs, eps=0.001) # ,momentum=0.1)
  146. def construct(self, x):
  147. """
  148. construct
  149. """
  150. residual = x
  151. # Point-wise expansion
  152. x = self.conv_pw(x)
  153. x = self.bn1(x)
  154. x = self.act1(x)
  155. # Depth-wise convolution
  156. x = self.conv_dw(x)
  157. x = self.bn2(x)
  158. x = self.act2(x)
  159. # Squeeze-and-excitation
  160. x = self.se(x)
  161. # Point-wise linear projection
  162. x = self.conv_pwl(x)
  163. x = self.bn3(x)
  164. if self.has_residual:
  165. x += residual
  166. return x
  167. class EfficientNet(nn.Cell):
  168. """
  169. EfficientNet
  170. """
  171. def __init__(self, cfgs, num_classes=1000):
  172. super().__init__(EfficientNet)
  173. # setting of inverted residual blocks
  174. self.cfgs = cfgs
  175. stem_size = 32
  176. self.num_classes_ = num_classes
  177. self.num_features_ = 1280
  178. self.conv_stem = nn.Conv2d(
  179. in_channels=3, out_channels=stem_size, kernel_size=3, stride=2, has_bias=False)
  180. self.bn1 = nn.BatchNorm2d(stem_size, eps=0.001) # momentum=0.1)
  181. self.act1 = Swish()
  182. in_chs = stem_size
  183. layers = [nn.SequentialCell([DepthwiseSeparableConv(in_chs, 16, 3, 1, se_ratio=4)]),
  184. nn.SequentialCell([InvertedResidual(16, 24, 3, 2, 0, 6, se_ratio=24),
  185. InvertedResidual(24, 24, 3, 1, 1, 6, se_ratio=24)]),
  186. nn.SequentialCell([InvertedResidual(24, 40, 5, 2, 0, 6, se_ratio=24),
  187. InvertedResidual(40, 40, 5, 1, 2, 6, se_ratio=24)]),
  188. nn.SequentialCell([InvertedResidual(40, 80, 3, 2, 0, 6, se_ratio=24),
  189. InvertedResidual(
  190. 80, 80, 3, 1, 1, 6, se_ratio=24),
  191. InvertedResidual(80, 80, 3, 1, 1, 6, se_ratio=24)]),
  192. nn.SequentialCell([InvertedResidual(80, 112, 5, 1, 2, 6, se_ratio=24),
  193. InvertedResidual(
  194. 112, 112, 5, 1, 2, 6, se_ratio=24),
  195. InvertedResidual(112, 112, 5, 1, 2, 6, se_ratio=24)]),
  196. nn.SequentialCell([InvertedResidual(112, 192, 5, 2, 0, 6, se_ratio=24),
  197. InvertedResidual(
  198. 192, 192, 5, 1, 2, 6, se_ratio=24),
  199. InvertedResidual(
  200. 192, 192, 5, 1, 2, 6, se_ratio=24),
  201. InvertedResidual(192, 192, 5, 1, 2, 6, se_ratio=24)]),
  202. nn.SequentialCell(
  203. [InvertedResidual(192, 320, 3, 1, 1, 6, se_ratio=24)])
  204. ]
  205. self.blocks = nn.SequentialCell(layers)
  206. self.conv_head = nn.Conv2d(
  207. in_channels=320, out_channels=self.num_features_, kernel_size=1)
  208. self.bn2 = nn.BatchNorm2d(self.num_features_, eps=0.001)
  209. self.act2 = Swish()
  210. self.global_pool = AdaptiveAvgPool(output_size=(1, 1))
  211. self.classifier = nn.Dense(self.num_features_, num_classes)
  212. self._initialize_weights()
  213. def construct(self, x):
  214. """
  215. construct
  216. """
  217. x = self.conv_stem(x)
  218. x = self.bn1(x)
  219. x = self.act1(x)
  220. x = self.blocks(x)
  221. x = self.conv_head(x)
  222. x = self.bn2(x)
  223. x = self.act2(x)
  224. x = self.global_pool(x)
  225. x = P.Reshape()(x, (-1, self.num_features_))
  226. x = self.classifier(x)
  227. return x
  228. def _initialize_weights(self):
  229. """
  230. _initialize_weights
  231. """
  232. def init_linear_weight(m):
  233. m.weight.set_data(Tensor(np.random.normal(
  234. 0, 0.01, m.weight.data.shape).astype("float32")))
  235. if m.bias is not None:
  236. m.bias.set_data(
  237. Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
  238. for m in self.cells():
  239. if isinstance(m, nn.Conv2d):
  240. n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  241. m.weight.set_data(Tensor(np.random.normal(0, np.sqrt(2. / n),
  242. m.weight.data.shape).astype("float32")))
  243. if m.bias is not None:
  244. m.bias.data.zero_()
  245. m.weight.requires_grad = True
  246. elif isinstance(m, nn.BatchNorm2d):
  247. m.gamma.set_data(
  248. Tensor(np.ones(m.gamma.data.shape, dtype="float32")))
  249. m.beta.set_data(
  250. Tensor(np.zeros(m.beta.data.shape, dtype="float32")))
  251. elif isinstance(m, nn.Dense):
  252. init_linear_weight(m)
  253. def effnet(**kwargs):
  254. """
  255. Constructs a EfficientNet model
  256. """
  257. cfgs = [
  258. # k, t, c, SE, HS, s
  259. [3, 1, 16, 1, 0, 2],
  260. [3, 4.5, 24, 0, 0, 2],
  261. [3, 3.67, 24, 0, 0, 1],
  262. [5, 4, 40, 1, 1, 2],
  263. [5, 6, 40, 1, 1, 1],
  264. [5, 6, 40, 1, 1, 1],
  265. [5, 3, 48, 1, 1, 1],
  266. [5, 3, 48, 1, 1, 1],
  267. [5, 6, 96, 1, 1, 2],
  268. [5, 6, 96, 1, 1, 1],
  269. [5, 6, 96, 1, 1, 1],
  270. ]
  271. return EfficientNet(cfgs, **kwargs)