You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ssd.py 18 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """SSD net based MobilenetV2."""
  16. import mindspore.common.dtype as mstype
  17. import mindspore as ms
  18. import mindspore.nn as nn
  19. from mindspore import Parameter, context, Tensor
  20. from mindspore.parallel._auto_parallel_context import auto_parallel_context
  21. from mindspore.communication.management import get_group_size
  22. from mindspore.ops import operations as P
  23. from mindspore.ops import functional as F
  24. from mindspore.ops import composite as C
  25. from mindspore.common.initializer import initializer
  26. def _make_divisible(v, divisor, min_value=None):
  27. """nsures that all layers have a channel number that is divisible by 8."""
  28. if min_value is None:
  29. min_value = divisor
  30. new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
  31. # Make sure that round down does not go down by more than 10%.
  32. if new_v < 0.9 * v:
  33. new_v += divisor
  34. return new_v
  35. def _conv2d(in_channel, out_channel, kernel_size=3, stride=1, pad_mod='same'):
  36. return nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride,
  37. padding=0, pad_mode=pad_mod, has_bias=True)
  38. def _bn(channel):
  39. return nn.BatchNorm2d(channel, eps=1e-3, momentum=0.97,
  40. gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
  41. def _last_conv2d(in_channel, out_channel, kernel_size=3, stride=1, pad_mod='same', pad=0):
  42. depthwise_conv = DepthwiseConv(in_channel, kernel_size, stride, pad_mode='same', pad=pad)
  43. conv = _conv2d(in_channel, out_channel, kernel_size=1)
  44. return nn.SequentialCell([depthwise_conv, _bn(in_channel), nn.ReLU6(), conv])
  45. class ConvBNReLU(nn.Cell):
  46. """
  47. Convolution/Depthwise fused with Batchnorm and ReLU block definition.
  48. Args:
  49. in_planes (int): Input channel.
  50. out_planes (int): Output channel.
  51. kernel_size (int): Input kernel size.
  52. stride (int): Stride size for the first convolutional layer. Default: 1.
  53. groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1.
  54. Returns:
  55. Tensor, output tensor.
  56. Examples:
  57. >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1)
  58. """
  59. def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
  60. super(ConvBNReLU, self).__init__()
  61. padding = 0
  62. if groups == 1:
  63. conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='same',
  64. padding=padding)
  65. else:
  66. conv = DepthwiseConv(in_planes, kernel_size, stride, pad_mode='same', pad=padding)
  67. layers = [conv, _bn(out_planes), nn.ReLU6()]
  68. self.features = nn.SequentialCell(layers)
  69. def construct(self, x):
  70. output = self.features(x)
  71. return output
  72. class DepthwiseConv(nn.Cell):
  73. """
  74. Depthwise Convolution warpper definition.
  75. Args:
  76. in_planes (int): Input channel.
  77. kernel_size (int): Input kernel size.
  78. stride (int): Stride size.
  79. pad_mode (str): pad mode in (pad, same, valid)
  80. channel_multiplier (int): Output channel multiplier
  81. has_bias (bool): has bias or not
  82. Returns:
  83. Tensor, output tensor.
  84. Examples:
  85. >>> DepthwiseConv(16, 3, 1, 'pad', 1, channel_multiplier=1)
  86. """
  87. def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False):
  88. super(DepthwiseConv, self).__init__()
  89. self.has_bias = has_bias
  90. self.in_channels = in_planes
  91. self.channel_multiplier = channel_multiplier
  92. self.out_channels = in_planes * channel_multiplier
  93. self.kernel_size = (kernel_size, kernel_size)
  94. self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier,
  95. kernel_size=self.kernel_size,
  96. stride=stride, pad_mode=pad_mode, pad=pad)
  97. self.bias_add = P.BiasAdd()
  98. weight_shape = [channel_multiplier, in_planes, *self.kernel_size]
  99. self.weight = Parameter(initializer('ones', weight_shape), name='weight')
  100. if has_bias:
  101. bias_shape = [channel_multiplier * in_planes]
  102. self.bias = Parameter(initializer('zeros', bias_shape), name='bias')
  103. else:
  104. self.bias = None
  105. def construct(self, x):
  106. output = self.depthwise_conv(x, self.weight)
  107. if self.has_bias:
  108. output = self.bias_add(output, self.bias)
  109. return output
  110. class InvertedResidual(nn.Cell):
  111. """
  112. Residual block definition.
  113. Args:
  114. inp (int): Input channel.
  115. oup (int): Output channel.
  116. stride (int): Stride size for the first convolutional layer. Default: 1.
  117. expand_ratio (int): expand ration of input channel
  118. Returns:
  119. Tensor, output tensor.
  120. Examples:
  121. >>> ResidualBlock(3, 256, 1, 1)
  122. """
  123. def __init__(self, inp, oup, stride, expand_ratio, last_relu=False):
  124. super(InvertedResidual, self).__init__()
  125. assert stride in [1, 2]
  126. hidden_dim = int(round(inp * expand_ratio))
  127. self.use_res_connect = stride == 1 and inp == oup
  128. layers = []
  129. if expand_ratio != 1:
  130. layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
  131. layers.extend([
  132. # dw
  133. ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
  134. # pw-linear
  135. nn.Conv2d(hidden_dim, oup, kernel_size=1, stride=1, has_bias=False),
  136. _bn(oup),
  137. ])
  138. self.conv = nn.SequentialCell(layers)
  139. self.add = P.TensorAdd()
  140. self.cast = P.Cast()
  141. self.last_relu = last_relu
  142. self.relu = nn.ReLU6()
  143. def construct(self, x):
  144. identity = x
  145. x = self.conv(x)
  146. if self.use_res_connect:
  147. x = self.add(identity, x)
  148. if self.last_relu:
  149. x = self.relu(x)
  150. return x
  151. class FlattenConcat(nn.Cell):
  152. """
  153. Concatenate predictions into a single tensor.
  154. Args:
  155. config (dict): The default config of SSD.
  156. Returns:
  157. Tensor, flatten predictions.
  158. """
  159. def __init__(self, config):
  160. super(FlattenConcat, self).__init__()
  161. self.num_ssd_boxes = config.num_ssd_boxes
  162. self.concat = P.Concat(axis=1)
  163. self.transpose = P.Transpose()
  164. def construct(self, inputs):
  165. output = ()
  166. batch_size = F.shape(inputs[0])[0]
  167. for x in inputs:
  168. x = self.transpose(x, (0, 2, 3, 1))
  169. output += (F.reshape(x, (batch_size, -1)),)
  170. res = self.concat(output)
  171. return F.reshape(res, (batch_size, self.num_ssd_boxes, -1))
  172. class MultiBox(nn.Cell):
  173. """
  174. Multibox conv layers. Each multibox layer contains class conf scores and localization predictions.
  175. Args:
  176. config (dict): The default config of SSD.
  177. Returns:
  178. Tensor, localization predictions.
  179. Tensor, class conf scores.
  180. """
  181. def __init__(self, config):
  182. super(MultiBox, self).__init__()
  183. num_classes = config.num_classes
  184. out_channels = config.extras_out_channels
  185. num_default = config.num_default
  186. loc_layers = []
  187. cls_layers = []
  188. for k, out_channel in enumerate(out_channels):
  189. loc_layers += [_last_conv2d(out_channel, 4 * num_default[k],
  190. kernel_size=3, stride=1, pad_mod='same', pad=0)]
  191. cls_layers += [_last_conv2d(out_channel, num_classes * num_default[k],
  192. kernel_size=3, stride=1, pad_mod='same', pad=0)]
  193. self.multi_loc_layers = nn.layer.CellList(loc_layers)
  194. self.multi_cls_layers = nn.layer.CellList(cls_layers)
  195. self.flatten_concat = FlattenConcat(config)
  196. def construct(self, inputs):
  197. loc_outputs = ()
  198. cls_outputs = ()
  199. for i in range(len(self.multi_loc_layers)):
  200. loc_outputs += (self.multi_loc_layers[i](inputs[i]),)
  201. cls_outputs += (self.multi_cls_layers[i](inputs[i]),)
  202. return self.flatten_concat(loc_outputs), self.flatten_concat(cls_outputs)
  203. class SSD300(nn.Cell):
  204. """
  205. SSD300 Network. Default backbone is resnet34.
  206. Args:
  207. backbone (Cell): Backbone Network.
  208. config (dict): The default config of SSD.
  209. Returns:
  210. Tensor, localization predictions.
  211. Tensor, class conf scores.
  212. Examples:backbone
  213. SSD300(backbone=resnet34(num_classes=None),
  214. config=config).
  215. """
  216. def __init__(self, backbone, config, is_training=True):
  217. super(SSD300, self).__init__()
  218. self.backbone = backbone
  219. in_channels = config.extras_in_channels
  220. out_channels = config.extras_out_channels
  221. ratios = config.extras_ratio
  222. strides = config.extras_srides
  223. residual_list = []
  224. for i in range(2, len(in_channels)):
  225. residual = InvertedResidual(in_channels[i], out_channels[i], stride=strides[i],
  226. expand_ratio=ratios[i], last_relu=True)
  227. residual_list.append(residual)
  228. self.multi_residual = nn.layer.CellList(residual_list)
  229. self.multi_box = MultiBox(config)
  230. self.is_training = is_training
  231. if not is_training:
  232. self.activation = P.Sigmoid()
  233. def construct(self, x):
  234. layer_out_13, output = self.backbone(x)
  235. multi_feature = (layer_out_13, output)
  236. feature = output
  237. for residual in self.multi_residual:
  238. feature = residual(feature)
  239. multi_feature += (feature,)
  240. pred_loc, pred_label = self.multi_box(multi_feature)
  241. if not self.is_training:
  242. pred_label = self.activation(pred_label)
  243. return pred_loc, pred_label
  244. class SigmoidFocalClassificationLoss(nn.Cell):
  245. """"
  246. Sigmoid focal-loss for classification.
  247. Args:
  248. gamma (float): Hyper-parameter to balance the easy and hard examples. Default: 2.0
  249. alpha (float): Hyper-parameter to balance the positive and negative example. Default: 0.25
  250. Returns:
  251. Tensor, the focal loss.
  252. """
  253. def __init__(self, gamma=2.0, alpha=0.25):
  254. super(SigmoidFocalClassificationLoss, self).__init__()
  255. self.sigmiod_cross_entropy = P.SigmoidCrossEntropyWithLogits()
  256. self.sigmoid = P.Sigmoid()
  257. self.pow = P.Pow()
  258. self.onehot = P.OneHot()
  259. self.on_value = Tensor(1.0, mstype.float32)
  260. self.off_value = Tensor(0.0, mstype.float32)
  261. self.gamma = gamma
  262. self.alpha = alpha
  263. def construct(self, logits, label):
  264. label = self.onehot(label, F.shape(logits)[-1], self.on_value, self.off_value)
  265. sigmiod_cross_entropy = self.sigmiod_cross_entropy(logits, label)
  266. sigmoid = self.sigmoid(logits)
  267. label = F.cast(label, mstype.float32)
  268. p_t = label * sigmoid + (1 - label) * (1 - sigmoid)
  269. modulating_factor = self.pow(1 - p_t, self.gamma)
  270. alpha_weight_factor = label * self.alpha + (1 - label) * (1 - self.alpha)
  271. focal_loss = modulating_factor * alpha_weight_factor * sigmiod_cross_entropy
  272. return focal_loss
  273. class SSDWithLossCell(nn.Cell):
  274. """"
  275. Provide SSD training loss through network.
  276. Args:
  277. network (Cell): The training network.
  278. config (dict): SSD config.
  279. Returns:
  280. Tensor, the loss of the network.
  281. """
  282. def __init__(self, network, config):
  283. super(SSDWithLossCell, self).__init__()
  284. self.network = network
  285. self.less = P.Less()
  286. self.tile = P.Tile()
  287. self.reduce_sum = P.ReduceSum()
  288. self.reduce_mean = P.ReduceMean()
  289. self.expand_dims = P.ExpandDims()
  290. self.class_loss = SigmoidFocalClassificationLoss(config.gamma, config.alpha)
  291. self.loc_loss = nn.SmoothL1Loss()
  292. def construct(self, x, gt_loc, gt_label, num_matched_boxes):
  293. pred_loc, pred_label = self.network(x)
  294. mask = F.cast(self.less(0, gt_label), mstype.float32)
  295. num_matched_boxes = self.reduce_sum(F.cast(num_matched_boxes, mstype.float32))
  296. # Localization Loss
  297. mask_loc = self.tile(self.expand_dims(mask, -1), (1, 1, 4))
  298. smooth_l1 = self.loc_loss(pred_loc, gt_loc) * mask_loc
  299. loss_loc = self.reduce_sum(self.reduce_mean(smooth_l1, -1), -1)
  300. # Classification Loss
  301. loss_cls = self.class_loss(pred_label, gt_label)
  302. loss_cls = self.reduce_sum(loss_cls, (1, 2))
  303. return self.reduce_sum((loss_cls + loss_loc) / num_matched_boxes)
  304. class TrainingWrapper(nn.Cell):
  305. """
  306. Encapsulation class of SSD network training.
  307. Append an optimizer to the training network after that the construct
  308. function can be called to create the backward graph.
  309. Args:
  310. network (Cell): The training network. Note that loss function should have been added.
  311. optimizer (Optimizer): Optimizer for updating the weights.
  312. sens (Number): The adjust parameter. Default: 1.0.
  313. """
  314. def __init__(self, network, optimizer, sens=1.0):
  315. super(TrainingWrapper, self).__init__(auto_prefix=False)
  316. self.network = network
  317. self.weights = ms.ParameterTuple(network.trainable_params())
  318. self.optimizer = optimizer
  319. self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
  320. self.sens = sens
  321. self.reducer_flag = False
  322. self.grad_reducer = None
  323. self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
  324. if self.parallel_mode in [ms.ParallelMode.DATA_PARALLEL, ms.ParallelMode.HYBRID_PARALLEL]:
  325. self.reducer_flag = True
  326. if self.reducer_flag:
  327. mean = context.get_auto_parallel_context("mirror_mean")
  328. if auto_parallel_context().get_device_num_is_set():
  329. degree = context.get_auto_parallel_context("device_num")
  330. else:
  331. degree = get_group_size()
  332. self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
  333. def construct(self, *args):
  334. weights = self.weights
  335. loss = self.network(*args)
  336. sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
  337. grads = self.grad(self.network, weights)(*args, sens)
  338. if self.reducer_flag:
  339. # apply grad reducer on grads
  340. grads = self.grad_reducer(grads)
  341. return F.depend(loss, self.optimizer(grads))
  342. class SSDWithMobileNetV2(nn.Cell):
  343. """
  344. MobileNetV2 architecture for SSD backbone.
  345. Args:
  346. width_mult (int): Channels multiplier for round to 8/16 and others. Default is 1.
  347. inverted_residual_setting (list): Inverted residual settings. Default is None
  348. round_nearest (list): Channel round to. Default is 8
  349. Returns:
  350. Tensor, the 13th feature after ConvBNReLU in MobileNetV2.
  351. Tensor, the last feature in MobileNetV2.
  352. Examples:
  353. >>> SSDWithMobileNetV2()
  354. """
  355. def __init__(self, width_mult=1.0, inverted_residual_setting=None, round_nearest=8):
  356. super(SSDWithMobileNetV2, self).__init__()
  357. block = InvertedResidual
  358. input_channel = 32
  359. last_channel = 1280
  360. if inverted_residual_setting is None:
  361. inverted_residual_setting = [
  362. # t, c, n, s
  363. [1, 16, 1, 1],
  364. [6, 24, 2, 2],
  365. [6, 32, 3, 2],
  366. [6, 64, 4, 2],
  367. [6, 96, 3, 1],
  368. [6, 160, 3, 2],
  369. [6, 320, 1, 1],
  370. ]
  371. if len(inverted_residual_setting[0]) != 4:
  372. raise ValueError("inverted_residual_setting should be non-empty "
  373. "or a 4-element list, got {}".format(inverted_residual_setting))
  374. #building first layer
  375. input_channel = _make_divisible(input_channel * width_mult, round_nearest)
  376. self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
  377. features = [ConvBNReLU(3, input_channel, stride=2)]
  378. # building inverted residual blocks
  379. layer_index = 0
  380. for t, c, n, s in inverted_residual_setting:
  381. output_channel = _make_divisible(c * width_mult, round_nearest)
  382. for i in range(n):
  383. if layer_index == 13:
  384. hidden_dim = int(round(input_channel * t))
  385. self.expand_layer_conv_13 = ConvBNReLU(input_channel, hidden_dim, kernel_size=1)
  386. stride = s if i == 0 else 1
  387. features.append(block(input_channel, output_channel, stride, expand_ratio=t))
  388. input_channel = output_channel
  389. layer_index += 1
  390. # building last several layers
  391. features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
  392. self.features_1 = nn.SequentialCell(features[:14])
  393. self.features_2 = nn.SequentialCell(features[14:])
  394. def construct(self, x):
  395. out = self.features_1(x)
  396. expand_layer_conv_13 = self.expand_layer_conv_13(out)
  397. out = self.features_2(out)
  398. return expand_layer_conv_13, out
  399. def get_out_channels(self):
  400. return self.last_channel
  401. def ssd_mobilenet_v2(**kwargs):
  402. return SSDWithMobileNetV2(**kwargs)