You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ssd.py 16 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """SSD net based MobilenetV2."""
  16. import mindspore.common.dtype as mstype
  17. import mindspore as ms
  18. import mindspore.nn as nn
  19. from mindspore import context, Tensor
  20. from mindspore.context import ParallelMode
  21. from mindspore.parallel._auto_parallel_context import auto_parallel_context
  22. from mindspore.communication.management import get_group_size
  23. from mindspore.ops import operations as P
  24. from mindspore.ops import functional as F
  25. from mindspore.ops import composite as C
  26. def _make_divisible(v, divisor, min_value=None):
  27. """nsures that all layers have a channel number that is divisible by 8."""
  28. if min_value is None:
  29. min_value = divisor
  30. new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
  31. # Make sure that round down does not go down by more than 10%.
  32. if new_v < 0.9 * v:
  33. new_v += divisor
  34. return new_v
  35. def _conv2d(in_channel, out_channel, kernel_size=3, stride=1, pad_mod='same'):
  36. return nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride,
  37. padding=0, pad_mode=pad_mod, has_bias=True)
  38. def _bn(channel):
  39. return nn.BatchNorm2d(channel, eps=1e-3, momentum=0.97,
  40. gamma_init=1, beta_init=0, moving_mean_init=0, moving_var_init=1)
  41. def _last_conv2d(in_channel, out_channel, kernel_size=3, stride=1, pad_mod='same', pad=0):
  42. in_channels = in_channel
  43. out_channels = in_channel
  44. depthwise_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='same',
  45. padding=pad, group=in_channels)
  46. conv = _conv2d(in_channel, out_channel, kernel_size=1)
  47. return nn.SequentialCell([depthwise_conv, _bn(in_channel), nn.ReLU6(), conv])
  48. class ConvBNReLU(nn.Cell):
  49. """
  50. Convolution/Depthwise fused with Batchnorm and ReLU block definition.
  51. Args:
  52. in_planes (int): Input channel.
  53. out_planes (int): Output channel.
  54. kernel_size (int): Input kernel size.
  55. stride (int): Stride size for the first convolutional layer. Default: 1.
  56. groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1.
  57. Returns:
  58. Tensor, output tensor.
  59. Examples:
  60. >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1)
  61. """
  62. def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
  63. super(ConvBNReLU, self).__init__()
  64. padding = 0
  65. in_channels = in_planes
  66. out_channels = out_planes
  67. if groups == 1:
  68. conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='same', padding=padding)
  69. else:
  70. out_channels = in_planes
  71. conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='same',
  72. padding=padding, group=in_channels)
  73. layers = [conv, _bn(out_planes), nn.ReLU6()]
  74. self.features = nn.SequentialCell(layers)
  75. def construct(self, x):
  76. output = self.features(x)
  77. return output
  78. class InvertedResidual(nn.Cell):
  79. """
  80. Residual block definition.
  81. Args:
  82. inp (int): Input channel.
  83. oup (int): Output channel.
  84. stride (int): Stride size for the first convolutional layer. Default: 1.
  85. expand_ratio (int): expand ration of input channel
  86. Returns:
  87. Tensor, output tensor.
  88. Examples:
  89. >>> ResidualBlock(3, 256, 1, 1)
  90. """
  91. def __init__(self, inp, oup, stride, expand_ratio, last_relu=False):
  92. super(InvertedResidual, self).__init__()
  93. assert stride in [1, 2]
  94. hidden_dim = int(round(inp * expand_ratio))
  95. self.use_res_connect = stride == 1 and inp == oup
  96. layers = []
  97. if expand_ratio != 1:
  98. layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
  99. layers.extend([
  100. # dw
  101. ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
  102. # pw-linear
  103. nn.Conv2d(hidden_dim, oup, kernel_size=1, stride=1, has_bias=False),
  104. _bn(oup),
  105. ])
  106. self.conv = nn.SequentialCell(layers)
  107. self.add = P.TensorAdd()
  108. self.cast = P.Cast()
  109. self.last_relu = last_relu
  110. self.relu = nn.ReLU6()
  111. def construct(self, x):
  112. identity = x
  113. x = self.conv(x)
  114. if self.use_res_connect:
  115. x = self.add(identity, x)
  116. if self.last_relu:
  117. x = self.relu(x)
  118. return x
  119. class FlattenConcat(nn.Cell):
  120. """
  121. Concatenate predictions into a single tensor.
  122. Args:
  123. config (dict): The default config of SSD.
  124. Returns:
  125. Tensor, flatten predictions.
  126. """
  127. def __init__(self, config):
  128. super(FlattenConcat, self).__init__()
  129. self.num_ssd_boxes = config.num_ssd_boxes
  130. self.concat = P.Concat(axis=1)
  131. self.transpose = P.Transpose()
  132. def construct(self, inputs):
  133. output = ()
  134. batch_size = F.shape(inputs[0])[0]
  135. for x in inputs:
  136. x = self.transpose(x, (0, 2, 3, 1))
  137. output += (F.reshape(x, (batch_size, -1)),)
  138. res = self.concat(output)
  139. return F.reshape(res, (batch_size, self.num_ssd_boxes, -1))
  140. class MultiBox(nn.Cell):
  141. """
  142. Multibox conv layers. Each multibox layer contains class conf scores and localization predictions.
  143. Args:
  144. config (dict): The default config of SSD.
  145. Returns:
  146. Tensor, localization predictions.
  147. Tensor, class conf scores.
  148. """
  149. def __init__(self, config):
  150. super(MultiBox, self).__init__()
  151. num_classes = config.num_classes
  152. out_channels = config.extras_out_channels
  153. num_default = config.num_default
  154. loc_layers = []
  155. cls_layers = []
  156. for k, out_channel in enumerate(out_channels):
  157. loc_layers += [_last_conv2d(out_channel, 4 * num_default[k],
  158. kernel_size=3, stride=1, pad_mod='same', pad=0)]
  159. cls_layers += [_last_conv2d(out_channel, num_classes * num_default[k],
  160. kernel_size=3, stride=1, pad_mod='same', pad=0)]
  161. self.multi_loc_layers = nn.layer.CellList(loc_layers)
  162. self.multi_cls_layers = nn.layer.CellList(cls_layers)
  163. self.flatten_concat = FlattenConcat(config)
  164. def construct(self, inputs):
  165. loc_outputs = ()
  166. cls_outputs = ()
  167. for i in range(len(self.multi_loc_layers)):
  168. loc_outputs += (self.multi_loc_layers[i](inputs[i]),)
  169. cls_outputs += (self.multi_cls_layers[i](inputs[i]),)
  170. return self.flatten_concat(loc_outputs), self.flatten_concat(cls_outputs)
  171. class SSD300(nn.Cell):
  172. """
  173. SSD300 Network. Default backbone is resnet34.
  174. Args:
  175. backbone (Cell): Backbone Network.
  176. config (dict): The default config of SSD.
  177. Returns:
  178. Tensor, localization predictions.
  179. Tensor, class conf scores.
  180. Examples:backbone
  181. SSD300(backbone=resnet34(num_classes=None),
  182. config=config).
  183. """
  184. def __init__(self, backbone, config, is_training=True):
  185. super(SSD300, self).__init__()
  186. self.backbone = backbone
  187. in_channels = config.extras_in_channels
  188. out_channels = config.extras_out_channels
  189. ratios = config.extras_ratio
  190. strides = config.extras_srides
  191. residual_list = []
  192. for i in range(2, len(in_channels)):
  193. residual = InvertedResidual(in_channels[i], out_channels[i], stride=strides[i],
  194. expand_ratio=ratios[i], last_relu=True)
  195. residual_list.append(residual)
  196. self.multi_residual = nn.layer.CellList(residual_list)
  197. self.multi_box = MultiBox(config)
  198. self.is_training = is_training
  199. if not is_training:
  200. self.activation = P.Sigmoid()
  201. def construct(self, x):
  202. layer_out_13, output = self.backbone(x)
  203. multi_feature = (layer_out_13, output)
  204. feature = output
  205. for residual in self.multi_residual:
  206. feature = residual(feature)
  207. multi_feature += (feature,)
  208. pred_loc, pred_label = self.multi_box(multi_feature)
  209. if not self.is_training:
  210. pred_label = self.activation(pred_label)
  211. return pred_loc, pred_label
  212. class SigmoidFocalClassificationLoss(nn.Cell):
  213. """"
  214. Sigmoid focal-loss for classification.
  215. Args:
  216. gamma (float): Hyper-parameter to balance the easy and hard examples. Default: 2.0
  217. alpha (float): Hyper-parameter to balance the positive and negative example. Default: 0.25
  218. Returns:
  219. Tensor, the focal loss.
  220. """
  221. def __init__(self, gamma=2.0, alpha=0.25):
  222. super(SigmoidFocalClassificationLoss, self).__init__()
  223. self.sigmiod_cross_entropy = P.SigmoidCrossEntropyWithLogits()
  224. self.sigmoid = P.Sigmoid()
  225. self.pow = P.Pow()
  226. self.onehot = P.OneHot()
  227. self.on_value = Tensor(1.0, mstype.float32)
  228. self.off_value = Tensor(0.0, mstype.float32)
  229. self.gamma = gamma
  230. self.alpha = alpha
  231. def construct(self, logits, label):
  232. label = self.onehot(label, F.shape(logits)[-1], self.on_value, self.off_value)
  233. sigmiod_cross_entropy = self.sigmiod_cross_entropy(logits, label)
  234. sigmoid = self.sigmoid(logits)
  235. label = F.cast(label, mstype.float32)
  236. p_t = label * sigmoid + (1 - label) * (1 - sigmoid)
  237. modulating_factor = self.pow(1 - p_t, self.gamma)
  238. alpha_weight_factor = label * self.alpha + (1 - label) * (1 - self.alpha)
  239. focal_loss = modulating_factor * alpha_weight_factor * sigmiod_cross_entropy
  240. return focal_loss
  241. class SSDWithLossCell(nn.Cell):
  242. """"
  243. Provide SSD training loss through network.
  244. Args:
  245. network (Cell): The training network.
  246. config (dict): SSD config.
  247. Returns:
  248. Tensor, the loss of the network.
  249. """
  250. def __init__(self, network, config):
  251. super(SSDWithLossCell, self).__init__()
  252. self.network = network
  253. self.less = P.Less()
  254. self.tile = P.Tile()
  255. self.reduce_sum = P.ReduceSum()
  256. self.reduce_mean = P.ReduceMean()
  257. self.expand_dims = P.ExpandDims()
  258. self.class_loss = SigmoidFocalClassificationLoss(config.gamma, config.alpha)
  259. self.loc_loss = nn.SmoothL1Loss()
  260. def construct(self, x, gt_loc, gt_label, num_matched_boxes):
  261. pred_loc, pred_label = self.network(x)
  262. mask = F.cast(self.less(0, gt_label), mstype.float32)
  263. num_matched_boxes = self.reduce_sum(F.cast(num_matched_boxes, mstype.float32))
  264. # Localization Loss
  265. mask_loc = self.tile(self.expand_dims(mask, -1), (1, 1, 4))
  266. smooth_l1 = self.loc_loss(pred_loc, gt_loc) * mask_loc
  267. loss_loc = self.reduce_sum(self.reduce_mean(smooth_l1, -1), -1)
  268. # Classification Loss
  269. loss_cls = self.class_loss(pred_label, gt_label)
  270. loss_cls = self.reduce_sum(loss_cls, (1, 2))
  271. return self.reduce_sum((loss_cls + loss_loc) / num_matched_boxes)
  272. class TrainingWrapper(nn.Cell):
  273. """
  274. Encapsulation class of SSD network training.
  275. Append an optimizer to the training network after that the construct
  276. function can be called to create the backward graph.
  277. Args:
  278. network (Cell): The training network. Note that loss function should have been added.
  279. optimizer (Optimizer): Optimizer for updating the weights.
  280. sens (Number): The adjust parameter. Default: 1.0.
  281. """
  282. def __init__(self, network, optimizer, sens=1.0):
  283. super(TrainingWrapper, self).__init__(auto_prefix=False)
  284. self.network = network
  285. self.network.set_grad()
  286. self.weights = ms.ParameterTuple(network.trainable_params())
  287. self.optimizer = optimizer
  288. self.grad = C.GradOperation(get_by_list=True, sens_param=True)
  289. self.sens = sens
  290. self.reducer_flag = False
  291. self.grad_reducer = None
  292. self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
  293. if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
  294. self.reducer_flag = True
  295. if self.reducer_flag:
  296. mean = context.get_auto_parallel_context("gradients_mean")
  297. if auto_parallel_context().get_device_num_is_set():
  298. degree = context.get_auto_parallel_context("device_num")
  299. else:
  300. degree = get_group_size()
  301. self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
  302. def construct(self, *args):
  303. weights = self.weights
  304. loss = self.network(*args)
  305. sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
  306. grads = self.grad(self.network, weights)(*args, sens)
  307. if self.reducer_flag:
  308. # apply grad reducer on grads
  309. grads = self.grad_reducer(grads)
  310. return F.depend(loss, self.optimizer(grads))
  311. class SSDWithMobileNetV2(nn.Cell):
  312. """
  313. MobileNetV2 architecture for SSD backbone.
  314. Args:
  315. width_mult (int): Channels multiplier for round to 8/16 and others. Default is 1.
  316. inverted_residual_setting (list): Inverted residual settings. Default is None
  317. round_nearest (list): Channel round to. Default is 8
  318. Returns:
  319. Tensor, the 13th feature after ConvBNReLU in MobileNetV2.
  320. Tensor, the last feature in MobileNetV2.
  321. Examples:
  322. >>> SSDWithMobileNetV2()
  323. """
  324. def __init__(self, width_mult=1.0, inverted_residual_setting=None, round_nearest=8):
  325. super(SSDWithMobileNetV2, self).__init__()
  326. block = InvertedResidual
  327. input_channel = 32
  328. last_channel = 1280
  329. if inverted_residual_setting is None:
  330. inverted_residual_setting = [
  331. # t, c, n, s
  332. [1, 16, 1, 1],
  333. [6, 24, 2, 2],
  334. [6, 32, 3, 2],
  335. [6, 64, 4, 2],
  336. [6, 96, 3, 1],
  337. [6, 160, 3, 2],
  338. [6, 320, 1, 1],
  339. ]
  340. if len(inverted_residual_setting[0]) != 4:
  341. raise ValueError("inverted_residual_setting should be non-empty "
  342. "or a 4-element list, got {}".format(inverted_residual_setting))
  343. #building first layer
  344. input_channel = _make_divisible(input_channel * width_mult, round_nearest)
  345. self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
  346. features = [ConvBNReLU(3, input_channel, stride=2)]
  347. # building inverted residual blocks
  348. layer_index = 0
  349. for t, c, n, s in inverted_residual_setting:
  350. output_channel = _make_divisible(c * width_mult, round_nearest)
  351. for i in range(n):
  352. if layer_index == 13:
  353. hidden_dim = int(round(input_channel * t))
  354. self.expand_layer_conv_13 = ConvBNReLU(input_channel, hidden_dim, kernel_size=1)
  355. stride = s if i == 0 else 1
  356. features.append(block(input_channel, output_channel, stride, expand_ratio=t))
  357. input_channel = output_channel
  358. layer_index += 1
  359. # building last several layers
  360. features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
  361. self.features_1 = nn.SequentialCell(features[:14])
  362. self.features_2 = nn.SequentialCell(features[14:])
  363. def construct(self, x):
  364. out = self.features_1(x)
  365. expand_layer_conv_13 = self.expand_layer_conv_13(out)
  366. out = self.features_2(out)
  367. return expand_layer_conv_13, out
  368. def get_out_channels(self):
  369. return self.last_channel
  370. def ssd_mobilenet_v2(**kwargs):
  371. return SSDWithMobileNetV2(**kwargs)