diff --git a/model_zoo/official/cv/yolov3_darknet53/eval.py b/model_zoo/official/cv/yolov3_darknet53/eval.py index 8658ed8926..b270a39a26 100644 --- a/model_zoo/official/cv/yolov3_darknet53/eval.py +++ b/model_zoo/official/cv/yolov3_darknet53/eval.py @@ -24,11 +24,9 @@ import numpy as np from pycocotools.coco import COCO from pycocotools.cocoeval import COCOeval -from mindspore import Tensor from mindspore.context import ParallelMode from mindspore import context from mindspore.train.serialization import load_checkpoint, load_param_into_net -import mindspore as ms from src.yolo import YOLOV3DarkNet53 from src.logger import get_logger @@ -297,7 +295,6 @@ def test(): # init detection engine detection = DetectionEngine(args) - input_shape = Tensor(tuple(config.test_img_shape), ms.float32) args.logger.info('Start inference....') for i, data in enumerate(ds.create_dict_iterator(num_epochs=1)): image = data["image"] @@ -305,7 +302,7 @@ def test(): image_shape = data["image_shape"] image_id = data["img_id"] - prediction = network(image, input_shape) + prediction = network(image) output_big, output_me, output_small = prediction output_big = output_big.asnumpy() output_me = output_me.asnumpy() @@ -324,7 +321,7 @@ def test(): eval_result = detection.get_eval_result() cost_time = time.time() - start_time - args.logger.info('\n=============coco eval reulst=========\n' + eval_result) + args.logger.info('\n=============coco eval result=========\n' + eval_result) args.logger.info('testing cost time {:.2f}h'.format(cost_time / 3600.)) diff --git a/model_zoo/official/cv/yolov3_darknet53/export.py b/model_zoo/official/cv/yolov3_darknet53/export.py index 8888ea3f3d..7f2c3ce696 100644 --- a/model_zoo/official/cv/yolov3_darknet53/export.py +++ b/model_zoo/official/cv/yolov3_darknet53/export.py @@ -47,6 +47,5 @@ if __name__ == "__main__": shape = [args.batch_size, 3] + config.test_img_shape input_data = Tensor(np.zeros(shape), ms.float32) - input_shape = Tensor(tuple(config.test_img_shape), ms.float32) - export(network, input_data, input_shape, file_name=args.file_name, file_format=args.file_format) + export(network, input_data, file_name=args.file_name, file_format=args.file_format) diff --git a/model_zoo/official/cv/yolov3_darknet53/src/yolo.py b/model_zoo/official/cv/yolov3_darknet53/src/yolo.py index c3e43f8e14..6d6e37d909 100644 --- a/model_zoo/official/cv/yolov3_darknet53/src/yolo.py +++ b/model_zoo/official/cv/yolov3_darknet53/src/yolo.py @@ -365,6 +365,7 @@ class YOLOV3DarkNet53(nn.Cell): def __init__(self, is_training): super(YOLOV3DarkNet53, self).__init__() self.config = ConfigYOLOV3DarkNet53() + self.tenser_to_array = P.TupleToArray() # YOLOv3 network self.feature_map = YOLOv3(backbone=DarkNet(ResidualBlock, self.config.backbone_layers, @@ -379,7 +380,9 @@ class YOLOV3DarkNet53(nn.Cell): self.detect_2 = DetectionBlock('m', is_training=is_training) self.detect_3 = DetectionBlock('s', is_training=is_training) - def construct(self, x, input_shape): + def construct(self, x): + input_shape = F.shape(x)[2:4] + input_shape = F.cast(self.tenser_to_array(input_shape), ms.float32) big_object_output, medium_object_output, small_object_output = self.feature_map(x) output_big = self.detect_1(big_object_output, input_shape) output_me = self.detect_2(medium_object_output, input_shape) @@ -394,12 +397,15 @@ class YoloWithLossCell(nn.Cell): super(YoloWithLossCell, self).__init__() self.yolo_network = network self.config = ConfigYOLOV3DarkNet53() + self.tenser_to_array = P.TupleToArray() self.loss_big = YoloLossBlock('l', self.config) self.loss_me = YoloLossBlock('m', self.config) self.loss_small = YoloLossBlock('s', self.config) - def construct(self, x, y_true_0, y_true_1, y_true_2, gt_0, gt_1, gt_2, input_shape): - yolo_out = self.yolo_network(x, input_shape) + def construct(self, x, y_true_0, y_true_1, y_true_2, gt_0, gt_1, gt_2): + input_shape = F.shape(x)[2:4] + input_shape = F.cast(self.tenser_to_array(input_shape), ms.float32) + yolo_out = self.yolo_network(x) loss_l = self.loss_big(*yolo_out[0], y_true_0, gt_0, input_shape) loss_m = self.loss_me(*yolo_out[1], y_true_1, gt_1, input_shape) loss_s = self.loss_small(*yolo_out[2], y_true_2, gt_2, input_shape) diff --git a/model_zoo/official/cv/yolov3_darknet53/train.py b/model_zoo/official/cv/yolov3_darknet53/train.py index 679f84da30..682e0eeffc 100644 --- a/model_zoo/official/cv/yolov3_darknet53/train.py +++ b/model_zoo/official/cv/yolov3_darknet53/train.py @@ -26,7 +26,6 @@ from mindspore import context from mindspore.communication.management import init, get_rank, get_group_size from mindspore.train.callback import ModelCheckpoint, RunContext from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig -import mindspore as ms from mindspore import amp from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.common import set_seed @@ -254,9 +253,8 @@ def train(): batch_gt_box1 = Tensor.from_numpy(data['gt_box2']) batch_gt_box2 = Tensor.from_numpy(data['gt_box3']) - input_shape = Tensor(tuple(input_shape[::-1]), ms.float32) loss = network(images, batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, - batch_gt_box2, input_shape) + batch_gt_box2) loss_meter.update(loss.asnumpy()) if args.rank_save_ckpt_flag: