Browse Source

!6257 modify yolov3-resnet18 test case and improve yolov3-darknet-quant performance

Merge pull request !6257 from chengxb7532/master
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
59a63d2566
3 changed files with 12 additions and 9 deletions
  1. +3
    -0
      model_zoo/official/cv/yolov3_darknet53_quant/src/yolo_dataset.py
  2. +7
    -7
      model_zoo/official/cv/yolov3_darknet53_quant/train.py
  3. +2
    -2
      tests/st/model_zoo_tests/yolov3/test_yolov3.py

+ 3
- 0
model_zoo/official/cv/yolov3_darknet53_quant/src/yolo_dataset.py View File

@@ -15,6 +15,7 @@
"""YOLOV3 dataset.""" """YOLOV3 dataset."""
import os import os
import cv2
from PIL import Image from PIL import Image
from pycocotools.coco import COCO from pycocotools.coco import COCO
import mindspore.dataset as de import mindspore.dataset as de
@@ -142,6 +143,8 @@ class COCOYoloDataset:
def create_yolo_dataset(image_dir, anno_path, batch_size, max_epoch, device_num, rank, def create_yolo_dataset(image_dir, anno_path, batch_size, max_epoch, device_num, rank,
config=None, is_training=True, shuffle=True): config=None, is_training=True, shuffle=True):
"""Create dataset for YOLOV3.""" """Create dataset for YOLOV3."""
cv2.setNumThreads(0)
if is_training: if is_training:
filter_crowd = True filter_crowd = True
remove_empty_anno = True remove_empty_anno = True


+ 7
- 7
model_zoo/official/cv/yolov3_darknet53_quant/train.py View File

@@ -313,7 +313,7 @@ def train():
args.logger.info('iter[{}], shape{}'.format(i, input_shape[0])) args.logger.info('iter[{}], shape{}'.format(i, input_shape[0]))
shape_record.set(input_shape) shape_record.set(input_shape)


images = Tensor(images)
images = Tensor.from_numpy(images)
annos = data["annotation"] annos = data["annotation"]
if args.group_size == 1: if args.group_size == 1:
batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, batch_gt_box2 = \ batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, batch_gt_box2 = \
@@ -322,12 +322,12 @@ def train():
batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, batch_gt_box2 = \ batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, batch_gt_box2 = \
batch_preprocess_true_box_single(annos, config, input_shape) batch_preprocess_true_box_single(annos, config, input_shape)


batch_y_true_0 = Tensor(batch_y_true_0)
batch_y_true_1 = Tensor(batch_y_true_1)
batch_y_true_2 = Tensor(batch_y_true_2)
batch_gt_box0 = Tensor(batch_gt_box0)
batch_gt_box1 = Tensor(batch_gt_box1)
batch_gt_box2 = Tensor(batch_gt_box2)
batch_y_true_0 = Tensor.from_numpy(batch_y_true_0)
batch_y_true_1 = Tensor.from_numpy(batch_y_true_1)
batch_y_true_2 = Tensor.from_numpy(batch_y_true_2)
batch_gt_box0 = Tensor.from_numpy(batch_gt_box0)
batch_gt_box1 = Tensor.from_numpy(batch_gt_box1)
batch_gt_box2 = Tensor.from_numpy(batch_gt_box2)


input_shape = Tensor(tuple(input_shape[::-1]), ms.float32) input_shape = Tensor(tuple(input_shape[::-1]), ms.float32)
loss = network(images, batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, loss = network(images, batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1,


+ 2
- 2
tests/st/model_zoo_tests/yolov3/test_yolov3.py View File

@@ -146,12 +146,12 @@ def test_yolov3():
assert loss_value[2] < expect_loss_value[2] assert loss_value[2] < expect_loss_value[2]


epoch_mseconds = np.array(time_monitor_callback.epoch_mseconds_list)[2] epoch_mseconds = np.array(time_monitor_callback.epoch_mseconds_list)[2]
expect_epoch_mseconds = 950
expect_epoch_mseconds = 1250
print("epoch mseconds: {}".format(epoch_mseconds)) print("epoch mseconds: {}".format(epoch_mseconds))
assert epoch_mseconds <= expect_epoch_mseconds assert epoch_mseconds <= expect_epoch_mseconds


per_step_mseconds = np.array(time_monitor_callback.per_step_mseconds_list)[2] per_step_mseconds = np.array(time_monitor_callback.per_step_mseconds_list)[2]
expect_per_step_mseconds = 110
expect_per_step_mseconds = 120
print("per step mseconds: {}".format(per_step_mseconds)) print("per step mseconds: {}".format(per_step_mseconds))
assert per_step_mseconds <= expect_per_step_mseconds assert per_step_mseconds <= expect_per_step_mseconds
print("yolov3 test case passed.") print("yolov3 test case passed.")

Loading…
Cancel
Save