|
|
|
@@ -114,9 +114,8 @@ class GeneralRecognitionPipeline(Pipeline): |
|
|
|
label_mapping = f.readlines() |
|
|
|
score = torch.max(inputs['outputs']) |
|
|
|
inputs = { |
|
|
|
OutputKeys.SCORES: |
|
|
|
score.item(), |
|
|
|
OutputKeys.SCORES: [score.item()], |
|
|
|
OutputKeys.LABELS: |
|
|
|
label_mapping[inputs['outputs'].argmax()].split('\t')[1] |
|
|
|
[label_mapping[inputs['outputs'].argmax()].split('\t')[1]] |
|
|
|
} |
|
|
|
return inputs |