You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tinynet.py 30 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Tinynet model definition"""
  16. import math
  17. import re
  18. from copy import deepcopy
  19. import mindspore.nn as nn
  20. import mindspore.common.dtype as mstype
  21. from mindspore.ops import operations as P
  22. from mindspore.common.initializer import Normal, Zero, One, Uniform
  23. from mindspore import ms_function
  24. from mindspore import Tensor
  25. # Imagenet constant values
  26. IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
  27. IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
  28. # model structure configurations for TinyNets, values are
  29. # (resolution multiplier, channel multiplier, depth multiplier)
  30. # codes are inspired and partially adapted from
  31. # https://github.com/rwightman/gen-efficientnet-pytorch
  32. TINYNET_CFG = {"a": (0.86, 1.0, 1.2),
  33. "b": (0.84, 0.75, 1.1),
  34. "c": (0.825, 0.54, 0.85),
  35. "d": (0.68, 0.54, 0.695),
  36. "e": (0.475, 0.51, 0.60)}
  37. relu = P.ReLU()
  38. sigmoid = P.Sigmoid()
  39. def _cfg(url='', **kwargs):
  40. return {
  41. 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  42. 'crop_pct': 0.875, 'interpolation': 'bicubic',
  43. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  44. 'first_conv': 'conv_stem', 'classifier': 'classifier',
  45. **kwargs
  46. }
  47. default_cfgs = {
  48. 'efficientnet_b0': _cfg(
  49. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0-d6904d92.pth'),
  50. 'efficientnet_b1': _cfg(
  51. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth',
  52. input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
  53. 'efficientnet_b2': _cfg(
  54. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2-cf78dc4d.pth',
  55. input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890),
  56. 'efficientnet_b3': _cfg(
  57. url='', input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
  58. 'efficientnet_b4': _cfg(
  59. url='', input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
  60. }
  61. _DEBUG = False
  62. # Default args for PyTorch BN impl
  63. _BN_MOMENTUM_PT_DEFAULT = 0.1
  64. _BN_EPS_PT_DEFAULT = 1e-5
  65. _BN_ARGS_PT = dict(momentum=_BN_MOMENTUM_PT_DEFAULT, eps=_BN_EPS_PT_DEFAULT)
  66. # Defaults used for Google/Tensorflow training of mobile networks /w
  67. # RMSprop as per papers and TF reference implementations. PT momentum
  68. # equiv for TF decay is (1 - TF decay)
  69. # NOTE: momentum varies btw .99 and .9997 depending on source
  70. # .99 in official TF TPU impl
  71. # .9997 (/w .999 in search space) for paper
  72. _BN_MOMENTUM_TF_DEFAULT = 1 - 0.99
  73. _BN_EPS_TF_DEFAULT = 1e-3
  74. _BN_ARGS_TF = dict(momentum=_BN_MOMENTUM_TF_DEFAULT, eps=_BN_EPS_TF_DEFAULT)
  75. def _initialize_weight_goog(shape=None, layer_type='conv', bias=False):
  76. """Google style weight initialization"""
  77. if layer_type not in ('conv', 'bn', 'fc'):
  78. raise ValueError(
  79. 'The layer type is not known, the supported are conv, bn and fc')
  80. if bias:
  81. return Zero()
  82. if layer_type == 'conv':
  83. assert isinstance(shape, (tuple, list)) and len(
  84. shape) == 3, 'The shape must be 3 scalars, and are in_chs, ks, out_chs respectively'
  85. n = shape[1] * shape[1] * shape[2]
  86. return Normal(math.sqrt(2.0 / n))
  87. if layer_type == 'bn':
  88. return One()
  89. assert isinstance(shape, (tuple, list)) and len(
  90. shape) == 2, 'The shape must be 2 scalars, and are in_chs, out_chs respectively'
  91. n = shape[1]
  92. init_range = 1.0 / math.sqrt(n)
  93. return Uniform(init_range)
  94. def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0,
  95. pad_mode='same', bias=False):
  96. """convolution wrapper"""
  97. weight_init_value = _initialize_weight_goog(
  98. shape=(in_channels, kernel_size, out_channels))
  99. bias_init_value = _initialize_weight_goog(bias=True) if bias else None
  100. if bias:
  101. return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
  102. padding=padding, pad_mode=pad_mode, weight_init=weight_init_value,
  103. has_bias=bias, bias_init=bias_init_value)
  104. return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
  105. padding=padding, pad_mode=pad_mode, weight_init=weight_init_value,
  106. has_bias=bias)
  107. def _conv1x1(in_channels, out_channels, stride=1, padding=0, pad_mode='same', bias=False):
  108. """1x1 convolution wrapper"""
  109. weight_init_value = _initialize_weight_goog(
  110. shape=(in_channels, 1, out_channels))
  111. bias_init_value = _initialize_weight_goog(bias=True) if bias else None
  112. if bias:
  113. return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride,
  114. padding=padding, pad_mode=pad_mode, weight_init=weight_init_value,
  115. has_bias=bias, bias_init=bias_init_value)
  116. return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride,
  117. padding=padding, pad_mode=pad_mode, weight_init=weight_init_value,
  118. has_bias=bias)
  119. def _conv_group(in_channels, out_channels, group, kernel_size=3, stride=1, padding=0,
  120. pad_mode='same', bias=False):
  121. """group convolution wrapper"""
  122. weight_init_value = _initialize_weight_goog(
  123. shape=(in_channels, kernel_size, out_channels))
  124. bias_init_value = _initialize_weight_goog(bias=True) if bias else None
  125. if bias:
  126. return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
  127. padding=padding, pad_mode=pad_mode, weight_init=weight_init_value,
  128. group=group, has_bias=bias, bias_init=bias_init_value)
  129. return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
  130. padding=padding, pad_mode=pad_mode, weight_init=weight_init_value,
  131. group=group, has_bias=bias)
  132. def _fused_bn(channels, momentum=0.1, eps=1e-4, gamma_init=1, beta_init=0):
  133. return nn.BatchNorm2d(channels, eps=eps, momentum=1-momentum, gamma_init=gamma_init,
  134. beta_init=beta_init)
  135. def _dense(in_channels, output_channels, bias=True, activation=None):
  136. weight_init_value = _initialize_weight_goog(shape=(in_channels, output_channels),
  137. layer_type='fc')
  138. bias_init_value = _initialize_weight_goog(bias=True) if bias else None
  139. if bias:
  140. return nn.Dense(in_channels, output_channels, weight_init=weight_init_value,
  141. bias_init=bias_init_value, has_bias=bias, activation=activation)
  142. return nn.Dense(in_channels, output_channels, weight_init=weight_init_value,
  143. has_bias=bias, activation=activation)
  144. def _resolve_bn_args(kwargs):
  145. bn_args = _BN_ARGS_TF.copy() if kwargs.pop(
  146. 'bn_tf', False) else _BN_ARGS_PT.copy()
  147. bn_momentum = kwargs.pop('bn_momentum', None)
  148. if bn_momentum is not None:
  149. bn_args['momentum'] = bn_momentum
  150. bn_eps = kwargs.pop('bn_eps', None)
  151. if bn_eps is not None:
  152. bn_args['eps'] = bn_eps
  153. return bn_args
  154. def _round_channels(channels, multiplier=1.0, divisor=8, channel_min=None):
  155. """Round number of filters based on depth multiplier."""
  156. if not multiplier:
  157. return channels
  158. channels *= multiplier
  159. channel_min = channel_min or divisor
  160. new_channels = max(
  161. int(channels + divisor / 2) // divisor * divisor,
  162. channel_min)
  163. # Make sure that round down does not go down by more than 10%.
  164. if new_channels < 0.9 * channels:
  165. new_channels += divisor
  166. return new_channels
  167. def _parse_ksize(ss):
  168. if ss.isdigit():
  169. return int(ss)
  170. return [int(k) for k in ss.split('.')]
  171. def _decode_block_str(block_str, depth_multiplier=1.0):
  172. """ Decode block definition string
  173. Gets a list of block arg (dicts) through a string notation of arguments.
  174. E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip
  175. All args can exist in any order with the exception of the leading string which
  176. is assumed to indicate the block type.
  177. leading string - block type (
  178. ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct)
  179. r - number of repeat blocks,
  180. k - kernel size,
  181. s - strides (1-9),
  182. e - expansion ratio,
  183. c - output channels,
  184. se - squeeze/excitation ratio
  185. n - activation fn ('re', 'r6', 'hs', or 'sw')
  186. Args:
  187. block_str: a string representation of block arguments.
  188. Returns:
  189. A list of block args (dicts)
  190. Raises:
  191. ValueError: if the string def not properly specified (TODO)
  192. """
  193. assert isinstance(block_str, str)
  194. ops = block_str.split('_')
  195. block_type = ops[0] # take the block type off the front
  196. ops = ops[1:]
  197. options = {}
  198. noskip = False
  199. for op in ops:
  200. if op == 'noskip':
  201. noskip = True
  202. elif op.startswith('n'):
  203. # activation fn
  204. key = op[0]
  205. v = op[1:]
  206. if v in ('re', 'r6', 'hs', 'sw'):
  207. print('not support')
  208. else:
  209. continue
  210. options[key] = value
  211. else:
  212. # all numeric options
  213. splits = re.split(r'(\d.*)', op)
  214. if len(splits) >= 2:
  215. key, value = splits[:2]
  216. options[key] = value
  217. act_fn = options['n'] if 'n' in options else None
  218. exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
  219. pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
  220. fake_in_chs = int(options['fc']) if 'fc' in options else 0
  221. num_repeat = int(options['r'])
  222. # each type of block has different valid arguments, fill accordingly
  223. if block_type == 'ir':
  224. block_args = dict(
  225. block_type=block_type,
  226. dw_kernel_size=_parse_ksize(options['k']),
  227. exp_kernel_size=exp_kernel_size,
  228. pw_kernel_size=pw_kernel_size,
  229. out_chs=int(options['c']),
  230. exp_ratio=float(options['e']),
  231. se_ratio=float(options['se']) if 'se' in options else None,
  232. stride=int(options['s']),
  233. act_fn=act_fn,
  234. noskip=noskip,
  235. )
  236. elif block_type in ('ds', 'dsa'):
  237. block_args = dict(
  238. block_type=block_type,
  239. dw_kernel_size=_parse_ksize(options['k']),
  240. pw_kernel_size=pw_kernel_size,
  241. out_chs=int(options['c']),
  242. se_ratio=float(options['se']) if 'se' in options else None,
  243. stride=int(options['s']),
  244. act_fn=act_fn,
  245. pw_act=block_type == 'dsa',
  246. noskip=block_type == 'dsa' or noskip,
  247. )
  248. elif block_type == 'er':
  249. block_args = dict(
  250. block_type=block_type,
  251. exp_kernel_size=_parse_ksize(options['k']),
  252. pw_kernel_size=pw_kernel_size,
  253. out_chs=int(options['c']),
  254. exp_ratio=float(options['e']),
  255. fake_in_chs=fake_in_chs,
  256. se_ratio=float(options['se']) if 'se' in options else None,
  257. stride=int(options['s']),
  258. act_fn=act_fn,
  259. noskip=noskip,
  260. )
  261. elif block_type == 'cn':
  262. block_args = dict(
  263. block_type=block_type,
  264. kernel_size=int(options['k']),
  265. out_chs=int(options['c']),
  266. stride=int(options['s']),
  267. act_fn=act_fn,
  268. )
  269. else:
  270. assert False, 'Unknown block type (%s)' % block_type
  271. return block_args, num_repeat
  272. def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'):
  273. """ Per-stage depth scaling
  274. Scales the block repeats in each stage. This depth scaling impl maintains
  275. compatibility with the EfficientNet scaling method, while allowing sensible
  276. scaling for other models that may have multiple block arg definitions in each stage.
  277. """
  278. # We scale the total repeat count for each stage, there may be multiple
  279. # block arg defs per stage so we need to sum.
  280. num_repeat = sum(repeats)
  281. if depth_trunc == 'round':
  282. # Truncating to int by rounding allows stages with few repeats to remain
  283. # proportionally smaller for longer. This is a good choice when stage definitions
  284. # include single repeat stages that we'd prefer to keep that way as long as possible
  285. num_repeat_scaled = max(1, round(num_repeat * depth_multiplier))
  286. else:
  287. # The default for EfficientNet truncates repeats to int via 'ceil'.
  288. # Any multiplier > 1.0 will result in an increased depth for every stage.
  289. num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier))
  290. # Proportionally distribute repeat count scaling to each block definition in the stage.
  291. # Allocation is done in reverse as it results in the first block being less likely to be scaled.
  292. # The first block makes less sense to repeat in most of the arch definitions.
  293. repeats_scaled = []
  294. for r in repeats[::-1]:
  295. rs = max(1, round((r / num_repeat * num_repeat_scaled)))
  296. repeats_scaled.append(rs)
  297. num_repeat -= r
  298. num_repeat_scaled -= rs
  299. repeats_scaled = repeats_scaled[::-1]
  300. # Apply the calculated scaling to each block arg in the stage
  301. sa_scaled = []
  302. for ba, rep in zip(stack_args, repeats_scaled):
  303. sa_scaled.extend([deepcopy(ba) for _ in range(rep)])
  304. return sa_scaled
  305. def _decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil'):
  306. """further decode the architecture definition into model-ready format"""
  307. arch_args = []
  308. for _, block_strings in enumerate(arch_def):
  309. assert isinstance(block_strings, list)
  310. stack_args = []
  311. repeats = []
  312. for block_str in block_strings:
  313. assert isinstance(block_str, str)
  314. ba, rep = _decode_block_str(block_str)
  315. stack_args.append(ba)
  316. repeats.append(rep)
  317. arch_args.append(_scale_stage_depth(
  318. stack_args, repeats, depth_multiplier, depth_trunc))
  319. return arch_args
  320. class Swish(nn.Cell):
  321. """swish activation function"""
  322. def __init__(self):
  323. super(Swish, self).__init__()
  324. self.sigmoid = P.Sigmoid()
  325. def construct(self, x):
  326. return x * self.sigmoid(x)
  327. @ms_function
  328. def swish(x):
  329. return x * nn.Sigmoid()(x)
  330. class BlockBuilder(nn.Cell):
  331. """build efficient-net convolution blocks"""
  332. def __init__(self, builder_in_channels, builder_block_args, channel_multiplier=1.0,
  333. channel_divisor=8, channel_min=None, pad_type='', act_fn=None,
  334. se_gate_fn=sigmoid, se_reduce_mid=False, bn_args=None,
  335. drop_connect_rate=0., verbose=False):
  336. super(BlockBuilder, self).__init__()
  337. self.channel_multiplier = channel_multiplier
  338. self.channel_divisor = channel_divisor
  339. self.channel_min = channel_min
  340. self.pad_type = pad_type
  341. self.act_fn = Swish()
  342. self.se_gate_fn = se_gate_fn
  343. self.se_reduce_mid = se_reduce_mid
  344. self.bn_args = bn_args
  345. self.drop_connect_rate = drop_connect_rate
  346. self.verbose = verbose
  347. # updated during build
  348. self.in_chs = None
  349. self.block_idx = 0
  350. self.block_count = 0
  351. self.layer = self._make_layer(builder_in_channels, builder_block_args)
  352. def _round_channels(self, chs):
  353. return _round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min)
  354. def _make_block(self, ba):
  355. """make the current block based on the block argument"""
  356. bt = ba.pop('block_type')
  357. ba['in_chs'] = self.in_chs
  358. ba['out_chs'] = self._round_channels(ba['out_chs'])
  359. if 'fake_in_chs' in ba and ba['fake_in_chs']:
  360. # this is a hack to work around mismatch in origin impl input filters
  361. ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs'])
  362. ba['bn_args'] = self.bn_args
  363. ba['pad_type'] = self.pad_type
  364. # block act fn overrides the model default
  365. ba['act_fn'] = ba['act_fn'] if ba['act_fn'] is not None else self.act_fn
  366. assert ba['act_fn'] is not None
  367. if bt == 'ir':
  368. ba['drop_connect_rate'] = self.drop_connect_rate * \
  369. self.block_idx / self.block_count
  370. ba['se_gate_fn'] = self.se_gate_fn
  371. ba['se_reduce_mid'] = self.se_reduce_mid
  372. block = InvertedResidual(**ba)
  373. elif bt in ('ds', 'dsa'):
  374. ba['drop_connect_rate'] = self.drop_connect_rate * \
  375. self.block_idx / self.block_count
  376. block = DepthwiseSeparableConv(**ba)
  377. else:
  378. assert False, 'Uknkown block type (%s) while building model.' % bt
  379. self.in_chs = ba['out_chs']
  380. return block
  381. def _make_stack(self, stack_args):
  382. """make a stack of blocks"""
  383. blocks = []
  384. # each stack (stage) contains a list of block arguments
  385. for i, ba in enumerate(stack_args):
  386. if i >= 1:
  387. # only the first block in any stack can have a stride > 1
  388. ba['stride'] = 1
  389. block = self._make_block(ba)
  390. blocks.append(block)
  391. self.block_idx += 1 # incr global idx (across all stacks)
  392. return nn.SequentialCell(blocks)
  393. def _make_layer(self, in_chs, block_args):
  394. """ Build the entire layer
  395. Args:
  396. in_chs: Number of input-channels passed to first block
  397. block_args: A list of lists, outer list defines stages, inner
  398. list contains strings defining block configuration(s)
  399. Return:
  400. List of block stacks (each stack wrapped in nn.Sequential)
  401. """
  402. self.in_chs = in_chs
  403. self.block_count = sum([len(x) for x in block_args])
  404. self.block_idx = 0
  405. blocks = []
  406. # outer list of block_args defines the stacks ('stages' by some conventions)
  407. for _, stack in enumerate(block_args):
  408. assert isinstance(stack, list)
  409. stack = self._make_stack(stack)
  410. blocks.append(stack)
  411. return nn.SequentialCell(blocks)
  412. def construct(self, x):
  413. return self.layer(x)
  414. class DropConnect(nn.Cell):
  415. """drop connect implementation"""
  416. def __init__(self, drop_connect_rate=0., seed0=0, seed1=0):
  417. super(DropConnect, self).__init__()
  418. self.shape = P.Shape()
  419. self.dtype = P.DType()
  420. self.keep_prob = 1 - drop_connect_rate
  421. self.dropout = P.Dropout(keep_prob=self.keep_prob)
  422. self.keep_prob_tensor = Tensor(self.keep_prob, dtype=mstype.float32)
  423. def construct(self, x):
  424. shape = self.shape(x)
  425. dtype = self.dtype(x)
  426. ones_tensor = P.Fill()(dtype, (shape[0], 1, 1, 1), 1)
  427. _, mask = self.dropout(ones_tensor)
  428. x = x * mask
  429. x = x / self.keep_prob_tensor
  430. return x
  431. def drop_connect(inputs, training=False, drop_connect_rate=0.):
  432. if not training:
  433. return inputs
  434. return DropConnect(drop_connect_rate)(inputs)
  435. class SqueezeExcite(nn.Cell):
  436. """squeeze-excite implementation"""
  437. def __init__(self, in_chs, reduce_chs=None, act_fn=relu, gate_fn=sigmoid):
  438. super(SqueezeExcite, self).__init__()
  439. self.act_fn = Swish()
  440. self.gate_fn = gate_fn
  441. reduce_chs = reduce_chs or in_chs
  442. self.conv_reduce = nn.Conv2d(in_channels=in_chs, out_channels=reduce_chs,
  443. kernel_size=1, has_bias=True, pad_mode='pad')
  444. self.conv_expand = nn.Conv2d(in_channels=reduce_chs, out_channels=in_chs,
  445. kernel_size=1, has_bias=True, pad_mode='pad')
  446. self.avg_global_pool = P.ReduceMean(keep_dims=True)
  447. def construct(self, x):
  448. x_se = self.avg_global_pool(x, (2, 3))
  449. x_se = self.conv_reduce(x_se)
  450. x_se = self.act_fn(x_se)
  451. x_se = self.conv_expand(x_se)
  452. x_se = self.gate_fn(x_se)
  453. x = x * x_se
  454. return x
  455. class DepthwiseSeparableConv(nn.Cell):
  456. """depth-wise convolution -> (squeeze-excite) -> point-wise convolution"""
  457. def __init__(self, in_chs, out_chs, dw_kernel_size=3,
  458. stride=1, pad_type='', act_fn=relu, noskip=False,
  459. pw_kernel_size=1, pw_act=False, se_ratio=0., se_gate_fn=sigmoid,
  460. bn_args=None, drop_connect_rate=0.):
  461. super(DepthwiseSeparableConv, self).__init__()
  462. assert stride in [1, 2], 'stride must be 1 or 2'
  463. self.has_se = se_ratio is not None and se_ratio > 0.
  464. self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
  465. self.has_pw_act = pw_act
  466. self.act_fn = Swish()
  467. self.drop_connect_rate = drop_connect_rate
  468. self.conv_dw = nn.Conv2d(in_chs, in_chs, dw_kernel_size, stride, pad_mode="pad",
  469. padding=int(dw_kernel_size/2), has_bias=False, group=in_chs,
  470. weight_init=_initialize_weight_goog(shape=[1, dw_kernel_size, in_chs]))
  471. self.bn1 = _fused_bn(in_chs, **bn_args)
  472. if self.has_se:
  473. self.se = SqueezeExcite(in_chs, reduce_chs=max(1, int(in_chs * se_ratio)),
  474. act_fn=act_fn, gate_fn=se_gate_fn)
  475. self.conv_pw = _conv1x1(in_chs, out_chs)
  476. self.bn2 = _fused_bn(out_chs, **bn_args)
  477. def construct(self, x):
  478. """forward the depthwise separable conv"""
  479. identity = x
  480. x = self.conv_dw(x)
  481. x = self.bn1(x)
  482. x = self.act_fn(x)
  483. if self.has_se:
  484. x = self.se(x)
  485. x = self.conv_pw(x)
  486. x = self.bn2(x)
  487. if self.has_pw_act:
  488. x = self.act_fn(x)
  489. if self.has_residual:
  490. if self.drop_connect_rate > 0.:
  491. x = drop_connect(x, self.training, self.drop_connect_rate)
  492. x = x + identity
  493. return x
  494. class InvertedResidual(nn.Cell):
  495. """inverted-residual block implementation"""
  496. def __init__(self, in_chs, out_chs, dw_kernel_size=3, stride=1,
  497. pad_type='', act_fn=relu, pw_kernel_size=1,
  498. noskip=False, exp_ratio=1., exp_kernel_size=1, se_ratio=0.,
  499. se_reduce_mid=False, se_gate_fn=sigmoid, shuffle_type=None,
  500. bn_args=None, drop_connect_rate=0.):
  501. super(InvertedResidual, self).__init__()
  502. mid_chs = int(in_chs * exp_ratio)
  503. self.has_se = se_ratio is not None and se_ratio > 0.
  504. self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
  505. self.act_fn = Swish()
  506. self.drop_connect_rate = drop_connect_rate
  507. self.conv_pw = _conv(in_chs, mid_chs, exp_kernel_size)
  508. self.bn1 = _fused_bn(mid_chs, **bn_args)
  509. self.shuffle_type = shuffle_type
  510. if self.shuffle_type is not None and isinstance(exp_kernel_size, list):
  511. self.shuffle = None
  512. self.conv_dw = nn.Conv2d(mid_chs, mid_chs, dw_kernel_size, stride, pad_mode="pad",
  513. padding=int(dw_kernel_size/2), has_bias=False, group=mid_chs,
  514. weight_init=_initialize_weight_goog(shape=[1, dw_kernel_size, mid_chs]))
  515. self.bn2 = _fused_bn(mid_chs, **bn_args)
  516. if self.has_se:
  517. se_base_chs = mid_chs if se_reduce_mid else in_chs
  518. self.se = SqueezeExcite(
  519. mid_chs, reduce_chs=max(1, int(se_base_chs * se_ratio)),
  520. act_fn=act_fn, gate_fn=se_gate_fn
  521. )
  522. self.conv_pwl = _conv(mid_chs, out_chs, pw_kernel_size)
  523. self.bn3 = _fused_bn(out_chs, **bn_args)
  524. def construct(self, x):
  525. """forward the inverted-residual block"""
  526. identity = x
  527. x = self.conv_pw(x)
  528. x = self.bn1(x)
  529. x = self.act_fn(x)
  530. x = self.conv_dw(x)
  531. x = self.bn2(x)
  532. x = self.act_fn(x)
  533. if self.has_se:
  534. x = self.se(x)
  535. x = self.conv_pwl(x)
  536. x = self.bn3(x)
  537. if self.has_residual:
  538. if self.drop_connect_rate > 0:
  539. x = drop_connect(x, self.training, self.drop_connect_rate)
  540. x = x + identity
  541. return x
  542. class GenEfficientNet(nn.Cell):
  543. """Generate EfficientNet architecture"""
  544. def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_features=1280,
  545. channel_multiplier=1.0, channel_divisor=8, channel_min=None,
  546. pad_type='', act_fn=relu, drop_rate=0., drop_connect_rate=0.,
  547. se_gate_fn=sigmoid, se_reduce_mid=False, bn_args=None,
  548. global_pool='avg', head_conv='default', weight_init='goog'):
  549. super(GenEfficientNet, self).__init__()
  550. bn_args = _BN_ARGS_PT if bn_args is None else bn_args
  551. self.num_classes = num_classes
  552. self.drop_rate = drop_rate
  553. self.num_features = num_features
  554. self.conv_stem = _conv(in_chans, stem_size, 3,
  555. stride=2, padding=1, pad_mode='pad')
  556. self.bn1 = _fused_bn(stem_size, **bn_args)
  557. self.act_fn = Swish()
  558. in_chans = stem_size
  559. self.blocks = BlockBuilder(in_chans, block_args, channel_multiplier,
  560. channel_divisor, channel_min,
  561. pad_type, act_fn, se_gate_fn, se_reduce_mid,
  562. bn_args, drop_connect_rate, verbose=_DEBUG)
  563. in_chs = self.blocks.in_chs
  564. if not head_conv or head_conv == 'none':
  565. self.efficient_head = False
  566. self.conv_head = None
  567. assert in_chs == self.num_features
  568. else:
  569. self.efficient_head = head_conv == 'efficient'
  570. self.conv_head = _conv1x1(in_chs, self.num_features)
  571. self.bn2 = None if self.efficient_head else _fused_bn(
  572. self.num_features, **bn_args)
  573. self.global_pool = P.ReduceMean(keep_dims=True)
  574. self.classifier = _dense(self.num_features, self.num_classes)
  575. self.reshape = P.Reshape()
  576. self.shape = P.Shape()
  577. self.drop_out = nn.Dropout(keep_prob=1-self.drop_rate)
  578. def construct(self, x):
  579. """efficient net entry point"""
  580. x = self.conv_stem(x)
  581. x = self.bn1(x)
  582. x = self.act_fn(x)
  583. x = self.blocks(x)
  584. if self.efficient_head:
  585. x = self.global_pool(x, (2, 3))
  586. x = self.conv_head(x)
  587. x = self.act_fn(x)
  588. x = self.reshape(self.shape(x)[0], -1)
  589. else:
  590. if self.conv_head is not None:
  591. x = self.conv_head(x)
  592. x = self.bn2(x)
  593. x = self.act_fn(x)
  594. x = self.global_pool(x, (2, 3))
  595. x = self.reshape(x, (self.shape(x)[0], -1))
  596. if self.training and self.drop_rate > 0.:
  597. x = self.drop_out(x)
  598. return self.classifier(x)
  599. def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=1000, **kwargs):
  600. """Creates an EfficientNet model.
  601. Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
  602. Paper: https://arxiv.org/abs/1905.11946
  603. EfficientNet params
  604. name: (channel_multiplier, depth_multiplier, resolution, dropout_rate)
  605. 'efficientnet-b0': (1.0, 1.0, 224, 0.2),
  606. 'efficientnet-b1': (1.0, 1.1, 240, 0.2),
  607. 'efficientnet-b2': (1.1, 1.2, 260, 0.3),
  608. 'efficientnet-b3': (1.2, 1.4, 300, 0.3),
  609. 'efficientnet-b4': (1.4, 1.8, 380, 0.4),
  610. 'efficientnet-b5': (1.6, 2.2, 456, 0.4),
  611. 'efficientnet-b6': (1.8, 2.6, 528, 0.5),
  612. 'efficientnet-b7': (2.0, 3.1, 600, 0.5),
  613. Args:
  614. channel_multiplier (int): multiplier to number of channels per layer
  615. depth_multiplier (int): multiplier to number of repeats per stage
  616. """
  617. arch_def = [
  618. ['ds_r1_k3_s1_e1_c16_se0.25'],
  619. ['ir_r2_k3_s2_e6_c24_se0.25'],
  620. ['ir_r2_k5_s2_e6_c40_se0.25'],
  621. ['ir_r3_k3_s2_e6_c80_se0.25'],
  622. ['ir_r3_k5_s1_e6_c112_se0.25'],
  623. ['ir_r4_k5_s2_e6_c192_se0.25'],
  624. ['ir_r1_k3_s1_e6_c320_se0.25'],
  625. ]
  626. num_features = max(1280, _round_channels(
  627. 1280, channel_multiplier, 8, None))
  628. model = GenEfficientNet(
  629. _decode_arch_def(arch_def, depth_multiplier, depth_trunc='round'),
  630. num_classes=num_classes,
  631. stem_size=32,
  632. channel_multiplier=channel_multiplier,
  633. num_features=num_features,
  634. bn_args=_resolve_bn_args(kwargs),
  635. act_fn=Swish,
  636. **kwargs)
  637. return model
  638. def tinynet(sub_model="c", num_classes=1000, in_chans=3, **kwargs):
  639. """ TinyNet Models """
  640. # choose a sub model
  641. r, w, d = TINYNET_CFG[sub_model]
  642. default_cfg = default_cfgs['efficientnet_b0']
  643. assert default_cfg['input_size'] == (3, 224, 224), "All tinynet models are \
  644. evolved from Efficient-B0, which has input dimension of 3*224*224"
  645. channel, height, width = default_cfg['input_size']
  646. height = int(r * height)
  647. width = int(r * width)
  648. default_cfg['input_size'] = (channel, height, width)
  649. print("Data processing configuration for current model + dataset:")
  650. print("input_size:", default_cfg['input_size'])
  651. print("channel mutiplier:%s, depth multiplier:%s, resolution multiplier:%s" % (w, d, r))
  652. model = _gen_efficientnet(
  653. channel_multiplier=w, depth_multiplier=d,
  654. num_classes=num_classes, in_chans=in_chans, **kwargs)
  655. model.default_cfg = default_cfg
  656. return model