|
|
|
@@ -90,7 +90,7 @@ class TinynasClassificationPipeline(Pipeline): |
|
|
|
output_prob = torch.nn.functional.softmax(inputs['outputs'], dim=-1) |
|
|
|
score = torch.max(output_prob) |
|
|
|
output_dict = { |
|
|
|
OutputKeys.SCORES: score.item(), |
|
|
|
OutputKeys.LABELS: label_dict[inputs['outputs'].argmax().item()] |
|
|
|
OutputKeys.SCORES: [score.item()], |
|
|
|
OutputKeys.LABELS: [label_dict[inputs['outputs'].argmax().item()]] |
|
|
|
} |
|
|
|
return output_dict |