Browse Source

!10978 modified two parameters to one parameter in yolov3_darknet53 network

From: @shuzigood
Reviewed-by: @wuxuejian,@linqingke
Signed-off-by: @wuxuejian
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
cfda3a49b6
4 changed files with 13 additions and 13 deletions
  1. +2
    -5
      model_zoo/official/cv/yolov3_darknet53/eval.py
  2. +1
    -2
      model_zoo/official/cv/yolov3_darknet53/export.py
  3. +9
    -3
      model_zoo/official/cv/yolov3_darknet53/src/yolo.py
  4. +1
    -3
      model_zoo/official/cv/yolov3_darknet53/train.py

+ 2
- 5
model_zoo/official/cv/yolov3_darknet53/eval.py View File

@@ -24,11 +24,9 @@ import numpy as np
from pycocotools.coco import COCO from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval from pycocotools.cocoeval import COCOeval


from mindspore import Tensor
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
from mindspore import context from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
import mindspore as ms


from src.yolo import YOLOV3DarkNet53 from src.yolo import YOLOV3DarkNet53
from src.logger import get_logger from src.logger import get_logger
@@ -297,7 +295,6 @@ def test():
# init detection engine # init detection engine
detection = DetectionEngine(args) detection = DetectionEngine(args)


input_shape = Tensor(tuple(config.test_img_shape), ms.float32)
args.logger.info('Start inference....') args.logger.info('Start inference....')
for i, data in enumerate(ds.create_dict_iterator(num_epochs=1)): for i, data in enumerate(ds.create_dict_iterator(num_epochs=1)):
image = data["image"] image = data["image"]
@@ -305,7 +302,7 @@ def test():
image_shape = data["image_shape"] image_shape = data["image_shape"]
image_id = data["img_id"] image_id = data["img_id"]


prediction = network(image, input_shape)
prediction = network(image)
output_big, output_me, output_small = prediction output_big, output_me, output_small = prediction
output_big = output_big.asnumpy() output_big = output_big.asnumpy()
output_me = output_me.asnumpy() output_me = output_me.asnumpy()
@@ -324,7 +321,7 @@ def test():
eval_result = detection.get_eval_result() eval_result = detection.get_eval_result()


cost_time = time.time() - start_time cost_time = time.time() - start_time
args.logger.info('\n=============coco eval reulst=========\n' + eval_result)
args.logger.info('\n=============coco eval result=========\n' + eval_result)
args.logger.info('testing cost time {:.2f}h'.format(cost_time / 3600.)) args.logger.info('testing cost time {:.2f}h'.format(cost_time / 3600.))






+ 1
- 2
model_zoo/official/cv/yolov3_darknet53/export.py View File

@@ -47,6 +47,5 @@ if __name__ == "__main__":


shape = [args.batch_size, 3] + config.test_img_shape shape = [args.batch_size, 3] + config.test_img_shape
input_data = Tensor(np.zeros(shape), ms.float32) input_data = Tensor(np.zeros(shape), ms.float32)
input_shape = Tensor(tuple(config.test_img_shape), ms.float32)


export(network, input_data, input_shape, file_name=args.file_name, file_format=args.file_format)
export(network, input_data, file_name=args.file_name, file_format=args.file_format)

+ 9
- 3
model_zoo/official/cv/yolov3_darknet53/src/yolo.py View File

@@ -365,6 +365,7 @@ class YOLOV3DarkNet53(nn.Cell):
def __init__(self, is_training): def __init__(self, is_training):
super(YOLOV3DarkNet53, self).__init__() super(YOLOV3DarkNet53, self).__init__()
self.config = ConfigYOLOV3DarkNet53() self.config = ConfigYOLOV3DarkNet53()
self.tenser_to_array = P.TupleToArray()
# YOLOv3 network # YOLOv3 network
self.feature_map = YOLOv3(backbone=DarkNet(ResidualBlock, self.config.backbone_layers, self.feature_map = YOLOv3(backbone=DarkNet(ResidualBlock, self.config.backbone_layers,
@@ -379,7 +380,9 @@ class YOLOV3DarkNet53(nn.Cell):
self.detect_2 = DetectionBlock('m', is_training=is_training) self.detect_2 = DetectionBlock('m', is_training=is_training)
self.detect_3 = DetectionBlock('s', is_training=is_training) self.detect_3 = DetectionBlock('s', is_training=is_training)
def construct(self, x, input_shape):
def construct(self, x):
input_shape = F.shape(x)[2:4]
input_shape = F.cast(self.tenser_to_array(input_shape), ms.float32)
big_object_output, medium_object_output, small_object_output = self.feature_map(x) big_object_output, medium_object_output, small_object_output = self.feature_map(x)
output_big = self.detect_1(big_object_output, input_shape) output_big = self.detect_1(big_object_output, input_shape)
output_me = self.detect_2(medium_object_output, input_shape) output_me = self.detect_2(medium_object_output, input_shape)
@@ -394,12 +397,15 @@ class YoloWithLossCell(nn.Cell):
super(YoloWithLossCell, self).__init__() super(YoloWithLossCell, self).__init__()
self.yolo_network = network self.yolo_network = network
self.config = ConfigYOLOV3DarkNet53() self.config = ConfigYOLOV3DarkNet53()
self.tenser_to_array = P.TupleToArray()
self.loss_big = YoloLossBlock('l', self.config) self.loss_big = YoloLossBlock('l', self.config)
self.loss_me = YoloLossBlock('m', self.config) self.loss_me = YoloLossBlock('m', self.config)
self.loss_small = YoloLossBlock('s', self.config) self.loss_small = YoloLossBlock('s', self.config)
def construct(self, x, y_true_0, y_true_1, y_true_2, gt_0, gt_1, gt_2, input_shape):
yolo_out = self.yolo_network(x, input_shape)
def construct(self, x, y_true_0, y_true_1, y_true_2, gt_0, gt_1, gt_2):
input_shape = F.shape(x)[2:4]
input_shape = F.cast(self.tenser_to_array(input_shape), ms.float32)
yolo_out = self.yolo_network(x)
loss_l = self.loss_big(*yolo_out[0], y_true_0, gt_0, input_shape) loss_l = self.loss_big(*yolo_out[0], y_true_0, gt_0, input_shape)
loss_m = self.loss_me(*yolo_out[1], y_true_1, gt_1, input_shape) loss_m = self.loss_me(*yolo_out[1], y_true_1, gt_1, input_shape)
loss_s = self.loss_small(*yolo_out[2], y_true_2, gt_2, input_shape) loss_s = self.loss_small(*yolo_out[2], y_true_2, gt_2, input_shape)


+ 1
- 3
model_zoo/official/cv/yolov3_darknet53/train.py View File

@@ -26,7 +26,6 @@ from mindspore import context
from mindspore.communication.management import init, get_rank, get_group_size from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train.callback import ModelCheckpoint, RunContext from mindspore.train.callback import ModelCheckpoint, RunContext
from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig
import mindspore as ms
from mindspore import amp from mindspore import amp
from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.common import set_seed from mindspore.common import set_seed
@@ -254,9 +253,8 @@ def train():
batch_gt_box1 = Tensor.from_numpy(data['gt_box2']) batch_gt_box1 = Tensor.from_numpy(data['gt_box2'])
batch_gt_box2 = Tensor.from_numpy(data['gt_box3']) batch_gt_box2 = Tensor.from_numpy(data['gt_box3'])


input_shape = Tensor(tuple(input_shape[::-1]), ms.float32)
loss = network(images, batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, loss = network(images, batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1,
batch_gt_box2, input_shape)
batch_gt_box2)
loss_meter.update(loss.asnumpy()) loss_meter.update(loss.asnumpy())


if args.rank_save_ckpt_flag: if args.rank_save_ckpt_flag:


Loading…
Cancel
Save