You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_torchserver.py 2.4 kB

2 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. from argparse import ArgumentParser
  2. import numpy as np
  3. import requests
  4. from mmdet.apis import inference_detector, init_detector, show_result_pyplot
  5. from mmdet.core import bbox2result
  6. def parse_args():
  7. parser = ArgumentParser()
  8. parser.add_argument('img', help='Image file')
  9. parser.add_argument('config', help='Config file')
  10. parser.add_argument('checkpoint', help='Checkpoint file')
  11. parser.add_argument('model_name', help='The model name in the server')
  12. parser.add_argument(
  13. '--inference-addr',
  14. default='127.0.0.1:8080',
  15. help='Address and port of the inference server')
  16. parser.add_argument(
  17. '--device', default='cuda:0', help='Device used for inference')
  18. parser.add_argument(
  19. '--score-thr', type=float, default=0.5, help='bbox score threshold')
  20. args = parser.parse_args()
  21. return args
  22. def parse_result(input, model_class):
  23. bbox = []
  24. label = []
  25. score = []
  26. for anchor in input:
  27. bbox.append(anchor['bbox'])
  28. label.append(model_class.index(anchor['class_name']))
  29. score.append([anchor['score']])
  30. bboxes = np.append(bbox, score, axis=1)
  31. labels = np.array(label)
  32. result = bbox2result(bboxes, labels, len(model_class))
  33. return result
  34. def main(args):
  35. # build the model from a config file and a checkpoint file
  36. model = init_detector(args.config, args.checkpoint, device=args.device)
  37. # test a single image
  38. model_result = inference_detector(model, args.img)
  39. for i, anchor_set in enumerate(model_result):
  40. anchor_set = anchor_set[anchor_set[:, 4] >= 0.5]
  41. model_result[i] = anchor_set
  42. # show the results
  43. show_result_pyplot(
  44. model,
  45. args.img,
  46. model_result,
  47. score_thr=args.score_thr,
  48. title='pytorch_result')
  49. url = 'http://' + args.inference_addr + '/predictions/' + args.model_name
  50. with open(args.img, 'rb') as image:
  51. response = requests.post(url, image)
  52. server_result = parse_result(response.json(), model.CLASSES)
  53. show_result_pyplot(
  54. model,
  55. args.img,
  56. server_result,
  57. score_thr=args.score_thr,
  58. title='server_result')
  59. for i in range(len(model.CLASSES)):
  60. assert np.allclose(model_result[i], server_result[i])
  61. if __name__ == '__main__':
  62. args = parse_args()
  63. main(args)

No Description

Contributors (3)