You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

yolo.py 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """YOLOv4 based on DarkNet."""
  16. import mindspore as ms
  17. import mindspore.nn as nn
  18. from mindspore.common.tensor import Tensor
  19. from mindspore import context
  20. from mindspore.context import ParallelMode
  21. from mindspore.parallel._auto_parallel_context import auto_parallel_context
  22. from mindspore.communication.management import get_group_size
  23. from mindspore.ops import operations as P
  24. from mindspore.ops import functional as F
  25. from mindspore.ops import composite as C
  26. from src.cspdarknet53 import CspDarkNet53, ResidualBlock
  27. from src.config import ConfigYOLOV4CspDarkNet53
  28. from src.loss import XYLoss, WHLoss, ConfidenceLoss, ClassLoss
  29. def _conv_bn_leakyrelu(in_channel,
  30. out_channel,
  31. ksize,
  32. stride=1,
  33. padding=0,
  34. dilation=1,
  35. alpha=0.1,
  36. momentum=0.9,
  37. eps=1e-5,
  38. pad_mode="same"):
  39. """Get a conv2d batchnorm and relu layer"""
  40. return nn.SequentialCell(
  41. [nn.Conv2d(in_channel,
  42. out_channel,
  43. kernel_size=ksize,
  44. stride=stride,
  45. padding=padding,
  46. dilation=dilation,
  47. pad_mode=pad_mode),
  48. nn.BatchNorm2d(out_channel, momentum=momentum, eps=eps),
  49. nn.LeakyReLU(alpha)]
  50. )
  51. class YoloBlock(nn.Cell):
  52. """
  53. YoloBlock for YOLOv4.
  54. Args:
  55. in_channels: Integer. Input channel.
  56. out_chls: Integer. Middle channel.
  57. out_channels: Integer. Output channel.
  58. Returns:
  59. Tuple, tuple of output tensor,(f1,f2,f3).
  60. Examples:
  61. YoloBlock(1024, 512, 255)
  62. """
  63. def __init__(self, in_channels, out_chls, out_channels):
  64. super(YoloBlock, self).__init__()
  65. out_chls_2 = out_chls*2
  66. self.conv0 = _conv_bn_leakyrelu(in_channels, out_chls, ksize=1)
  67. self.conv1 = _conv_bn_leakyrelu(out_chls, out_chls_2, ksize=3)
  68. self.conv2 = _conv_bn_leakyrelu(out_chls_2, out_chls, ksize=1)
  69. self.conv3 = _conv_bn_leakyrelu(out_chls, out_chls_2, ksize=3)
  70. self.conv4 = _conv_bn_leakyrelu(out_chls_2, out_chls, ksize=1)
  71. self.conv5 = _conv_bn_leakyrelu(out_chls, out_chls_2, ksize=3)
  72. self.conv6 = nn.Conv2d(out_chls_2, out_channels, kernel_size=1, stride=1, has_bias=True)
  73. def construct(self, x):
  74. """construct method"""
  75. c1 = self.conv0(x)
  76. c2 = self.conv1(c1)
  77. c3 = self.conv2(c2)
  78. c4 = self.conv3(c3)
  79. c5 = self.conv4(c4)
  80. c6 = self.conv5(c5)
  81. out = self.conv6(c6)
  82. return c5, out
  83. class YOLOv4(nn.Cell):
  84. """
  85. YOLOv4 Network.
  86. Note:
  87. backbone = CspDarkNet53
  88. Args:
  89. num_classes: Integer. Class number.
  90. feature_shape: List. Input image shape, [N,C,H,W].
  91. backbone_shape: List. Darknet output channels shape.
  92. backbone: Cell. Backbone Network.
  93. out_channel: Integer. Output channel.
  94. Returns:
  95. Tensor, output tensor.
  96. Examples:
  97. YOLOv4(feature_shape=[1,3,416,416],
  98. backbone_shape=[64, 128, 256, 512, 1024]
  99. backbone=CspDarkNet53(),
  100. out_channel=255)
  101. """
  102. def __init__(self, backbone_shape, backbone, out_channel):
  103. super(YOLOv4, self).__init__()
  104. self.out_channel = out_channel
  105. self.backbone = backbone
  106. self.conv1 = _conv_bn_leakyrelu(1024, 512, ksize=1)
  107. self.conv2 = _conv_bn_leakyrelu(512, 1024, ksize=3)
  108. self.conv3 = _conv_bn_leakyrelu(1024, 512, ksize=1)
  109. self.maxpool1 = nn.MaxPool2d(kernel_size=5, stride=1, pad_mode='same')
  110. self.maxpool2 = nn.MaxPool2d(kernel_size=9, stride=1, pad_mode='same')
  111. self.maxpool3 = nn.MaxPool2d(kernel_size=13, stride=1, pad_mode='same')
  112. self.conv4 = _conv_bn_leakyrelu(2048, 512, ksize=1)
  113. self.conv5 = _conv_bn_leakyrelu(512, 1024, ksize=3)
  114. self.conv6 = _conv_bn_leakyrelu(1024, 512, ksize=1)
  115. self.conv7 = _conv_bn_leakyrelu(512, 256, ksize=1)
  116. self.conv8 = _conv_bn_leakyrelu(512, 256, ksize=1)
  117. self.backblock0 = YoloBlock(backbone_shape[-2], out_chls=backbone_shape[-3], out_channels=out_channel)
  118. self.conv9 = _conv_bn_leakyrelu(256, 128, ksize=1)
  119. self.conv10 = _conv_bn_leakyrelu(256, 128, ksize=1)
  120. self.conv11 = _conv_bn_leakyrelu(128, 256, ksize=3, stride=2)
  121. self.conv12 = _conv_bn_leakyrelu(256, 512, ksize=3, stride=2)
  122. self.backblock1 = YoloBlock(backbone_shape[-3], out_chls=backbone_shape[-4], out_channels=out_channel)
  123. self.backblock2 = YoloBlock(backbone_shape[-2], out_chls=backbone_shape[-3], out_channels=out_channel)
  124. self.backblock3 = YoloBlock(backbone_shape[-1], out_chls=backbone_shape[-2], out_channels=out_channel)
  125. self.concat = P.Concat(axis=1)
  126. def construct(self, x):
  127. """
  128. input_shape of x is (batch_size, 3, h, w)
  129. feature_map1 is (batch_size, backbone_shape[2], h/8, w/8)
  130. feature_map2 is (batch_size, backbone_shape[3], h/16, w/16)
  131. feature_map3 is (batch_size, backbone_shape[4], h/32, w/32)
  132. """
  133. img_hight = P.Shape()(x)[2]
  134. img_width = P.Shape()(x)[3]
  135. # input=(1,3,608,608)
  136. # feature_map1=(1,256,76,76)
  137. # feature_map2=(1,512,38,38)
  138. # feature_map3=(1,1024,19,19)
  139. feature_map1, feature_map2, feature_map3 = self.backbone(x)
  140. con1 = self.conv1(feature_map3)
  141. con2 = self.conv2(con1)
  142. con3 = self.conv3(con2)
  143. m1 = self.maxpool1(con3)
  144. m2 = self.maxpool2(con3)
  145. m3 = self.maxpool3(con3)
  146. spp = self.concat((m3, m2, m1, con3))
  147. con4 = self.conv4(spp)
  148. con5 = self.conv5(con4)
  149. con6 = self.conv6(con5)
  150. con7 = self.conv7(con6)
  151. ups1 = P.ResizeNearestNeighbor((img_hight / 16, img_width / 16))(con7)
  152. con8 = self.conv8(feature_map2)
  153. con9 = self.concat((ups1, con8))
  154. con10, _ = self.backblock0(con9)
  155. con11 = self.conv9(con10)
  156. ups2 = P.ResizeNearestNeighbor((img_hight / 8, img_width / 8))(con11)
  157. con12 = self.conv10(feature_map1)
  158. con13 = self.concat((ups2, con12))
  159. con14, small_object_output = self.backblock1(con13)
  160. con15 = self.conv11(con14)
  161. con16 = self.concat((con15, con10))
  162. con17, medium_object_output = self.backblock2(con16)
  163. con18 = self.conv12(con17)
  164. con19 = self.concat((con18, con6))
  165. _, big_object_output = self.backblock3(con19)
  166. return big_object_output, medium_object_output, small_object_output
  167. class DetectionBlock(nn.Cell):
  168. """
  169. YOLOv4 detection Network. It will finally output the detection result.
  170. Args:
  171. scale: Character.
  172. config: ConfigYOLOV4CspDarkNet53, Configuration instance.
  173. is_training: Bool, Whether train or not, default True.
  174. Returns:
  175. Tuple, tuple of output tensor,(f1,f2,f3).
  176. Examples:
  177. DetectionBlock(scale='l',stride=32)
  178. """
  179. def __init__(self, scale, config=ConfigYOLOV4CspDarkNet53(), is_training=True):
  180. super(DetectionBlock, self).__init__()
  181. self.config = config
  182. if scale == 's':
  183. idx = (0, 1, 2)
  184. self.scale_x_y = 1.2
  185. self.offset_x_y = 0.1
  186. elif scale == 'm':
  187. idx = (3, 4, 5)
  188. self.scale_x_y = 1.1
  189. self.offset_x_y = 0.05
  190. elif scale == 'l':
  191. idx = (6, 7, 8)
  192. self.scale_x_y = 1.05
  193. self.offset_x_y = 0.025
  194. else:
  195. raise KeyError("Invalid scale value for DetectionBlock")
  196. self.anchors = Tensor([self.config.anchor_scales[i] for i in idx], ms.float32)
  197. self.num_anchors_per_scale = 3
  198. self.num_attrib = 4+1+self.config.num_classes
  199. self.lambda_coord = 1
  200. self.sigmoid = nn.Sigmoid()
  201. self.reshape = P.Reshape()
  202. self.tile = P.Tile()
  203. self.concat = P.Concat(axis=-1)
  204. self.conf_training = is_training
  205. def construct(self, x, input_shape):
  206. """construct method"""
  207. num_batch = P.Shape()(x)[0]
  208. grid_size = P.Shape()(x)[2:4]
  209. # Reshape and transpose the feature to [n, grid_size[0], grid_size[1], 3, num_attrib]
  210. prediction = P.Reshape()(x, (num_batch,
  211. self.num_anchors_per_scale,
  212. self.num_attrib,
  213. grid_size[0],
  214. grid_size[1]))
  215. prediction = P.Transpose()(prediction, (0, 3, 4, 1, 2))
  216. range_x = range(grid_size[1])
  217. range_y = range(grid_size[0])
  218. grid_x = P.Cast()(F.tuple_to_array(range_x), ms.float32)
  219. grid_y = P.Cast()(F.tuple_to_array(range_y), ms.float32)
  220. # Tensor of shape [grid_size[0], grid_size[1], 1, 1] representing the coordinate of x/y axis for each grid
  221. # [batch, gridx, gridy, 1, 1]
  222. grid_x = self.tile(self.reshape(grid_x, (1, 1, -1, 1, 1)), (1, grid_size[0], 1, 1, 1))
  223. grid_y = self.tile(self.reshape(grid_y, (1, -1, 1, 1, 1)), (1, 1, grid_size[1], 1, 1))
  224. # Shape is [grid_size[0], grid_size[1], 1, 2]
  225. grid = self.concat((grid_x, grid_y))
  226. box_xy = prediction[:, :, :, :, :2]
  227. box_wh = prediction[:, :, :, :, 2:4]
  228. box_confidence = prediction[:, :, :, :, 4:5]
  229. box_probs = prediction[:, :, :, :, 5:]
  230. # gridsize1 is x
  231. # gridsize0 is y
  232. box_xy = (self.scale_x_y * self.sigmoid(box_xy) - self.offset_x_y + grid) / \
  233. P.Cast()(F.tuple_to_array((grid_size[1], grid_size[0])), ms.float32)
  234. # box_wh is w->h
  235. box_wh = P.Exp()(box_wh) * self.anchors / input_shape
  236. box_confidence = self.sigmoid(box_confidence)
  237. box_probs = self.sigmoid(box_probs)
  238. if self.conf_training:
  239. return prediction, box_xy, box_wh
  240. return self.concat((box_xy, box_wh, box_confidence, box_probs))
  241. class Iou(nn.Cell):
  242. """Calculate the iou of boxes"""
  243. def __init__(self):
  244. super(Iou, self).__init__()
  245. self.min = P.Minimum()
  246. self.max = P.Maximum()
  247. def construct(self, box1, box2):
  248. """
  249. box1: pred_box [batch, gx, gy, anchors, 1, 4] ->4: [x_center, y_center, w, h]
  250. box2: gt_box [batch, 1, 1, 1, maxbox, 4]
  251. convert to topLeft and rightDown
  252. """
  253. box1_xy = box1[:, :, :, :, :, :2]
  254. box1_wh = box1[:, :, :, :, :, 2:4]
  255. box1_mins = box1_xy - box1_wh / F.scalar_to_array(2.0) # topLeft
  256. box1_maxs = box1_xy + box1_wh / F.scalar_to_array(2.0) # rightDown
  257. box2_xy = box2[:, :, :, :, :, :2]
  258. box2_wh = box2[:, :, :, :, :, 2:4]
  259. box2_mins = box2_xy - box2_wh / F.scalar_to_array(2.0)
  260. box2_maxs = box2_xy + box2_wh / F.scalar_to_array(2.0)
  261. intersect_mins = self.max(box1_mins, box2_mins)
  262. intersect_maxs = self.min(box1_maxs, box2_maxs)
  263. intersect_wh = self.max(intersect_maxs - intersect_mins, F.scalar_to_array(0.0))
  264. # P.squeeze: for effiecient slice
  265. intersect_area = P.Squeeze(-1)(intersect_wh[:, :, :, :, :, 0:1]) * \
  266. P.Squeeze(-1)(intersect_wh[:, :, :, :, :, 1:2])
  267. box1_area = P.Squeeze(-1)(box1_wh[:, :, :, :, :, 0:1]) * P.Squeeze(-1)(box1_wh[:, :, :, :, :, 1:2])
  268. box2_area = P.Squeeze(-1)(box2_wh[:, :, :, :, :, 0:1]) * P.Squeeze(-1)(box2_wh[:, :, :, :, :, 1:2])
  269. iou = intersect_area / (box1_area + box2_area - intersect_area)
  270. # iou : [batch, gx, gy, anchors, maxboxes]
  271. return iou
  272. class YoloLossBlock(nn.Cell):
  273. """
  274. Loss block cell of YOLOV4 network.
  275. """
  276. def __init__(self, scale, config=ConfigYOLOV4CspDarkNet53()):
  277. super(YoloLossBlock, self).__init__()
  278. self.config = config
  279. if scale == 's':
  280. # anchor mask
  281. idx = (0, 1, 2)
  282. elif scale == 'm':
  283. idx = (3, 4, 5)
  284. elif scale == 'l':
  285. idx = (6, 7, 8)
  286. else:
  287. raise KeyError("Invalid scale value for DetectionBlock")
  288. self.anchors = Tensor([self.config.anchor_scales[i] for i in idx], ms.float32)
  289. self.ignore_threshold = Tensor(self.config.ignore_threshold, ms.float32)
  290. self.concat = P.Concat(axis=-1)
  291. self.iou = Iou()
  292. self.reduce_max = P.ReduceMax(keep_dims=False)
  293. self.xy_loss = XYLoss()
  294. self.wh_loss = WHLoss()
  295. self.confidence_loss = ConfidenceLoss()
  296. self.class_loss = ClassLoss()
  297. self.reduce_sum = P.ReduceSum()
  298. self.giou = Giou()
  299. def construct(self, prediction, pred_xy, pred_wh, y_true, gt_box, input_shape):
  300. """
  301. prediction : origin output from yolo
  302. pred_xy: (sigmoid(xy)+grid)/grid_size
  303. pred_wh: (exp(wh)*anchors)/input_shape
  304. y_true : after normalize
  305. gt_box: [batch, maxboxes, xyhw] after normalize
  306. """
  307. object_mask = y_true[:, :, :, :, 4:5]
  308. class_probs = y_true[:, :, :, :, 5:]
  309. true_boxes = y_true[:, :, :, :, :4]
  310. grid_shape = P.Shape()(prediction)[1:3]
  311. grid_shape = P.Cast()(F.tuple_to_array(grid_shape[::-1]), ms.float32)
  312. pred_boxes = self.concat((pred_xy, pred_wh))
  313. true_wh = y_true[:, :, :, :, 2:4]
  314. true_wh = P.Select()(P.Equal()(true_wh, 0.0),
  315. P.Fill()(P.DType()(true_wh),
  316. P.Shape()(true_wh), 1.0),
  317. true_wh)
  318. true_wh = P.Log()(true_wh / self.anchors * input_shape)
  319. # 2-w*h for large picture, use small scale, since small obj need more precise
  320. box_loss_scale = 2 - y_true[:, :, :, :, 2:3] * y_true[:, :, :, :, 3:4]
  321. gt_shape = P.Shape()(gt_box)
  322. gt_box = P.Reshape()(gt_box, (gt_shape[0], 1, 1, 1, gt_shape[1], gt_shape[2]))
  323. # add one more dimension for broadcast
  324. iou = self.iou(P.ExpandDims()(pred_boxes, -2), gt_box)
  325. # gt_box is x,y,h,w after normalize
  326. # [batch, grid[0], grid[1], num_anchor, num_gt]
  327. best_iou = self.reduce_max(iou, -1)
  328. # [batch, grid[0], grid[1], num_anchor]
  329. # ignore_mask IOU too small
  330. ignore_mask = best_iou < self.ignore_threshold
  331. ignore_mask = P.Cast()(ignore_mask, ms.float32)
  332. ignore_mask = P.ExpandDims()(ignore_mask, -1)
  333. # ignore_mask backpro will cause a lot maximunGrad and minimumGrad time consume.
  334. # so we turn off its gradient
  335. ignore_mask = F.stop_gradient(ignore_mask)
  336. confidence_loss = self.confidence_loss(object_mask, prediction[:, :, :, :, 4:5], ignore_mask)
  337. class_loss = self.class_loss(object_mask, prediction[:, :, :, :, 5:], class_probs)
  338. object_mask_me = P.Reshape()(object_mask, (-1, 1)) # [8, 72, 72, 3, 1]
  339. box_loss_scale_me = P.Reshape()(box_loss_scale, (-1, 1))
  340. pred_boxes_me = xywh2x1y1x2y2(pred_boxes)
  341. pred_boxes_me = P.Reshape()(pred_boxes_me, (-1, 4))
  342. true_boxes_me = xywh2x1y1x2y2(true_boxes)
  343. true_boxes_me = P.Reshape()(true_boxes_me, (-1, 4))
  344. ciou = self.giou(pred_boxes_me, true_boxes_me)
  345. ciou_loss = object_mask_me * box_loss_scale_me * (1 - ciou)
  346. ciou_loss_me = self.reduce_sum(ciou_loss, ())
  347. loss = ciou_loss_me * 10 + confidence_loss + class_loss
  348. batch_size = P.Shape()(prediction)[0]
  349. return loss / batch_size
  350. class YOLOV4CspDarkNet53(nn.Cell):
  351. """
  352. Darknet based YOLOV4 network.
  353. Args:
  354. is_training: Bool. Whether train or not.
  355. Returns:
  356. Cell, cell instance of Darknet based YOLOV4 neural network.
  357. Examples:
  358. YOLOV4CspDarkNet53(True)
  359. """
  360. def __init__(self, is_training):
  361. super(YOLOV4CspDarkNet53, self).__init__()
  362. self.config = ConfigYOLOV4CspDarkNet53()
  363. # YOLOv4 network
  364. self.feature_map = YOLOv4(backbone=CspDarkNet53(ResidualBlock, detect=True),
  365. backbone_shape=self.config.backbone_shape,
  366. out_channel=self.config.out_channel)
  367. # prediction on the default anchor boxes
  368. self.detect_1 = DetectionBlock('l', is_training=is_training)
  369. self.detect_2 = DetectionBlock('m', is_training=is_training)
  370. self.detect_3 = DetectionBlock('s', is_training=is_training)
  371. def construct(self, x, input_shape):
  372. big_object_output, medium_object_output, small_object_output = self.feature_map(x)
  373. output_big = self.detect_1(big_object_output, input_shape)
  374. output_me = self.detect_2(medium_object_output, input_shape)
  375. output_small = self.detect_3(small_object_output, input_shape)
  376. # big is the final output which has smallest feature map
  377. return output_big, output_me, output_small
  378. class YoloWithLossCell(nn.Cell):
  379. """YOLOV4 loss."""
  380. def __init__(self, network):
  381. super(YoloWithLossCell, self).__init__()
  382. self.yolo_network = network
  383. self.config = ConfigYOLOV4CspDarkNet53()
  384. self.loss_big = YoloLossBlock('l', self.config)
  385. self.loss_me = YoloLossBlock('m', self.config)
  386. self.loss_small = YoloLossBlock('s', self.config)
  387. def construct(self, x, y_true_0, y_true_1, y_true_2, gt_0, gt_1, gt_2, input_shape):
  388. yolo_out = self.yolo_network(x, input_shape)
  389. loss_l = self.loss_big(*yolo_out[0], y_true_0, gt_0, input_shape)
  390. loss_m = self.loss_me(*yolo_out[1], y_true_1, gt_1, input_shape)
  391. loss_s = self.loss_small(*yolo_out[2], y_true_2, gt_2, input_shape)
  392. return loss_l + loss_m + loss_s
  393. class TrainingWrapper(nn.Cell):
  394. """Training wrapper."""
  395. def __init__(self, network, optimizer, sens=1.0):
  396. super(TrainingWrapper, self).__init__(auto_prefix=False)
  397. self.network = network
  398. self.network.set_grad()
  399. self.weights = optimizer.parameters
  400. self.optimizer = optimizer
  401. self.grad = C.GradOperation(get_by_list=True, sens_param=True)
  402. self.sens = sens
  403. self.reducer_flag = False
  404. self.grad_reducer = None
  405. self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
  406. if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
  407. self.reducer_flag = True
  408. if self.reducer_flag:
  409. mean = context.get_auto_parallel_context("gradients_mean")
  410. if auto_parallel_context().get_device_num_is_set():
  411. degree = context.get_auto_parallel_context("device_num")
  412. else:
  413. degree = get_group_size()
  414. self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
  415. def construct(self, *args):
  416. weights = self.weights
  417. loss = self.network(*args)
  418. sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
  419. grads = self.grad(self.network, weights)(*args, sens)
  420. if self.reducer_flag:
  421. grads = self.grad_reducer(grads)
  422. return F.depend(loss, self.optimizer(grads))
  423. class Giou(nn.Cell):
  424. """Calculating giou"""
  425. def __init__(self):
  426. super(Giou, self).__init__()
  427. self.cast = P.Cast()
  428. self.reshape = P.Reshape()
  429. self.min = P.Minimum()
  430. self.max = P.Maximum()
  431. self.concat = P.Concat(axis=1)
  432. self.mean = P.ReduceMean()
  433. self.div = P.RealDiv()
  434. self.eps = 0.000001
  435. def construct(self, box_p, box_gt):
  436. """construct method"""
  437. box_p_area = (box_p[..., 2:3] - box_p[..., 0:1]) * (box_p[..., 3:4] - box_p[..., 1:2])
  438. box_gt_area = (box_gt[..., 2:3] - box_gt[..., 0:1]) * (box_gt[..., 3:4] - box_gt[..., 1:2])
  439. x_1 = self.max(box_p[..., 0:1], box_gt[..., 0:1])
  440. x_2 = self.min(box_p[..., 2:3], box_gt[..., 2:3])
  441. y_1 = self.max(box_p[..., 1:2], box_gt[..., 1:2])
  442. y_2 = self.min(box_p[..., 3:4], box_gt[..., 3:4])
  443. intersection = (y_2 - y_1) * (x_2 - x_1)
  444. xc_1 = self.min(box_p[..., 0:1], box_gt[..., 0:1])
  445. xc_2 = self.max(box_p[..., 2:3], box_gt[..., 2:3])
  446. yc_1 = self.min(box_p[..., 1:2], box_gt[..., 1:2])
  447. yc_2 = self.max(box_p[..., 3:4], box_gt[..., 3:4])
  448. c_area = (xc_2 - xc_1) * (yc_2 - yc_1)
  449. union = box_p_area + box_gt_area - intersection
  450. union = union + self.eps
  451. c_area = c_area + self.eps
  452. iou = self.div(self.cast(intersection, ms.float32), self.cast(union, ms.float32))
  453. res_mid0 = c_area - union
  454. res_mid1 = self.div(self.cast(res_mid0, ms.float32), self.cast(c_area, ms.float32))
  455. giou = iou - res_mid1
  456. giou = C.clip_by_value(giou, -1.0, 1.0)
  457. return giou
  458. def xywh2x1y1x2y2(box_xywh):
  459. boxes_x1 = box_xywh[..., 0:1] - box_xywh[..., 2:3] / 2
  460. boxes_y1 = box_xywh[..., 1:2] - box_xywh[..., 3:4] / 2
  461. boxes_x2 = box_xywh[..., 0:1] + box_xywh[..., 2:3] / 2
  462. boxes_y2 = box_xywh[..., 1:2] + box_xywh[..., 3:4] / 2
  463. boxes_x1y1x2y2 = P.Concat(-1)((boxes_x1, boxes_y1, boxes_x2, boxes_y2))
  464. return boxes_x1y1x2y2