You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mmdet_handler.py 2.6 kB

2 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. import base64
  3. import os
  4. import mmcv
  5. import torch
  6. from ts.torch_handler.base_handler import BaseHandler
  7. from mmdet.apis import inference_detector, init_detector
  8. class MMdetHandler(BaseHandler):
  9. threshold = 0.5
  10. def initialize(self, context):
  11. properties = context.system_properties
  12. self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu'
  13. self.device = torch.device(self.map_location + ':' +
  14. str(properties.get('gpu_id')) if torch.cuda.
  15. is_available() else self.map_location)
  16. self.manifest = context.manifest
  17. model_dir = properties.get('model_dir')
  18. serialized_file = self.manifest['model']['serializedFile']
  19. checkpoint = os.path.join(model_dir, serialized_file)
  20. self.config_file = os.path.join(model_dir, 'config.py')
  21. self.model = init_detector(self.config_file, checkpoint, self.device)
  22. self.initialized = True
  23. def preprocess(self, data):
  24. images = []
  25. for row in data:
  26. image = row.get('data') or row.get('body')
  27. if isinstance(image, str):
  28. image = base64.b64decode(image)
  29. image = mmcv.imfrombytes(image)
  30. images.append(image)
  31. return images
  32. def inference(self, data, *args, **kwargs):
  33. results = inference_detector(self.model, data)
  34. return results
  35. def postprocess(self, data):
  36. # Format output following the example ObjectDetectionHandler format
  37. output = []
  38. for image_index, image_result in enumerate(data):
  39. output.append([])
  40. if isinstance(image_result, tuple):
  41. bbox_result, segm_result = image_result
  42. if isinstance(segm_result, tuple):
  43. segm_result = segm_result[0] # ms rcnn
  44. else:
  45. bbox_result, segm_result = image_result, None
  46. for class_index, class_result in enumerate(bbox_result):
  47. class_name = self.model.CLASSES[class_index]
  48. for bbox in class_result:
  49. bbox_coords = bbox[:-1].tolist()
  50. score = float(bbox[-1])
  51. if score >= self.threshold:
  52. output[image_index].append({
  53. 'class_name': class_name,
  54. 'bbox': bbox_coords,
  55. 'score': score
  56. })
  57. return output

No Description

Contributors (1)