You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

yolo.py 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """YOLOv3 based on DarkNet."""
  16. import mindspore as ms
  17. import mindspore.nn as nn
  18. from mindspore.common.tensor import Tensor
  19. from mindspore import context
  20. from mindspore.parallel._auto_parallel_context import auto_parallel_context
  21. from mindspore.communication.management import get_group_size
  22. from mindspore.ops import operations as P
  23. from mindspore.ops import functional as F
  24. from mindspore.ops import composite as C
  25. from src.darknet import DarkNet, ResidualBlock
  26. from src.config import ConfigYOLOV3DarkNet53
  27. from src.loss import XYLoss, WHLoss, ConfidenceLoss, ClassLoss
  28. def _conv_bn_relu(in_channel,
  29. out_channel,
  30. ksize,
  31. stride=1,
  32. padding=0,
  33. dilation=1,
  34. alpha=0.1,
  35. momentum=0.9,
  36. eps=1e-5,
  37. pad_mode="same"):
  38. """Get a conv2d batchnorm and relu layer"""
  39. return nn.SequentialCell(
  40. [nn.Conv2d(in_channel,
  41. out_channel,
  42. kernel_size=ksize,
  43. stride=stride,
  44. padding=padding,
  45. dilation=dilation,
  46. pad_mode=pad_mode),
  47. nn.BatchNorm2d(out_channel, momentum=momentum, eps=eps),
  48. nn.LeakyReLU(alpha)]
  49. )
  50. class YoloBlock(nn.Cell):
  51. """
  52. YoloBlock for YOLOv3.
  53. Args:
  54. in_channels: Integer. Input channel.
  55. out_chls: Interger. Middle channel.
  56. out_channels: Integer. Output channel.
  57. Returns:
  58. Tuple, tuple of output tensor,(f1,f2,f3).
  59. Examples:
  60. YoloBlock(1024, 512, 255)
  61. """
  62. def __init__(self, in_channels, out_chls, out_channels):
  63. super(YoloBlock, self).__init__()
  64. out_chls_2 = out_chls*2
  65. self.conv0 = _conv_bn_relu(in_channels, out_chls, ksize=1)
  66. self.conv1 = _conv_bn_relu(out_chls, out_chls_2, ksize=3)
  67. self.conv2 = _conv_bn_relu(out_chls_2, out_chls, ksize=1)
  68. self.conv3 = _conv_bn_relu(out_chls, out_chls_2, ksize=3)
  69. self.conv4 = _conv_bn_relu(out_chls_2, out_chls, ksize=1)
  70. self.conv5 = _conv_bn_relu(out_chls, out_chls_2, ksize=3)
  71. self.conv6 = nn.Conv2d(out_chls_2, out_channels, kernel_size=1, stride=1, has_bias=True)
  72. def construct(self, x):
  73. c1 = self.conv0(x)
  74. c2 = self.conv1(c1)
  75. c3 = self.conv2(c2)
  76. c4 = self.conv3(c3)
  77. c5 = self.conv4(c4)
  78. c6 = self.conv5(c5)
  79. out = self.conv6(c6)
  80. return c5, out
  81. class YOLOv3(nn.Cell):
  82. """
  83. YOLOv3 Network.
  84. Note:
  85. backbone = darknet53
  86. Args:
  87. backbone_shape: List. Darknet output channels shape.
  88. backbone: Cell. Backbone Network.
  89. out_channel: Interger. Output channel.
  90. Returns:
  91. Tensor, output tensor.
  92. Examples:
  93. YOLOv3(backbone_shape=[64, 128, 256, 512, 1024]
  94. backbone=darknet53(),
  95. out_channel=255)
  96. """
  97. def __init__(self, backbone_shape, backbone, out_channel):
  98. super(YOLOv3, self).__init__()
  99. self.out_channel = out_channel
  100. self.backbone = backbone
  101. self.backblock0 = YoloBlock(backbone_shape[-1], out_chls=backbone_shape[-2], out_channels=out_channel)
  102. self.conv1 = _conv_bn_relu(in_channel=backbone_shape[-2], out_channel=backbone_shape[-2]//2, ksize=1)
  103. self.backblock1 = YoloBlock(in_channels=backbone_shape[-2]+backbone_shape[-3],
  104. out_chls=backbone_shape[-3],
  105. out_channels=out_channel)
  106. self.conv2 = _conv_bn_relu(in_channel=backbone_shape[-3], out_channel=backbone_shape[-3]//2, ksize=1)
  107. self.backblock2 = YoloBlock(in_channels=backbone_shape[-3]+backbone_shape[-4],
  108. out_chls=backbone_shape[-4],
  109. out_channels=out_channel)
  110. self.concat = P.Concat(axis=1)
  111. def construct(self, x):
  112. # input_shape of x is (batch_size, 3, h, w)
  113. # feature_map1 is (batch_size, backbone_shape[2], h/8, w/8)
  114. # feature_map2 is (batch_size, backbone_shape[3], h/16, w/16)
  115. # feature_map3 is (batch_size, backbone_shape[4], h/32, w/32)
  116. img_hight = P.Shape()(x)[2]
  117. img_width = P.Shape()(x)[3]
  118. feature_map1, feature_map2, feature_map3 = self.backbone(x)
  119. con1, big_object_output = self.backblock0(feature_map3)
  120. con1 = self.conv1(con1)
  121. ups1 = P.ResizeNearestNeighbor((img_hight / 16, img_width / 16))(con1)
  122. con1 = self.concat((ups1, feature_map2))
  123. con2, medium_object_output = self.backblock1(con1)
  124. con2 = self.conv2(con2)
  125. ups2 = P.ResizeNearestNeighbor((img_hight / 8, img_width / 8))(con2)
  126. con3 = self.concat((ups2, feature_map1))
  127. _, small_object_output = self.backblock2(con3)
  128. return big_object_output, medium_object_output, small_object_output
  129. class DetectionBlock(nn.Cell):
  130. """
  131. YOLOv3 detection Network. It will finally output the detection result.
  132. Args:
  133. scale: Character.
  134. config: ConfigYOLOV3DarkNet53, Configuration instance.
  135. is_training: Bool, Whether train or not, default True.
  136. Returns:
  137. Tuple, tuple of output tensor,(f1,f2,f3).
  138. Examples:
  139. DetectionBlock(scale='l',stride=32)
  140. """
  141. def __init__(self, scale, config=ConfigYOLOV3DarkNet53(), is_training=True):
  142. super(DetectionBlock, self).__init__()
  143. self.config = config
  144. if scale == 's':
  145. idx = (0, 1, 2)
  146. elif scale == 'm':
  147. idx = (3, 4, 5)
  148. elif scale == 'l':
  149. idx = (6, 7, 8)
  150. else:
  151. raise KeyError("Invalid scale value for DetectionBlock")
  152. self.anchors = Tensor([self.config.anchor_scales[i] for i in idx], ms.float32)
  153. self.num_anchors_per_scale = 3
  154. self.num_attrib = 4+1+self.config.num_classes
  155. self.lambda_coord = 1
  156. self.sigmoid = nn.Sigmoid()
  157. self.reshape = P.Reshape()
  158. self.tile = P.Tile()
  159. self.concat = P.Concat(axis=-1)
  160. self.conf_training = is_training
  161. def construct(self, x, input_shape):
  162. num_batch = P.Shape()(x)[0]
  163. grid_size = P.Shape()(x)[2:4]
  164. # Reshape and transpose the feature to [n, grid_size[0], grid_size[1], 3, num_attrib]
  165. prediction = P.Reshape()(x, (num_batch,
  166. self.num_anchors_per_scale,
  167. self.num_attrib,
  168. grid_size[0],
  169. grid_size[1]))
  170. prediction = P.Transpose()(prediction, (0, 3, 4, 1, 2))
  171. range_x = range(grid_size[1])
  172. range_y = range(grid_size[0])
  173. grid_x = P.Cast()(F.tuple_to_array(range_x), ms.float32)
  174. grid_y = P.Cast()(F.tuple_to_array(range_y), ms.float32)
  175. # Tensor of shape [grid_size[0], grid_size[1], 1, 1] representing the coordinate of x/y axis for each grid
  176. # [batch, gridx, gridy, 1, 1]
  177. grid_x = self.tile(self.reshape(grid_x, (1, 1, -1, 1, 1)), (1, grid_size[0], 1, 1, 1))
  178. grid_y = self.tile(self.reshape(grid_y, (1, -1, 1, 1, 1)), (1, 1, grid_size[1], 1, 1))
  179. # Shape is [grid_size[0], grid_size[1], 1, 2]
  180. grid = self.concat((grid_x, grid_y))
  181. box_xy = prediction[:, :, :, :, :2]
  182. box_wh = prediction[:, :, :, :, 2:4]
  183. box_confidence = prediction[:, :, :, :, 4:5]
  184. box_probs = prediction[:, :, :, :, 5:]
  185. # gridsize1 is x
  186. # gridsize0 is y
  187. box_xy = (self.sigmoid(box_xy) + grid) / P.Cast()(F.tuple_to_array((grid_size[1], grid_size[0])), ms.float32)
  188. # box_wh is w->h
  189. box_wh = P.Exp()(box_wh) * self.anchors / input_shape
  190. box_confidence = self.sigmoid(box_confidence)
  191. box_probs = self.sigmoid(box_probs)
  192. if self.conf_training:
  193. return grid, prediction, box_xy, box_wh
  194. return self.concat((box_xy, box_wh, box_confidence, box_probs))
  195. class Iou(nn.Cell):
  196. """Calculate the iou of boxes"""
  197. def __init__(self):
  198. super(Iou, self).__init__()
  199. self.min = P.Minimum()
  200. self.max = P.Maximum()
  201. def construct(self, box1, box2):
  202. # box1: pred_box [batch, gx, gy, anchors, 1, 4] ->4: [x_center, y_center, w, h]
  203. # box2: gt_box [batch, 1, 1, 1, maxbox, 4]
  204. # convert to topLeft and rightDown
  205. box1_xy = box1[:, :, :, :, :, :2]
  206. box1_wh = box1[:, :, :, :, :, 2:4]
  207. box1_mins = box1_xy - box1_wh / F.scalar_to_array(2.0) # topLeft
  208. box1_maxs = box1_xy + box1_wh / F.scalar_to_array(2.0) # rightDown
  209. box2_xy = box2[:, :, :, :, :, :2]
  210. box2_wh = box2[:, :, :, :, :, 2:4]
  211. box2_mins = box2_xy - box2_wh / F.scalar_to_array(2.0)
  212. box2_maxs = box2_xy + box2_wh / F.scalar_to_array(2.0)
  213. intersect_mins = self.max(box1_mins, box2_mins)
  214. intersect_maxs = self.min(box1_maxs, box2_maxs)
  215. intersect_wh = self.max(intersect_maxs - intersect_mins, F.scalar_to_array(0.0))
  216. # P.squeeze: for effiecient slice
  217. intersect_area = P.Squeeze(-1)(intersect_wh[:, :, :, :, :, 0:1]) * \
  218. P.Squeeze(-1)(intersect_wh[:, :, :, :, :, 1:2])
  219. box1_area = P.Squeeze(-1)(box1_wh[:, :, :, :, :, 0:1]) * P.Squeeze(-1)(box1_wh[:, :, :, :, :, 1:2])
  220. box2_area = P.Squeeze(-1)(box2_wh[:, :, :, :, :, 0:1]) * P.Squeeze(-1)(box2_wh[:, :, :, :, :, 1:2])
  221. iou = intersect_area / (box1_area + box2_area - intersect_area)
  222. # iou : [batch, gx, gy, anchors, maxboxes]
  223. return iou
  224. class YoloLossBlock(nn.Cell):
  225. """
  226. Loss block cell of YOLOV3 network.
  227. """
  228. def __init__(self, scale, config=ConfigYOLOV3DarkNet53()):
  229. super(YoloLossBlock, self).__init__()
  230. self.config = config
  231. if scale == 's':
  232. # anchor mask
  233. idx = (0, 1, 2)
  234. elif scale == 'm':
  235. idx = (3, 4, 5)
  236. elif scale == 'l':
  237. idx = (6, 7, 8)
  238. else:
  239. raise KeyError("Invalid scale value for DetectionBlock")
  240. self.anchors = Tensor([self.config.anchor_scales[i] for i in idx], ms.float32)
  241. self.ignore_threshold = Tensor(self.config.ignore_threshold, ms.float32)
  242. self.concat = P.Concat(axis=-1)
  243. self.iou = Iou()
  244. self.reduce_max = P.ReduceMax(keep_dims=False)
  245. self.xy_loss = XYLoss()
  246. self.wh_loss = WHLoss()
  247. self.confidenceLoss = ConfidenceLoss()
  248. self.classLoss = ClassLoss()
  249. def construct(self, grid, prediction, pred_xy, pred_wh, y_true, gt_box, input_shape):
  250. # prediction : origin output from yolo
  251. # pred_xy: (sigmoid(xy)+grid)/grid_size
  252. # pred_wh: (exp(wh)*anchors)/input_shape
  253. # y_true : after normalize
  254. # gt_box: [batch, maxboxes, xyhw] after normalize
  255. object_mask = y_true[:, :, :, :, 4:5]
  256. class_probs = y_true[:, :, :, :, 5:]
  257. grid_shape = P.Shape()(prediction)[1:3]
  258. grid_shape = P.Cast()(F.tuple_to_array(grid_shape[::-1]), ms.float32)
  259. pred_boxes = self.concat((pred_xy, pred_wh))
  260. true_xy = y_true[:, :, :, :, :2] * grid_shape - grid
  261. true_wh = y_true[:, :, :, :, 2:4]
  262. true_wh = P.Select()(P.Equal()(true_wh, 0.0),
  263. P.Fill()(P.DType()(true_wh),
  264. P.Shape()(true_wh), 1.0),
  265. true_wh)
  266. true_wh = P.Log()(true_wh / self.anchors * input_shape)
  267. # 2-w*h for large picture, use small scale, since small obj need more precise
  268. box_loss_scale = 2 - y_true[:, :, :, :, 2:3] * y_true[:, :, :, :, 3:4]
  269. gt_shape = P.Shape()(gt_box)
  270. gt_box = P.Reshape()(gt_box, (gt_shape[0], 1, 1, 1, gt_shape[1], gt_shape[2]))
  271. # add one more dimension for broadcast
  272. iou = self.iou(P.ExpandDims()(pred_boxes, -2), gt_box)
  273. # gt_box is x,y,h,w after normalize
  274. # [batch, grid[0], grid[1], num_anchor, num_gt]
  275. best_iou = self.reduce_max(iou, -1)
  276. # [batch, grid[0], grid[1], num_anchor]
  277. # ignore_mask IOU too small
  278. ignore_mask = best_iou < self.ignore_threshold
  279. ignore_mask = P.Cast()(ignore_mask, ms.float32)
  280. ignore_mask = P.ExpandDims()(ignore_mask, -1)
  281. # ignore_mask backpro will cause a lot maximunGrad and minimumGrad time consume.
  282. # so we turn off its gradient
  283. ignore_mask = F.stop_gradient(ignore_mask)
  284. xy_loss = self.xy_loss(object_mask, box_loss_scale, prediction[:, :, :, :, :2], true_xy)
  285. wh_loss = self.wh_loss(object_mask, box_loss_scale, prediction[:, :, :, :, 2:4], true_wh)
  286. confidence_loss = self.confidenceLoss(object_mask, prediction[:, :, :, :, 4:5], ignore_mask)
  287. class_loss = self.classLoss(object_mask, prediction[:, :, :, :, 5:], class_probs)
  288. loss = xy_loss + wh_loss + confidence_loss + class_loss
  289. batch_size = P.Shape()(prediction)[0]
  290. return loss / batch_size
  291. class YOLOV3DarkNet53(nn.Cell):
  292. """
  293. Darknet based YOLOV3 network.
  294. Args:
  295. is_training: Bool. Whether train or not.
  296. Returns:
  297. Cell, cell instance of Darknet based YOLOV3 neural network.
  298. Examples:
  299. YOLOV3DarkNet53(True)
  300. """
  301. def __init__(self, is_training):
  302. super(YOLOV3DarkNet53, self).__init__()
  303. self.config = ConfigYOLOV3DarkNet53()
  304. # YOLOv3 network
  305. self.feature_map = YOLOv3(backbone=DarkNet(ResidualBlock, self.config.backbone_layers,
  306. self.config.backbone_input_shape,
  307. self.config.backbone_shape,
  308. detect=True),
  309. backbone_shape=self.config.backbone_shape,
  310. out_channel=self.config.out_channel)
  311. # prediction on the default anchor boxes
  312. self.detect_1 = DetectionBlock('l', is_training=is_training)
  313. self.detect_2 = DetectionBlock('m', is_training=is_training)
  314. self.detect_3 = DetectionBlock('s', is_training=is_training)
  315. def construct(self, x, input_shape):
  316. big_object_output, medium_object_output, small_object_output = self.feature_map(x)
  317. output_big = self.detect_1(big_object_output, input_shape)
  318. output_me = self.detect_2(medium_object_output, input_shape)
  319. output_small = self.detect_3(small_object_output, input_shape)
  320. # big is the final output which has smallest feature map
  321. return output_big, output_me, output_small
  322. class YoloWithLossCell(nn.Cell):
  323. """YOLOV3 loss."""
  324. def __init__(self, network):
  325. super(YoloWithLossCell, self).__init__()
  326. self.yolo_network = network
  327. self.config = ConfigYOLOV3DarkNet53()
  328. self.loss_big = YoloLossBlock('l', self.config)
  329. self.loss_me = YoloLossBlock('m', self.config)
  330. self.loss_small = YoloLossBlock('s', self.config)
  331. def construct(self, x, y_true_0, y_true_1, y_true_2, gt_0, gt_1, gt_2, input_shape):
  332. yolo_out = self.yolo_network(x, input_shape)
  333. loss_l = self.loss_big(*yolo_out[0], y_true_0, gt_0, input_shape)
  334. loss_m = self.loss_me(*yolo_out[1], y_true_1, gt_1, input_shape)
  335. loss_s = self.loss_small(*yolo_out[2], y_true_2, gt_2, input_shape)
  336. return loss_l + loss_m + loss_s
  337. class TrainingWrapper(nn.Cell):
  338. """Training wrapper."""
  339. def __init__(self, network, optimizer, sens=1.0):
  340. super(TrainingWrapper, self).__init__(auto_prefix=False)
  341. self.network = network
  342. self.weights = optimizer.parameters
  343. self.optimizer = optimizer
  344. self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
  345. self.sens = sens
  346. self.reducer_flag = False
  347. self.grad_reducer = None
  348. self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
  349. if self.parallel_mode in [ms.ParallelMode.DATA_PARALLEL, ms.ParallelMode.HYBRID_PARALLEL]:
  350. self.reducer_flag = True
  351. if self.reducer_flag:
  352. mean = context.get_auto_parallel_context("mirror_mean")
  353. if auto_parallel_context().get_device_num_is_set():
  354. degree = context.get_auto_parallel_context("device_num")
  355. else:
  356. degree = get_group_size()
  357. self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
  358. def construct(self, *args):
  359. weights = self.weights
  360. loss = self.network(*args)
  361. sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
  362. grads = self.grad(self.network, weights)(*args, sens)
  363. if self.reducer_flag:
  364. grads = self.grad_reducer(grads)
  365. return F.depend(loss, self.optimizer(grads))