|
|
|
@@ -1,3 +1,4 @@ |
|
|
|
# Copyright (c) Alibaba, Inc. and its affiliates. |
|
|
|
import argparse |
|
|
|
import logging as logger |
|
|
|
import os |
|
|
|
@@ -48,6 +49,7 @@ class RealtimeDetector(TorchModel): |
|
|
|
self.nmsthre = self.exp.nmsthre |
|
|
|
self.test_size = self.exp.test_size |
|
|
|
self.preproc = ValTransform(legacy=False) |
|
|
|
self.label_mapping = self.config['labels'] |
|
|
|
|
|
|
|
def inference(self, img): |
|
|
|
with torch.no_grad(): |
|
|
|
@@ -81,5 +83,8 @@ class RealtimeDetector(TorchModel): |
|
|
|
bboxes = outputs[0][:, 0:4].cpu().numpy() / self.ratio |
|
|
|
scores = outputs[0][:, 5].cpu().numpy() |
|
|
|
labels = outputs[0][:, 6].cpu().int().numpy() |
|
|
|
pred_label_names = [] |
|
|
|
for lab in labels: |
|
|
|
pred_label_names.append(self.label_mapping[lab]) |
|
|
|
|
|
|
|
return bboxes, scores, labels |
|
|
|
return bboxes, scores, pred_label_names |