You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tinynet.py 31 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Tinynet model definition"""
  16. import math
  17. import re
  18. from copy import deepcopy
  19. import mindspore.nn as nn
  20. from mindspore.ops import operations as P
  21. from mindspore.common.initializer import Normal, Zero, One, initializer, Uniform
  22. from mindspore import context, ms_function
  23. from mindspore.common.parameter import Parameter
  24. # Imagenet constant values
  25. IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
  26. IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
  27. # model structure configurations for TinyNets, values are
  28. # (resolution multiplier, channel multiplier, depth multiplier)
  29. # only tinynet-c is availiable for now, we will release other tinynet
  30. # models soon
  31. # codes are inspired and partially adapted from
  32. # https://github.com/rwightman/gen-efficientnet-pytorch
  33. TINYNET_CFG = {"c": (0.825, 0.54, 0.85)}
  34. relu = P.ReLU()
  35. sigmoid = P.Sigmoid()
  36. def _cfg(url='', **kwargs):
  37. return {
  38. 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  39. 'crop_pct': 0.875, 'interpolation': 'bicubic',
  40. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  41. 'first_conv': 'conv_stem', 'classifier': 'classifier',
  42. **kwargs
  43. }
  44. default_cfgs = {
  45. 'efficientnet_b0': _cfg(
  46. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0-d6904d92.pth'),
  47. 'efficientnet_b1': _cfg(
  48. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth',
  49. input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
  50. 'efficientnet_b2': _cfg(
  51. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2-cf78dc4d.pth',
  52. input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890),
  53. 'efficientnet_b3': _cfg(
  54. url='', input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
  55. 'efficientnet_b4': _cfg(
  56. url='', input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
  57. }
  58. _DEBUG = False
  59. # Default args for PyTorch BN impl
  60. _BN_MOMENTUM_PT_DEFAULT = 0.1
  61. _BN_EPS_PT_DEFAULT = 1e-5
  62. _BN_ARGS_PT = dict(momentum=_BN_MOMENTUM_PT_DEFAULT, eps=_BN_EPS_PT_DEFAULT)
  63. # Defaults used for Google/Tensorflow training of mobile networks /w
  64. # RMSprop as per papers and TF reference implementations. PT momentum
  65. # equiv for TF decay is (1 - TF decay)
  66. # NOTE: momentum varies btw .99 and .9997 depending on source
  67. # .99 in official TF TPU impl
  68. # .9997 (/w .999 in search space) for paper
  69. _BN_MOMENTUM_TF_DEFAULT = 1 - 0.99
  70. _BN_EPS_TF_DEFAULT = 1e-3
  71. _BN_ARGS_TF = dict(momentum=_BN_MOMENTUM_TF_DEFAULT, eps=_BN_EPS_TF_DEFAULT)
  72. def _initialize_weight_goog(shape=None, layer_type='conv', bias=False):
  73. """Google style weight initialization"""
  74. if layer_type not in ('conv', 'bn', 'fc'):
  75. raise ValueError(
  76. 'The layer type is not known, the supported are conv, bn and fc')
  77. if bias:
  78. return Zero()
  79. if layer_type == 'conv':
  80. assert isinstance(shape, (tuple, list)) and len(
  81. shape) == 3, 'The shape must be 3 scalars, and are in_chs, ks, out_chs respectively'
  82. n = shape[1] * shape[1] * shape[2]
  83. return Normal(math.sqrt(2.0 / n))
  84. if layer_type == 'bn':
  85. return One()
  86. assert isinstance(shape, (tuple, list)) and len(
  87. shape) == 2, 'The shape must be 2 scalars, and are in_chs, out_chs respectively'
  88. n = shape[1]
  89. init_range = 1.0 / math.sqrt(n)
  90. return Uniform(init_range)
  91. def _conv(in_channels, out_channels, kernel_size=3, stride=1, padding=0,
  92. pad_mode='same', bias=False):
  93. """convolution wrapper"""
  94. weight_init_value = _initialize_weight_goog(
  95. shape=(in_channels, kernel_size, out_channels))
  96. bias_init_value = _initialize_weight_goog(bias=True) if bias else None
  97. if bias:
  98. return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
  99. padding=padding, pad_mode=pad_mode, weight_init=weight_init_value,
  100. has_bias=bias, bias_init=bias_init_value)
  101. return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
  102. padding=padding, pad_mode=pad_mode, weight_init=weight_init_value,
  103. has_bias=bias)
  104. def _conv1x1(in_channels, out_channels, stride=1, padding=0, pad_mode='same', bias=False):
  105. """1x1 convolution wrapper"""
  106. weight_init_value = _initialize_weight_goog(
  107. shape=(in_channels, 1, out_channels))
  108. bias_init_value = _initialize_weight_goog(bias=True) if bias else None
  109. if bias:
  110. return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride,
  111. padding=padding, pad_mode=pad_mode, weight_init=weight_init_value,
  112. has_bias=bias, bias_init=bias_init_value)
  113. return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride,
  114. padding=padding, pad_mode=pad_mode, weight_init=weight_init_value,
  115. has_bias=bias)
  116. def _conv_group(in_channels, out_channels, group, kernel_size=3, stride=1, padding=0,
  117. pad_mode='same', bias=False):
  118. """group convolution wrapper"""
  119. weight_init_value = _initialize_weight_goog(
  120. shape=(in_channels, kernel_size, out_channels))
  121. bias_init_value = _initialize_weight_goog(bias=True) if bias else None
  122. if bias:
  123. return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
  124. padding=padding, pad_mode=pad_mode, weight_init=weight_init_value,
  125. group=group, has_bias=bias, bias_init=bias_init_value)
  126. return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
  127. padding=padding, pad_mode=pad_mode, weight_init=weight_init_value,
  128. group=group, has_bias=bias)
  129. def _fused_bn(channels, momentum=0.1, eps=1e-4, gamma_init=1, beta_init=0):
  130. return nn.BatchNorm2d(channels, eps=eps, momentum=1-momentum, gamma_init=gamma_init,
  131. beta_init=beta_init)
  132. def _dense(in_channels, output_channels, bias=True, activation=None):
  133. weight_init_value = _initialize_weight_goog(shape=(in_channels, output_channels),
  134. layer_type='fc')
  135. bias_init_value = _initialize_weight_goog(bias=True) if bias else None
  136. if bias:
  137. return nn.Dense(in_channels, output_channels, weight_init=weight_init_value,
  138. bias_init=bias_init_value, has_bias=bias, activation=activation)
  139. return nn.Dense(in_channels, output_channels, weight_init=weight_init_value,
  140. has_bias=bias, activation=activation)
  141. def _resolve_bn_args(kwargs):
  142. bn_args = _BN_ARGS_TF.copy() if kwargs.pop(
  143. 'bn_tf', False) else _BN_ARGS_PT.copy()
  144. bn_momentum = kwargs.pop('bn_momentum', None)
  145. if bn_momentum is not None:
  146. bn_args['momentum'] = bn_momentum
  147. bn_eps = kwargs.pop('bn_eps', None)
  148. if bn_eps is not None:
  149. bn_args['eps'] = bn_eps
  150. return bn_args
  151. def _round_channels(channels, multiplier=1.0, divisor=8, channel_min=None):
  152. """Round number of filters based on depth multiplier."""
  153. if not multiplier:
  154. return channels
  155. channels *= multiplier
  156. channel_min = channel_min or divisor
  157. new_channels = max(
  158. int(channels + divisor / 2) // divisor * divisor,
  159. channel_min)
  160. # Make sure that round down does not go down by more than 10%.
  161. if new_channels < 0.9 * channels:
  162. new_channels += divisor
  163. return new_channels
  164. def _parse_ksize(ss):
  165. if ss.isdigit():
  166. return int(ss)
  167. return [int(k) for k in ss.split('.')]
  168. def _decode_block_str(block_str, depth_multiplier=1.0):
  169. """ Decode block definition string
  170. Gets a list of block arg (dicts) through a string notation of arguments.
  171. E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip
  172. All args can exist in any order with the exception of the leading string which
  173. is assumed to indicate the block type.
  174. leading string - block type (
  175. ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct)
  176. r - number of repeat blocks,
  177. k - kernel size,
  178. s - strides (1-9),
  179. e - expansion ratio,
  180. c - output channels,
  181. se - squeeze/excitation ratio
  182. n - activation fn ('re', 'r6', 'hs', or 'sw')
  183. Args:
  184. block_str: a string representation of block arguments.
  185. Returns:
  186. A list of block args (dicts)
  187. Raises:
  188. ValueError: if the string def not properly specified (TODO)
  189. """
  190. assert isinstance(block_str, str)
  191. ops = block_str.split('_')
  192. block_type = ops[0] # take the block type off the front
  193. ops = ops[1:]
  194. options = {}
  195. noskip = False
  196. for op in ops:
  197. if op == 'noskip':
  198. noskip = True
  199. elif op.startswith('n'):
  200. # activation fn
  201. key = op[0]
  202. v = op[1:]
  203. if v == 're':
  204. print('not support')
  205. elif v == 'r6':
  206. print('not support')
  207. elif v == 'hs':
  208. print('not support')
  209. elif v == 'sw':
  210. print('not support')
  211. else:
  212. continue
  213. options[key] = value
  214. else:
  215. # all numeric options
  216. splits = re.split(r'(\d.*)', op)
  217. if len(splits) >= 2:
  218. key, value = splits[:2]
  219. options[key] = value
  220. act_fn = options['n'] if 'n' in options else None
  221. exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
  222. pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
  223. fake_in_chs = int(options['fc']) if 'fc' in options else 0
  224. num_repeat = int(options['r'])
  225. # each type of block has different valid arguments, fill accordingly
  226. if block_type == 'ir':
  227. block_args = dict(
  228. block_type=block_type,
  229. dw_kernel_size=_parse_ksize(options['k']),
  230. exp_kernel_size=exp_kernel_size,
  231. pw_kernel_size=pw_kernel_size,
  232. out_chs=int(options['c']),
  233. exp_ratio=float(options['e']),
  234. se_ratio=float(options['se']) if 'se' in options else None,
  235. stride=int(options['s']),
  236. act_fn=act_fn,
  237. noskip=noskip,
  238. )
  239. elif block_type in ('ds', 'dsa'):
  240. block_args = dict(
  241. block_type=block_type,
  242. dw_kernel_size=_parse_ksize(options['k']),
  243. pw_kernel_size=pw_kernel_size,
  244. out_chs=int(options['c']),
  245. se_ratio=float(options['se']) if 'se' in options else None,
  246. stride=int(options['s']),
  247. act_fn=act_fn,
  248. pw_act=block_type == 'dsa',
  249. noskip=block_type == 'dsa' or noskip,
  250. )
  251. elif block_type == 'er':
  252. block_args = dict(
  253. block_type=block_type,
  254. exp_kernel_size=_parse_ksize(options['k']),
  255. pw_kernel_size=pw_kernel_size,
  256. out_chs=int(options['c']),
  257. exp_ratio=float(options['e']),
  258. fake_in_chs=fake_in_chs,
  259. se_ratio=float(options['se']) if 'se' in options else None,
  260. stride=int(options['s']),
  261. act_fn=act_fn,
  262. noskip=noskip,
  263. )
  264. elif block_type == 'cn':
  265. block_args = dict(
  266. block_type=block_type,
  267. kernel_size=int(options['k']),
  268. out_chs=int(options['c']),
  269. stride=int(options['s']),
  270. act_fn=act_fn,
  271. )
  272. else:
  273. assert False, 'Unknown block type (%s)' % block_type
  274. return block_args, num_repeat
  275. def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'):
  276. """ Per-stage depth scaling
  277. Scales the block repeats in each stage. This depth scaling impl maintains
  278. compatibility with the EfficientNet scaling method, while allowing sensible
  279. scaling for other models that may have multiple block arg definitions in each stage.
  280. """
  281. # We scale the total repeat count for each stage, there may be multiple
  282. # block arg defs per stage so we need to sum.
  283. num_repeat = sum(repeats)
  284. if depth_trunc == 'round':
  285. # Truncating to int by rounding allows stages with few repeats to remain
  286. # proportionally smaller for longer. This is a good choice when stage definitions
  287. # include single repeat stages that we'd prefer to keep that way as long as possible
  288. num_repeat_scaled = max(1, round(num_repeat * depth_multiplier))
  289. else:
  290. # The default for EfficientNet truncates repeats to int via 'ceil'.
  291. # Any multiplier > 1.0 will result in an increased depth for every stage.
  292. num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier))
  293. # Proportionally distribute repeat count scaling to each block definition in the stage.
  294. # Allocation is done in reverse as it results in the first block being less likely to be scaled.
  295. # The first block makes less sense to repeat in most of the arch definitions.
  296. repeats_scaled = []
  297. for r in repeats[::-1]:
  298. rs = max(1, round((r / num_repeat * num_repeat_scaled)))
  299. repeats_scaled.append(rs)
  300. num_repeat -= r
  301. num_repeat_scaled -= rs
  302. repeats_scaled = repeats_scaled[::-1]
  303. # Apply the calculated scaling to each block arg in the stage
  304. sa_scaled = []
  305. for ba, rep in zip(stack_args, repeats_scaled):
  306. sa_scaled.extend([deepcopy(ba) for _ in range(rep)])
  307. return sa_scaled
  308. def _decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil'):
  309. """further decode the architecture definition into model-ready format"""
  310. arch_args = []
  311. for _, block_strings in enumerate(arch_def):
  312. assert isinstance(block_strings, list)
  313. stack_args = []
  314. repeats = []
  315. for block_str in block_strings:
  316. assert isinstance(block_str, str)
  317. ba, rep = _decode_block_str(block_str)
  318. stack_args.append(ba)
  319. repeats.append(rep)
  320. arch_args.append(_scale_stage_depth(
  321. stack_args, repeats, depth_multiplier, depth_trunc))
  322. return arch_args
  323. class Swish(nn.Cell):
  324. """swish activation function"""
  325. def __init__(self):
  326. super(Swish, self).__init__()
  327. self.sigmoid = P.Sigmoid()
  328. def construct(self, x):
  329. return x * self.sigmoid(x)
  330. @ms_function
  331. def swish(x):
  332. return x * nn.Sigmoid()(x)
  333. class BlockBuilder(nn.Cell):
  334. """build efficient-net convolution blocks"""
  335. def __init__(self, builder_in_channels, builder_block_args, channel_multiplier=1.0,
  336. channel_divisor=8, channel_min=None, pad_type='', act_fn=None,
  337. se_gate_fn=sigmoid, se_reduce_mid=False, bn_args=None,
  338. drop_connect_rate=0., verbose=False):
  339. super(BlockBuilder, self).__init__()
  340. self.channel_multiplier = channel_multiplier
  341. self.channel_divisor = channel_divisor
  342. self.channel_min = channel_min
  343. self.pad_type = pad_type
  344. self.act_fn = Swish()
  345. self.se_gate_fn = se_gate_fn
  346. self.se_reduce_mid = se_reduce_mid
  347. self.bn_args = bn_args
  348. self.drop_connect_rate = drop_connect_rate
  349. self.verbose = verbose
  350. # updated during build
  351. self.in_chs = None
  352. self.block_idx = 0
  353. self.block_count = 0
  354. self.layer = self._make_layer(builder_in_channels, builder_block_args)
  355. def _round_channels(self, chs):
  356. return _round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min)
  357. def _make_block(self, ba):
  358. """make the current block based on the block argument"""
  359. bt = ba.pop('block_type')
  360. ba['in_chs'] = self.in_chs
  361. ba['out_chs'] = self._round_channels(ba['out_chs'])
  362. if 'fake_in_chs' in ba and ba['fake_in_chs']:
  363. # this is a hack to work around mismatch in origin impl input filters
  364. ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs'])
  365. ba['bn_args'] = self.bn_args
  366. ba['pad_type'] = self.pad_type
  367. # block act fn overrides the model default
  368. ba['act_fn'] = ba['act_fn'] if ba['act_fn'] is not None else self.act_fn
  369. assert ba['act_fn'] is not None
  370. if bt == 'ir':
  371. ba['drop_connect_rate'] = self.drop_connect_rate * \
  372. self.block_idx / self.block_count
  373. ba['se_gate_fn'] = self.se_gate_fn
  374. ba['se_reduce_mid'] = self.se_reduce_mid
  375. block = InvertedResidual(**ba)
  376. elif bt in ('ds', 'dsa'):
  377. ba['drop_connect_rate'] = self.drop_connect_rate * \
  378. self.block_idx / self.block_count
  379. block = DepthwiseSeparableConv(**ba)
  380. else:
  381. assert False, 'Uknkown block type (%s) while building model.' % bt
  382. self.in_chs = ba['out_chs']
  383. return block
  384. def _make_stack(self, stack_args):
  385. """make a stack of blocks"""
  386. blocks = []
  387. # each stack (stage) contains a list of block arguments
  388. for i, ba in enumerate(stack_args):
  389. if i >= 1:
  390. # only the first block in any stack can have a stride > 1
  391. ba['stride'] = 1
  392. block = self._make_block(ba)
  393. blocks.append(block)
  394. self.block_idx += 1 # incr global idx (across all stacks)
  395. return nn.SequentialCell(blocks)
  396. def _make_layer(self, in_chs, block_args):
  397. """ Build the entire layer
  398. Args:
  399. in_chs: Number of input-channels passed to first block
  400. block_args: A list of lists, outer list defines stages, inner
  401. list contains strings defining block configuration(s)
  402. Return:
  403. List of block stacks (each stack wrapped in nn.Sequential)
  404. """
  405. self.in_chs = in_chs
  406. self.block_count = sum([len(x) for x in block_args])
  407. self.block_idx = 0
  408. blocks = []
  409. # outer list of block_args defines the stacks ('stages' by some conventions)
  410. for _, stack in enumerate(block_args):
  411. assert isinstance(stack, list)
  412. stack = self._make_stack(stack)
  413. blocks.append(stack)
  414. return nn.SequentialCell(blocks)
  415. def construct(self, x):
  416. return self.layer(x)
  417. class DepthWiseConv(nn.Cell):
  418. """depth-wise convolution"""
  419. def __init__(self, in_planes, kernel_size, stride):
  420. super(DepthWiseConv, self).__init__()
  421. platform = context.get_context("device_target")
  422. weight_shape = [1, kernel_size, in_planes]
  423. weight_init = _initialize_weight_goog(shape=weight_shape)
  424. if platform == "GPU":
  425. self.depthwise_conv = P.Conv2D(out_channel=in_planes*1,
  426. kernel_size=kernel_size,
  427. stride=stride,
  428. pad=int(kernel_size/2),
  429. pad_mode="pad",
  430. group=in_planes)
  431. self.weight = Parameter(initializer(weight_init,
  432. [in_planes*1, 1, kernel_size, kernel_size]))
  433. else:
  434. self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=1,
  435. kernel_size=kernel_size,
  436. stride=stride, pad_mode='pad',
  437. pad=int(kernel_size/2))
  438. self.weight = Parameter(initializer(weight_init,
  439. [1, in_planes, kernel_size, kernel_size]))
  440. def construct(self, x):
  441. x = self.depthwise_conv(x, self.weight)
  442. return x
  443. class DropConnect(nn.Cell):
  444. """drop connect implementation"""
  445. def __init__(self, drop_connect_rate=0., seed0=0, seed1=0):
  446. super(DropConnect, self).__init__()
  447. self.shape = P.Shape()
  448. self.dtype = P.DType()
  449. self.keep_prob = 1 - drop_connect_rate
  450. self.dropout = P.Dropout(keep_prob=self.keep_prob)
  451. def construct(self, x):
  452. shape = self.shape(x)
  453. dtype = self.dtype(x)
  454. ones_tensor = P.Fill()(dtype, (shape[0], 1, 1, 1), 1)
  455. _, mask_ = self.dropout(ones_tensor)
  456. x = x * mask_
  457. return x
  458. def drop_connect(inputs, training=False, drop_connect_rate=0.):
  459. if not training:
  460. return inputs
  461. return DropConnect(drop_connect_rate)(inputs)
  462. class SqueezeExcite(nn.Cell):
  463. """squeeze-excite implementation"""
  464. def __init__(self, in_chs, reduce_chs=None, act_fn=relu, gate_fn=sigmoid):
  465. super(SqueezeExcite, self).__init__()
  466. self.act_fn = Swish()
  467. self.gate_fn = gate_fn
  468. reduce_chs = reduce_chs or in_chs
  469. self.conv_reduce = nn.Conv2d(in_channels=in_chs, out_channels=reduce_chs,
  470. kernel_size=1, has_bias=True, pad_mode='pad')
  471. self.conv_expand = nn.Conv2d(in_channels=reduce_chs, out_channels=in_chs,
  472. kernel_size=1, has_bias=True, pad_mode='pad')
  473. self.avg_global_pool = P.ReduceMean(keep_dims=True)
  474. def construct(self, x):
  475. x_se = self.avg_global_pool(x, (2, 3))
  476. x_se = self.conv_reduce(x_se)
  477. x_se = self.act_fn(x_se)
  478. x_se = self.conv_expand(x_se)
  479. x_se = self.gate_fn(x_se)
  480. x = x * x_se
  481. return x
  482. class DepthwiseSeparableConv(nn.Cell):
  483. """depth-wise convolution -> (squeeze-excite) -> point-wise convolution"""
  484. def __init__(self, in_chs, out_chs, dw_kernel_size=3,
  485. stride=1, pad_type='', act_fn=relu, noskip=False,
  486. pw_kernel_size=1, pw_act=False, se_ratio=0., se_gate_fn=sigmoid,
  487. bn_args=None, drop_connect_rate=0.):
  488. super(DepthwiseSeparableConv, self).__init__()
  489. assert stride in [1, 2], 'stride must be 1 or 2'
  490. self.has_se = se_ratio is not None and se_ratio > 0.
  491. self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
  492. self.has_pw_act = pw_act
  493. self.act_fn = Swish()
  494. self.drop_connect_rate = drop_connect_rate
  495. self.conv_dw = DepthWiseConv(in_chs, dw_kernel_size, stride)
  496. self.bn1 = _fused_bn(in_chs, **bn_args)
  497. if self.has_se:
  498. self.se = SqueezeExcite(in_chs, reduce_chs=max(1, int(in_chs * se_ratio)),
  499. act_fn=act_fn, gate_fn=se_gate_fn)
  500. self.conv_pw = _conv1x1(in_chs, out_chs)
  501. self.bn2 = _fused_bn(out_chs, **bn_args)
  502. def construct(self, x):
  503. """forward the depthwise separable conv"""
  504. identity = x
  505. x = self.conv_dw(x)
  506. x = self.bn1(x)
  507. x = self.act_fn(x)
  508. if self.has_se:
  509. x = self.se(x)
  510. x = self.conv_pw(x)
  511. x = self.bn2(x)
  512. if self.has_pw_act:
  513. x = self.act_fn(x)
  514. if self.has_residual:
  515. if self.drop_connect_rate > 0.:
  516. x = drop_connect(x, self.training, self.drop_connect_rate)
  517. x = x + identity
  518. return x
  519. class InvertedResidual(nn.Cell):
  520. """inverted-residual block implementation"""
  521. def __init__(self, in_chs, out_chs, dw_kernel_size=3, stride=1,
  522. pad_type='', act_fn=relu, pw_kernel_size=1,
  523. noskip=False, exp_ratio=1., exp_kernel_size=1, se_ratio=0.,
  524. se_reduce_mid=False, se_gate_fn=sigmoid, shuffle_type=None,
  525. bn_args=None, drop_connect_rate=0.):
  526. super(InvertedResidual, self).__init__()
  527. mid_chs = int(in_chs * exp_ratio)
  528. self.has_se = se_ratio is not None and se_ratio > 0.
  529. self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
  530. self.act_fn = Swish()
  531. self.drop_connect_rate = drop_connect_rate
  532. self.conv_pw = _conv(in_chs, mid_chs, exp_kernel_size)
  533. self.bn1 = _fused_bn(mid_chs, **bn_args)
  534. self.shuffle_type = shuffle_type
  535. if self.shuffle_type is not None and isinstance(exp_kernel_size, list):
  536. self.shuffle = None
  537. self.conv_dw = DepthWiseConv(mid_chs, dw_kernel_size, stride)
  538. self.bn2 = _fused_bn(mid_chs, **bn_args)
  539. if self.has_se:
  540. se_base_chs = mid_chs if se_reduce_mid else in_chs
  541. self.se = SqueezeExcite(
  542. mid_chs, reduce_chs=max(1, int(se_base_chs * se_ratio)),
  543. act_fn=act_fn, gate_fn=se_gate_fn
  544. )
  545. self.conv_pwl = _conv(mid_chs, out_chs, pw_kernel_size)
  546. self.bn3 = _fused_bn(out_chs, **bn_args)
  547. def construct(self, x):
  548. """forward the inverted-residual block"""
  549. identity = x
  550. x = self.conv_pw(x)
  551. x = self.bn1(x)
  552. x = self.act_fn(x)
  553. x = self.conv_dw(x)
  554. x = self.bn2(x)
  555. x = self.act_fn(x)
  556. if self.has_se:
  557. x = self.se(x)
  558. x = self.conv_pwl(x)
  559. x = self.bn3(x)
  560. if self.has_residual:
  561. if self.drop_connect_rate > 0:
  562. x = drop_connect(x, self.training, self.drop_connect_rate)
  563. x = x + identity
  564. return x
  565. class GenEfficientNet(nn.Cell):
  566. """Generate EfficientNet architecture"""
  567. def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_features=1280,
  568. channel_multiplier=1.0, channel_divisor=8, channel_min=None,
  569. pad_type='', act_fn=relu, drop_rate=0., drop_connect_rate=0.,
  570. se_gate_fn=sigmoid, se_reduce_mid=False, bn_args=None,
  571. global_pool='avg', head_conv='default', weight_init='goog'):
  572. super(GenEfficientNet, self).__init__()
  573. bn_args = _BN_ARGS_PT if bn_args is None else bn_args
  574. self.num_classes = num_classes
  575. self.drop_rate = drop_rate
  576. self.num_features = num_features
  577. self.conv_stem = _conv(in_chans, stem_size, 3,
  578. stride=2, padding=1, pad_mode='pad')
  579. self.bn1 = _fused_bn(stem_size, **bn_args)
  580. self.act_fn = Swish()
  581. in_chans = stem_size
  582. self.blocks = BlockBuilder(in_chans, block_args, channel_multiplier,
  583. channel_divisor, channel_min,
  584. pad_type, act_fn, se_gate_fn, se_reduce_mid,
  585. bn_args, drop_connect_rate, verbose=_DEBUG)
  586. in_chs = self.blocks.in_chs
  587. if not head_conv or head_conv == 'none':
  588. self.efficient_head = False
  589. self.conv_head = None
  590. assert in_chs == self.num_features
  591. else:
  592. self.efficient_head = head_conv == 'efficient'
  593. self.conv_head = _conv1x1(in_chs, self.num_features)
  594. self.bn2 = None if self.efficient_head else _fused_bn(
  595. self.num_features, **bn_args)
  596. self.global_pool = P.ReduceMean(keep_dims=True)
  597. self.classifier = _dense(self.num_features, self.num_classes)
  598. self.reshape = P.Reshape()
  599. self.shape = P.Shape()
  600. self.drop_out = nn.Dropout(keep_prob=1-self.drop_rate)
  601. def construct(self, x):
  602. """efficient net entry point"""
  603. x = self.conv_stem(x)
  604. x = self.bn1(x)
  605. x = self.act_fn(x)
  606. x = self.blocks(x)
  607. if self.efficient_head:
  608. x = self.global_pool(x, (2, 3))
  609. x = self.conv_head(x)
  610. x = self.act_fn(x)
  611. x = self.reshape(self.shape(x)[0], -1)
  612. else:
  613. if self.conv_head is not None:
  614. x = self.conv_head(x)
  615. x = self.bn2(x)
  616. x = self.act_fn(x)
  617. x = self.global_pool(x, (2, 3))
  618. x = self.reshape(x, (self.shape(x)[0], -1))
  619. if self.training and self.drop_rate > 0.:
  620. x = self.drop_out(x)
  621. return self.classifier(x)
  622. def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=1000, **kwargs):
  623. """Creates an EfficientNet model.
  624. Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
  625. Paper: https://arxiv.org/abs/1905.11946
  626. EfficientNet params
  627. name: (channel_multiplier, depth_multiplier, resolution, dropout_rate)
  628. 'efficientnet-b0': (1.0, 1.0, 224, 0.2),
  629. 'efficientnet-b1': (1.0, 1.1, 240, 0.2),
  630. 'efficientnet-b2': (1.1, 1.2, 260, 0.3),
  631. 'efficientnet-b3': (1.2, 1.4, 300, 0.3),
  632. 'efficientnet-b4': (1.4, 1.8, 380, 0.4),
  633. 'efficientnet-b5': (1.6, 2.2, 456, 0.4),
  634. 'efficientnet-b6': (1.8, 2.6, 528, 0.5),
  635. 'efficientnet-b7': (2.0, 3.1, 600, 0.5),
  636. Args:
  637. channel_multiplier (int): multiplier to number of channels per layer
  638. depth_multiplier (int): multiplier to number of repeats per stage
  639. """
  640. arch_def = [
  641. ['ds_r1_k3_s1_e1_c16_se0.25'],
  642. ['ir_r2_k3_s2_e6_c24_se0.25'],
  643. ['ir_r2_k5_s2_e6_c40_se0.25'],
  644. ['ir_r3_k3_s2_e6_c80_se0.25'],
  645. ['ir_r3_k5_s1_e6_c112_se0.25'],
  646. ['ir_r4_k5_s2_e6_c192_se0.25'],
  647. ['ir_r1_k3_s1_e6_c320_se0.25'],
  648. ]
  649. num_features = max(1280, _round_channels(
  650. 1280, channel_multiplier, 8, None))
  651. model = GenEfficientNet(
  652. _decode_arch_def(arch_def, depth_multiplier, depth_trunc='round'),
  653. num_classes=num_classes,
  654. stem_size=32,
  655. channel_multiplier=channel_multiplier,
  656. num_features=num_features,
  657. bn_args=_resolve_bn_args(kwargs),
  658. act_fn=Swish,
  659. **kwargs)
  660. return model
  661. def tinynet(sub_model="c", num_classes=1000, in_chans=3, **kwargs):
  662. """ TinyNet Models """
  663. # choose a sub model
  664. r, w, d = TINYNET_CFG[sub_model]
  665. default_cfg = default_cfgs['efficientnet_b0']
  666. assert default_cfg['input_size'] == (3, 224, 224), "All tinynet models are \
  667. evolved from Efficient-B0, which has input dimension of 3*224*224"
  668. channel, height, width = default_cfg['input_size']
  669. height = int(r * height)
  670. width = int(r * width)
  671. default_cfg['input_size'] = (channel, height, width)
  672. print("Data processing configuration for current model + dataset:")
  673. print("input_size:", default_cfg['input_size'])
  674. print("channel mutiplier:%s, depth multiplier:%s, resolution multiplier:%s" % (w, d, r))
  675. model = _gen_efficientnet(
  676. channel_multiplier=w, depth_multiplier=d,
  677. num_classes=num_classes, in_chans=in_chans, **kwargs)
  678. model.default_cfg = default_cfg
  679. return model