| @@ -137,7 +137,7 @@ def ssd_bboxes_encode(boxes): | |||||
| num_match_num = np.array([len(np.nonzero(t_label)[0])], dtype=np.int32) | num_match_num = np.array([len(np.nonzero(t_label)[0])], dtype=np.int32) | ||||
| return bboxes, t_label.astype(np.int32), num_match_num | return bboxes, t_label.astype(np.int32), num_match_num | ||||
| def ssd_bboxes_decode(boxes, index, image_shape): | |||||
| def ssd_bboxes_decode(boxes, index): | |||||
| """Decode predict boxes to [x, y, w, h]""" | """Decode predict boxes to [x, y, w, h]""" | ||||
| boxes_t = boxes[index] | boxes_t = boxes[index] | ||||
| default_boxes_t = default_boxes[index] | default_boxes_t = default_boxes[index] | ||||
| @@ -110,14 +110,12 @@ def metrics(pred_data): | |||||
| pred_boxes = sample['boxes'] | pred_boxes = sample['boxes'] | ||||
| boxes_scores = sample['box_scores'] | boxes_scores = sample['box_scores'] | ||||
| annotation = sample['annotation'] | annotation = sample['annotation'] | ||||
| image_shape = sample['image_shape'] | |||||
| annotation = np.squeeze(annotation, axis=0) | annotation = np.squeeze(annotation, axis=0) | ||||
| image_shape = np.squeeze(image_shape, axis=0) | |||||
| pred_labels = np.argmax(boxes_scores, axis=-1) | pred_labels = np.argmax(boxes_scores, axis=-1) | ||||
| index = np.nonzero(pred_labels) | index = np.nonzero(pred_labels) | ||||
| pred_boxes = ssd_bboxes_decode(pred_boxes, index, image_shape) | |||||
| pred_boxes = ssd_bboxes_decode(pred_boxes, index) | |||||
| pred_boxes = pred_boxes.clip(0, 1) | pred_boxes = pred_boxes.clip(0, 1) | ||||
| boxes_scores = np.max(boxes_scores, axis=-1) | boxes_scores = np.max(boxes_scores, axis=-1) | ||||
| @@ -60,7 +60,7 @@ def init_net_param(net, init='ones'): | |||||
| p.set_parameter_data(initializer(init, p.data.shape(), p.data.dtype())) | p.set_parameter_data(initializer(init, p.data.shape(), p.data.dtype())) | ||||
| if __name__ == '__main__': | |||||
| def main(): | |||||
| parser = argparse.ArgumentParser(description="YOLOv3 train") | parser = argparse.ArgumentParser(description="YOLOv3 train") | ||||
| parser.add_argument("--only_create_dataset", type=bool, default=False, help="If set it true, only create " | parser.add_argument("--only_create_dataset", type=bool, default=False, help="If set it true, only create " | ||||
| "Mindrecord, default is false.") | "Mindrecord, default is false.") | ||||
| @@ -153,3 +153,6 @@ if __name__ == '__main__': | |||||
| dataset_sink_mode = True | dataset_sink_mode = True | ||||
| print("Start train YOLOv3, the first epoch will be slower because of the graph compilation.") | print("Start train YOLOv3, the first epoch will be slower because of the graph compilation.") | ||||
| model.train(args_opt.epoch_size, dataset, callbacks=callback, dataset_sink_mode=dataset_sink_mode) | model.train(args_opt.epoch_size, dataset, callbacks=callback, dataset_sink_mode=dataset_sink_mode) | ||||
| if __name__ == '__main__': | |||||
| main() | |||||