You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_onnx.py 5.4 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. import argparse
  2. import os.path as osp
  3. import warnings
  4. import numbers
  5. from functools import partial
  6. import sys
  7. sys.path.append("/home/shanwei-luo/userdata/mmdetection")
  8. import numpy as np
  9. #import onnx
  10. import torch
  11. import cv2
  12. from mmcv import Config, DictAction
  13. from mmdet.core.export.model_wrappers import ONNXRuntimeDetector
  14. from mmdet.apis import (async_inference_detector, inference_detector,
  15. init_detector, show_result_pyplot)
  16. import onnxruntime as ort
  17. import onnx
  18. import mmcv
  19. def impad(img,
  20. *,
  21. shape=None,
  22. padding=None,
  23. pad_val=0,
  24. padding_mode='constant'):
  25. assert (shape is not None) ^ (padding is not None)
  26. if shape is not None:
  27. padding = (0, 0, shape[1] - img.shape[1], shape[0] - img.shape[0])
  28. # check pad_val
  29. if isinstance(pad_val, tuple):
  30. assert len(pad_val) == img.shape[-1]
  31. elif not isinstance(pad_val, numbers.Number):
  32. raise TypeError('pad_val must be a int or a tuple. '
  33. f'But received {type(pad_val)}')
  34. # check padding
  35. if isinstance(padding, tuple) and len(padding) in [2, 4]:
  36. if len(padding) == 2:
  37. padding = (padding[0], padding[1], padding[0], padding[1])
  38. elif isinstance(padding, numbers.Number):
  39. padding = (padding, padding, padding, padding)
  40. else:
  41. raise ValueError('Padding must be a int or a 2, or 4 element tuple.'
  42. f'But received {padding}')
  43. # check padding mode
  44. assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
  45. border_type = {
  46. 'constant': cv2.BORDER_CONSTANT,
  47. 'edge': cv2.BORDER_REPLICATE,
  48. 'reflect': cv2.BORDER_REFLECT_101,
  49. 'symmetric': cv2.BORDER_REFLECT
  50. }
  51. img = cv2.copyMakeBorder(
  52. img,
  53. padding[1],
  54. padding[3],
  55. padding[0],
  56. padding[2],
  57. border_type[padding_mode],
  58. value=pad_val)
  59. return img
  60. def imnormalize(img, mean, std, to_rgb=True):
  61. mean = np.float64(mean.reshape(1, -1))
  62. stdinv = 1 / np.float64(std.reshape(1, -1))
  63. if to_rgb:
  64. cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) # inplace
  65. cv2.subtract(img, mean, img) # inplace
  66. cv2.multiply(img, stdinv, img) # inplace
  67. return img
  68. def preprocess_example_input(input_config):
  69. input_path = input_config['input_path']
  70. image = cv2.imread(input_path, cv2.IMREAD_COLOR)
  71. h, w, _ = image.shape
  72. img_scale = (400, 400)
  73. max_long_edge = max(img_scale)
  74. max_short_edge = min(img_scale)
  75. scale_factor = min(max_long_edge / max(h, w), max_short_edge / min(h, w))
  76. scale_w = int(w * float(scale_factor) + 0.5)
  77. scale_h = int(h * float(scale_factor) + 0.5)
  78. image = cv2.resize(image, (scale_w, scale_h))
  79. image = np.asarray(image).astype(np.float32)
  80. mean = np.array([123.675, 116.28, 103.53], dtype=np.float32)
  81. std = np.array([58.395, 57.12, 57.375], dtype=np.float32)
  82. image = imnormalize(image, mean, std)
  83. divisor = 32
  84. pad_h = int(np.ceil(image.shape[0] / divisor)) * divisor
  85. pad_w = int(np.ceil(image.shape[1] / divisor)) * divisor
  86. image = impad(image, shape=(pad_h, pad_w), pad_val=0)
  87. image = np.asarray(image).astype(np.float32)
  88. image = np.transpose(image, [2, 0, 1])
  89. one_img = torch.from_numpy(image).unsqueeze(0).float().requires_grad_(
  90. True)
  91. (_, C, H, W) = input_config['input_shape']
  92. one_meta = {
  93. 'img_shape': (H, W, C),
  94. 'ori_shape': (H, W, C),
  95. 'pad_shape': (H, W, C),
  96. 'filename': '<demo>.png',
  97. 'scale_factor': np.ones(4, dtype=np.float32),
  98. 'flip': False,
  99. 'show_img': image,
  100. 'flip_direction': None
  101. }
  102. return one_img, one_meta
  103. print(f"onnxruntime device: {ort.get_device()}") # output: GPU
  104. print(f'ort avail providers: {ort.get_available_providers()}') # output: ['CUDAExecutionProvider', 'CPUExecutionProvider']
  105. output_file = "checkpoints/AD_dsxw_atts_20220318.onnx"
  106. classes_name = ['yiwei','loujian','celi','libei','fantie','lianxi','duojian','shunjian','shaoxi','jiahan','yiwu']
  107. onnx_model = onnx.load(output_file)
  108. onnx.checker.check_model(onnx_model)
  109. onnx_model = ONNXRuntimeDetector(output_file, classes_name, 0)
  110. input_shape = (1, 3, 400, 400)
  111. input_img = "/home/shanwei-luo/userdata/datasets/dsxw_dataset_v4/dsxw_test/images/21000204.jpg"
  112. normalize_cfg = {'mean': (123.675, 116.28, 103.53),'std': (58.395, 57.12, 57.375)}
  113. input_config = {
  114. 'input_shape': input_shape,
  115. 'input_path': input_img,
  116. 'normalize_cfg': normalize_cfg
  117. }
  118. input_config['input_shape'] = (1, 3, 416, 416)
  119. # prepare input once again
  120. one_img, one_meta = preprocess_example_input(input_config)
  121. print(one_img)
  122. img_list, img_meta_list = [one_img], [[one_meta]]
  123. img_list = [_.cuda().contiguous() for _ in img_list]
  124. onnx_results = onnx_model(
  125. img_list, img_metas=img_meta_list, return_loss=False)[0]
  126. print(onnx_results)
  127. config_file_1 = '/home/shanwei-luo/userdata/mmdetection/work_dirs/AD_dsxw_test61/AD_dsxw_test61.py'
  128. checkpoint_file_1 = '/home/shanwei-luo/userdata/mmdetection/work_dirs/AD_dsxw_test61/epoch_36.pth'
  129. model_1 = init_detector(config_file_1, checkpoint_file_1, device='cuda:1')
  130. results_1_tmp = inference_detector(model_1, [input_img])
  131. print(results_1_tmp)
  132. for i in range(len(onnx_results)):
  133. print(onnx_results[i].shape)
  134. for i in range(len(results_1_tmp[0])):
  135. print(len(results_1_tmp[0][i]))

No Description

Contributors (2)