You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ghostnet.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """GhostNet model define"""
  16. from functools import partial
  17. import math
  18. import numpy as np
  19. import mindspore.nn as nn
  20. from mindspore.ops import operations as P
  21. from mindspore import Tensor
  22. __all__ = ['ghostnet']
  23. def _make_divisible(x, divisor=4):
  24. return int(np.ceil(x * 1. / divisor) * divisor)
  25. class MyHSigmoid(nn.Cell):
  26. """
  27. Hard Sigmoid definition.
  28. Args:
  29. Returns:
  30. Tensor, output tensor.
  31. Examples:
  32. >>> MyHSigmoid()
  33. """
  34. def __init__(self):
  35. super(MyHSigmoid, self).__init__()
  36. self.relu6 = nn.ReLU6()
  37. def construct(self, x):
  38. return self.relu6(x + 3.) * 0.16666667
  39. class Activation(nn.Cell):
  40. """
  41. Activation definition.
  42. Args:
  43. act_func(string): activation name.
  44. Returns:
  45. Tensor, output tensor.
  46. """
  47. def __init__(self, act_func):
  48. super(Activation, self).__init__()
  49. if act_func == 'relu':
  50. self.act = nn.ReLU()
  51. elif act_func == 'relu6':
  52. self.act = nn.ReLU6()
  53. elif act_func in ('hsigmoid', 'hard_sigmoid'):
  54. self.act = MyHSigmoid()
  55. elif act_func in ('hswish', 'hard_swish'):
  56. self.act = nn.HSwish()
  57. else:
  58. raise NotImplementedError
  59. def construct(self, x):
  60. return self.act(x)
  61. class GlobalAvgPooling(nn.Cell):
  62. """
  63. Global avg pooling definition.
  64. Args:
  65. Returns:
  66. Tensor, output tensor.
  67. Examples:
  68. >>> GlobalAvgPooling()
  69. """
  70. def __init__(self, keep_dims=False):
  71. super(GlobalAvgPooling, self).__init__()
  72. self.mean = P.ReduceMean(keep_dims=keep_dims)
  73. def construct(self, x):
  74. x = self.mean(x, (2, 3))
  75. return x
  76. class SE(nn.Cell):
  77. """
  78. SE warpper definition.
  79. Args:
  80. num_out (int): Output channel.
  81. ratio (int): middle output ratio.
  82. Returns:
  83. Tensor, output tensor.
  84. Examples:
  85. >>> SE(4)
  86. """
  87. def __init__(self, num_out, ratio=4):
  88. super(SE, self).__init__()
  89. num_mid = _make_divisible(num_out // ratio)
  90. self.pool = GlobalAvgPooling(keep_dims=True)
  91. self.conv_reduce = nn.Conv2d(in_channels=num_out, out_channels=num_mid,
  92. kernel_size=1, has_bias=True, pad_mode='pad')
  93. self.act1 = Activation('relu')
  94. self.conv_expand = nn.Conv2d(in_channels=num_mid, out_channels=num_out,
  95. kernel_size=1, has_bias=True, pad_mode='pad')
  96. self.act2 = Activation('hsigmoid')
  97. self.mul = P.Mul()
  98. def construct(self, x):
  99. out = self.pool(x)
  100. out = self.conv_reduce(out)
  101. out = self.act1(out)
  102. out = self.conv_expand(out)
  103. out = self.act2(out)
  104. out = self.mul(x, out)
  105. return out
  106. class ConvUnit(nn.Cell):
  107. """
  108. ConvUnit warpper definition.
  109. Args:
  110. num_in (int): Input channel.
  111. num_out (int): Output channel.
  112. kernel_size (int): Input kernel size.
  113. stride (int): Stride size.
  114. padding (int): Padding number.
  115. num_groups (int): Output num group.
  116. use_act (bool): Used activation or not.
  117. act_type (string): Activation type.
  118. Returns:
  119. Tensor, output tensor.
  120. Examples:
  121. >>> ConvUnit(3, 3)
  122. """
  123. def __init__(self, num_in, num_out, kernel_size=1, stride=1, padding=0, num_groups=1,
  124. use_act=True, act_type='relu'):
  125. super(ConvUnit, self).__init__()
  126. self.conv = nn.Conv2d(in_channels=num_in,
  127. out_channels=num_out,
  128. kernel_size=kernel_size,
  129. stride=stride,
  130. padding=padding,
  131. group=num_groups,
  132. has_bias=False,
  133. pad_mode='pad')
  134. self.bn = nn.BatchNorm2d(num_out)
  135. self.use_act = use_act
  136. self.act = Activation(act_type) if use_act else None
  137. def construct(self, x):
  138. out = self.conv(x)
  139. out = self.bn(out)
  140. if self.use_act:
  141. out = self.act(out)
  142. return out
  143. class GhostModule(nn.Cell):
  144. """
  145. GhostModule warpper definition.
  146. Args:
  147. num_in (int): Input channel.
  148. num_out (int): Output channel.
  149. kernel_size (int): Input kernel size.
  150. stride (int): Stride size.
  151. padding (int): Padding number.
  152. ratio (int): Reduction ratio.
  153. dw_size (int): kernel size of cheap operation.
  154. use_act (bool): Used activation or not.
  155. act_type (string): Activation type.
  156. Returns:
  157. Tensor, output tensor.
  158. Examples:
  159. >>> GhostModule(3, 3)
  160. """
  161. def __init__(self, num_in, num_out, kernel_size=1, stride=1, padding=0, ratio=2, dw_size=3,
  162. use_act=True, act_type='relu'):
  163. super(GhostModule, self).__init__()
  164. init_channels = math.ceil(num_out / ratio)
  165. new_channels = init_channels * (ratio - 1)
  166. self.primary_conv = ConvUnit(num_in, init_channels, kernel_size=kernel_size, stride=stride, padding=padding,
  167. num_groups=1, use_act=use_act, act_type='relu')
  168. self.cheap_operation = ConvUnit(init_channels, new_channels, kernel_size=dw_size, stride=1, padding=dw_size//2,
  169. num_groups=init_channels, use_act=use_act, act_type='relu')
  170. self.concat = P.Concat(axis=1)
  171. def construct(self, x):
  172. x1 = self.primary_conv(x)
  173. x2 = self.cheap_operation(x1)
  174. return self.concat((x1, x2))
  175. class GhostBottleneck(nn.Cell):
  176. """
  177. GhostBottleneck warpper definition.
  178. Args:
  179. num_in (int): Input channel.
  180. num_mid (int): Middle channel.
  181. num_out (int): Output channel.
  182. kernel_size (int): Input kernel size.
  183. stride (int): Stride size.
  184. act_type (str): Activation type.
  185. use_se (bool): Use SE warpper or not.
  186. Returns:
  187. Tensor, output tensor.
  188. Examples:
  189. >>> GhostBottleneck(16, 3, 1, 1)
  190. """
  191. def __init__(self, num_in, num_mid, num_out, kernel_size, stride=1, act_type='relu', use_se=False):
  192. super(GhostBottleneck, self).__init__()
  193. self.ghost1 = GhostModule(num_in, num_mid, kernel_size=1,
  194. stride=1, padding=0, act_type=act_type)
  195. self.use_dw = stride > 1
  196. self.dw = None
  197. if self.use_dw:
  198. self.dw = ConvUnit(num_mid, num_mid, kernel_size=kernel_size, stride=stride,
  199. padding=self._get_pad(kernel_size), act_type=act_type, num_groups=num_mid, use_act=False)
  200. self.use_se = use_se
  201. if use_se:
  202. self.se = SE(num_mid)
  203. self.ghost2 = GhostModule(num_mid, num_out, kernel_size=1, stride=1,
  204. padding=0, act_type=act_type, use_act=False)
  205. self.down_sample = False
  206. if num_in != num_out or stride != 1:
  207. self.down_sample = True
  208. self.shortcut = None
  209. if self.down_sample:
  210. self.shortcut = nn.SequentialCell([
  211. ConvUnit(num_in, num_in, kernel_size=kernel_size, stride=stride,
  212. padding=self._get_pad(kernel_size), num_groups=num_in, use_act=False),
  213. ConvUnit(num_in, num_out, kernel_size=1, stride=1,
  214. padding=0, num_groups=1, use_act=False),
  215. ])
  216. self.add = P.TensorAdd()
  217. def construct(self, x):
  218. r"""construct of ghostnet"""
  219. shortcut = x
  220. out = self.ghost1(x)
  221. if self.use_dw:
  222. out = self.dw(out)
  223. if self.use_se:
  224. out = self.se(out)
  225. out = self.ghost2(out)
  226. if self.down_sample:
  227. shortcut = self.shortcut(shortcut)
  228. out = self.add(shortcut, out)
  229. return out
  230. def _get_pad(self, kernel_size):
  231. """set the padding number"""
  232. pad = 0
  233. if kernel_size == 1:
  234. pad = 0
  235. elif kernel_size == 3:
  236. pad = 1
  237. elif kernel_size == 5:
  238. pad = 2
  239. elif kernel_size == 7:
  240. pad = 3
  241. else:
  242. raise NotImplementedError
  243. return pad
  244. class GhostNet(nn.Cell):
  245. """
  246. GhostNet architecture.
  247. Args:
  248. model_cfgs (Cell): number of classes.
  249. num_classes (int): Output number classes.
  250. multiplier (int): Channels multiplier for round to 8/16 and others. Default is 1.
  251. final_drop (float): Dropout number.
  252. round_nearest (list): Channel round to . Default is 8.
  253. Returns:
  254. Tensor, output tensor.
  255. Examples:
  256. >>> GhostNet(num_classes=1000)
  257. """
  258. def __init__(self, model_cfgs, num_classes=1000, multiplier=1., final_drop=0., round_nearest=8):
  259. super(GhostNet, self).__init__()
  260. self.cfgs = model_cfgs['cfg']
  261. self.inplanes = 16
  262. first_conv_in_channel = 3
  263. first_conv_out_channel = _make_divisible(multiplier * self.inplanes)
  264. self.conv_stem = nn.Conv2d(in_channels=first_conv_in_channel,
  265. out_channels=first_conv_out_channel,
  266. kernel_size=3, padding=1, stride=2,
  267. has_bias=False, pad_mode='pad')
  268. self.bn1 = nn.BatchNorm2d(first_conv_out_channel)
  269. self.act1 = Activation('relu')
  270. self.blocks = []
  271. for layer_cfg in self.cfgs:
  272. self.blocks.append(self._make_layer(kernel_size=layer_cfg[0],
  273. exp_ch=_make_divisible(
  274. multiplier * layer_cfg[1]),
  275. out_channel=_make_divisible(
  276. multiplier * layer_cfg[2]),
  277. use_se=layer_cfg[3],
  278. act_func=layer_cfg[4],
  279. stride=layer_cfg[5]))
  280. output_channel = _make_divisible(
  281. multiplier * model_cfgs["cls_ch_squeeze"])
  282. self.blocks.append(ConvUnit(_make_divisible(multiplier * self.cfgs[-1][2]), output_channel,
  283. kernel_size=1, stride=1, padding=0, num_groups=1, use_act=True))
  284. self.blocks = nn.SequentialCell(self.blocks)
  285. self.global_pool = GlobalAvgPooling(keep_dims=True)
  286. self.conv_head = nn.Conv2d(in_channels=output_channel,
  287. out_channels=model_cfgs['cls_ch_expand'],
  288. kernel_size=1, padding=0, stride=1,
  289. has_bias=True, pad_mode='pad')
  290. self.act2 = Activation('relu')
  291. self.squeeze = P.Flatten()
  292. self.final_drop = final_drop
  293. if self.final_drop > 0:
  294. self.dropout = nn.Dropout(self.final_drop)
  295. self.classifier = nn.Dense(
  296. model_cfgs['cls_ch_expand'], num_classes, has_bias=True)
  297. self._initialize_weights()
  298. def construct(self, x):
  299. r"""construct of GhostNet"""
  300. x = self.conv_stem(x)
  301. x = self.bn1(x)
  302. x = self.act1(x)
  303. x = self.blocks(x)
  304. x = self.global_pool(x)
  305. x = self.conv_head(x)
  306. x = self.act2(x)
  307. x = self.squeeze(x)
  308. if self.final_drop > 0:
  309. x = self.dropout(x)
  310. x = self.classifier(x)
  311. return x
  312. def _make_layer(self, kernel_size, exp_ch, out_channel, use_se, act_func, stride=1):
  313. mid_planes = exp_ch
  314. out_planes = out_channel
  315. layer = GhostBottleneck(self.inplanes, mid_planes, out_planes,
  316. kernel_size, stride=stride, act_type=act_func, use_se=use_se)
  317. self.inplanes = out_planes
  318. return layer
  319. def _initialize_weights(self):
  320. """
  321. Initialize weights.
  322. Args:
  323. Returns:
  324. None.
  325. Examples:
  326. >>> _initialize_weights()
  327. """
  328. self.init_parameters_data()
  329. for _, m in self.cells_and_names():
  330. if isinstance(m, (nn.Conv2d)):
  331. n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  332. m.weight.set_parameter_data(Tensor(np.random.normal(0, np.sqrt(2. / n),
  333. m.weight.data.shape).astype("float32")))
  334. if m.bias is not None:
  335. m.bias.set_parameter_data(
  336. Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
  337. elif isinstance(m, nn.BatchNorm2d):
  338. m.gamma.set_parameter_data(
  339. Tensor(np.ones(m.gamma.data.shape, dtype="float32")))
  340. m.beta.set_parameter_data(
  341. Tensor(np.zeros(m.beta.data.shape, dtype="float32")))
  342. elif isinstance(m, nn.Dense):
  343. m.weight.set_parameter_data(Tensor(np.random.normal(
  344. 0, 0.01, m.weight.data.shape).astype("float32")))
  345. if m.bias is not None:
  346. m.bias.set_parameter_data(
  347. Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
  348. def ghostnet(model_name, **kwargs):
  349. """
  350. Constructs a GhostNet model
  351. """
  352. model_cfgs = {
  353. "1x": {
  354. "cfg": [
  355. # k, exp, c, se, nl, s,
  356. # stage1
  357. [3, 16, 16, False, 'relu', 1],
  358. # stage2
  359. [3, 48, 24, False, 'relu', 2],
  360. [3, 72, 24, False, 'relu', 1],
  361. # stage3
  362. [5, 72, 40, True, 'relu', 2],
  363. [5, 120, 40, True, 'relu', 1],
  364. # stage4
  365. [3, 240, 80, False, 'relu', 2],
  366. [3, 200, 80, False, 'relu', 1],
  367. [3, 184, 80, False, 'relu', 1],
  368. [3, 184, 80, False, 'relu', 1],
  369. [3, 480, 112, True, 'relu', 1],
  370. [3, 672, 112, True, 'relu', 1],
  371. # stage5
  372. [5, 672, 160, True, 'relu', 2],
  373. [5, 960, 160, False, 'relu', 1],
  374. [5, 960, 160, True, 'relu', 1],
  375. [5, 960, 160, False, 'relu', 1],
  376. [5, 960, 160, True, 'relu', 1]],
  377. "cls_ch_squeeze": 960,
  378. "cls_ch_expand": 1280,
  379. },
  380. "nose_1x": {
  381. "cfg": [
  382. # k, exp, c, se, nl, s,
  383. # stage1
  384. [3, 16, 16, False, 'relu', 1],
  385. # stage2
  386. [3, 48, 24, False, 'relu', 2],
  387. [3, 72, 24, False, 'relu', 1],
  388. # stage3
  389. [5, 72, 40, False, 'relu', 2],
  390. [5, 120, 40, False, 'relu', 1],
  391. # stage4
  392. [3, 240, 80, False, 'relu', 2],
  393. [3, 200, 80, False, 'relu', 1],
  394. [3, 184, 80, False, 'relu', 1],
  395. [3, 184, 80, False, 'relu', 1],
  396. [3, 480, 112, False, 'relu', 1],
  397. [3, 672, 112, False, 'relu', 1],
  398. # stage5
  399. [5, 672, 160, False, 'relu', 2],
  400. [5, 960, 160, False, 'relu', 1],
  401. [5, 960, 160, False, 'relu', 1],
  402. [5, 960, 160, False, 'relu', 1],
  403. [5, 960, 160, False, 'relu', 1]],
  404. "cls_ch_squeeze": 960,
  405. "cls_ch_expand": 1280,
  406. }
  407. }
  408. return GhostNet(model_cfgs[model_name], **kwargs)
  409. ghostnet_1x = partial(ghostnet, model_name="1x", final_drop=0.8)
  410. ghostnet_nose_1x = partial(ghostnet, model_name="nose_1x", final_drop=0.8)