From: @shuzigood Reviewed-by: @wuxuejian,@linqingke Signed-off-by: @wuxuejiantags/v1.2.0-rc1
| @@ -24,11 +24,9 @@ import numpy as np | |||||
| from pycocotools.coco import COCO | from pycocotools.coco import COCO | ||||
| from pycocotools.cocoeval import COCOeval | from pycocotools.cocoeval import COCOeval | ||||
| from mindspore import Tensor | |||||
| from mindspore.context import ParallelMode | from mindspore.context import ParallelMode | ||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| import mindspore as ms | |||||
| from src.yolo import YOLOV3DarkNet53 | from src.yolo import YOLOV3DarkNet53 | ||||
| from src.logger import get_logger | from src.logger import get_logger | ||||
| @@ -297,7 +295,6 @@ def test(): | |||||
| # init detection engine | # init detection engine | ||||
| detection = DetectionEngine(args) | detection = DetectionEngine(args) | ||||
| input_shape = Tensor(tuple(config.test_img_shape), ms.float32) | |||||
| args.logger.info('Start inference....') | args.logger.info('Start inference....') | ||||
| for i, data in enumerate(ds.create_dict_iterator(num_epochs=1)): | for i, data in enumerate(ds.create_dict_iterator(num_epochs=1)): | ||||
| image = data["image"] | image = data["image"] | ||||
| @@ -305,7 +302,7 @@ def test(): | |||||
| image_shape = data["image_shape"] | image_shape = data["image_shape"] | ||||
| image_id = data["img_id"] | image_id = data["img_id"] | ||||
| prediction = network(image, input_shape) | |||||
| prediction = network(image) | |||||
| output_big, output_me, output_small = prediction | output_big, output_me, output_small = prediction | ||||
| output_big = output_big.asnumpy() | output_big = output_big.asnumpy() | ||||
| output_me = output_me.asnumpy() | output_me = output_me.asnumpy() | ||||
| @@ -324,7 +321,7 @@ def test(): | |||||
| eval_result = detection.get_eval_result() | eval_result = detection.get_eval_result() | ||||
| cost_time = time.time() - start_time | cost_time = time.time() - start_time | ||||
| args.logger.info('\n=============coco eval reulst=========\n' + eval_result) | |||||
| args.logger.info('\n=============coco eval result=========\n' + eval_result) | |||||
| args.logger.info('testing cost time {:.2f}h'.format(cost_time / 3600.)) | args.logger.info('testing cost time {:.2f}h'.format(cost_time / 3600.)) | ||||
| @@ -47,6 +47,5 @@ if __name__ == "__main__": | |||||
| shape = [args.batch_size, 3] + config.test_img_shape | shape = [args.batch_size, 3] + config.test_img_shape | ||||
| input_data = Tensor(np.zeros(shape), ms.float32) | input_data = Tensor(np.zeros(shape), ms.float32) | ||||
| input_shape = Tensor(tuple(config.test_img_shape), ms.float32) | |||||
| export(network, input_data, input_shape, file_name=args.file_name, file_format=args.file_format) | |||||
| export(network, input_data, file_name=args.file_name, file_format=args.file_format) | |||||
| @@ -365,6 +365,7 @@ class YOLOV3DarkNet53(nn.Cell): | |||||
| def __init__(self, is_training): | def __init__(self, is_training): | ||||
| super(YOLOV3DarkNet53, self).__init__() | super(YOLOV3DarkNet53, self).__init__() | ||||
| self.config = ConfigYOLOV3DarkNet53() | self.config = ConfigYOLOV3DarkNet53() | ||||
| self.tenser_to_array = P.TupleToArray() | |||||
| # YOLOv3 network | # YOLOv3 network | ||||
| self.feature_map = YOLOv3(backbone=DarkNet(ResidualBlock, self.config.backbone_layers, | self.feature_map = YOLOv3(backbone=DarkNet(ResidualBlock, self.config.backbone_layers, | ||||
| @@ -379,7 +380,9 @@ class YOLOV3DarkNet53(nn.Cell): | |||||
| self.detect_2 = DetectionBlock('m', is_training=is_training) | self.detect_2 = DetectionBlock('m', is_training=is_training) | ||||
| self.detect_3 = DetectionBlock('s', is_training=is_training) | self.detect_3 = DetectionBlock('s', is_training=is_training) | ||||
| def construct(self, x, input_shape): | |||||
| def construct(self, x): | |||||
| input_shape = F.shape(x)[2:4] | |||||
| input_shape = F.cast(self.tenser_to_array(input_shape), ms.float32) | |||||
| big_object_output, medium_object_output, small_object_output = self.feature_map(x) | big_object_output, medium_object_output, small_object_output = self.feature_map(x) | ||||
| output_big = self.detect_1(big_object_output, input_shape) | output_big = self.detect_1(big_object_output, input_shape) | ||||
| output_me = self.detect_2(medium_object_output, input_shape) | output_me = self.detect_2(medium_object_output, input_shape) | ||||
| @@ -394,12 +397,15 @@ class YoloWithLossCell(nn.Cell): | |||||
| super(YoloWithLossCell, self).__init__() | super(YoloWithLossCell, self).__init__() | ||||
| self.yolo_network = network | self.yolo_network = network | ||||
| self.config = ConfigYOLOV3DarkNet53() | self.config = ConfigYOLOV3DarkNet53() | ||||
| self.tenser_to_array = P.TupleToArray() | |||||
| self.loss_big = YoloLossBlock('l', self.config) | self.loss_big = YoloLossBlock('l', self.config) | ||||
| self.loss_me = YoloLossBlock('m', self.config) | self.loss_me = YoloLossBlock('m', self.config) | ||||
| self.loss_small = YoloLossBlock('s', self.config) | self.loss_small = YoloLossBlock('s', self.config) | ||||
| def construct(self, x, y_true_0, y_true_1, y_true_2, gt_0, gt_1, gt_2, input_shape): | |||||
| yolo_out = self.yolo_network(x, input_shape) | |||||
| def construct(self, x, y_true_0, y_true_1, y_true_2, gt_0, gt_1, gt_2): | |||||
| input_shape = F.shape(x)[2:4] | |||||
| input_shape = F.cast(self.tenser_to_array(input_shape), ms.float32) | |||||
| yolo_out = self.yolo_network(x) | |||||
| loss_l = self.loss_big(*yolo_out[0], y_true_0, gt_0, input_shape) | loss_l = self.loss_big(*yolo_out[0], y_true_0, gt_0, input_shape) | ||||
| loss_m = self.loss_me(*yolo_out[1], y_true_1, gt_1, input_shape) | loss_m = self.loss_me(*yolo_out[1], y_true_1, gt_1, input_shape) | ||||
| loss_s = self.loss_small(*yolo_out[2], y_true_2, gt_2, input_shape) | loss_s = self.loss_small(*yolo_out[2], y_true_2, gt_2, input_shape) | ||||
| @@ -26,7 +26,6 @@ from mindspore import context | |||||
| from mindspore.communication.management import init, get_rank, get_group_size | from mindspore.communication.management import init, get_rank, get_group_size | ||||
| from mindspore.train.callback import ModelCheckpoint, RunContext | from mindspore.train.callback import ModelCheckpoint, RunContext | ||||
| from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig | from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig | ||||
| import mindspore as ms | |||||
| from mindspore import amp | from mindspore import amp | ||||
| from mindspore.train.loss_scale_manager import FixedLossScaleManager | from mindspore.train.loss_scale_manager import FixedLossScaleManager | ||||
| from mindspore.common import set_seed | from mindspore.common import set_seed | ||||
| @@ -254,9 +253,8 @@ def train(): | |||||
| batch_gt_box1 = Tensor.from_numpy(data['gt_box2']) | batch_gt_box1 = Tensor.from_numpy(data['gt_box2']) | ||||
| batch_gt_box2 = Tensor.from_numpy(data['gt_box3']) | batch_gt_box2 = Tensor.from_numpy(data['gt_box3']) | ||||
| input_shape = Tensor(tuple(input_shape[::-1]), ms.float32) | |||||
| loss = network(images, batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, | loss = network(images, batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, | ||||
| batch_gt_box2, input_shape) | |||||
| batch_gt_box2) | |||||
| loss_meter.update(loss.asnumpy()) | loss_meter.update(loss.asnumpy()) | ||||
| if args.rank_save_ckpt_flag: | if args.rank_save_ckpt_flag: | ||||