|
|
|
@@ -93,7 +93,7 @@ class TextDrivenSeg(TorchModel): |
|
|
|
""" |
|
|
|
with torch.no_grad(): |
|
|
|
if self.device_id == -1: |
|
|
|
output = self.model(image) |
|
|
|
output = self.model(image, [text]) |
|
|
|
else: |
|
|
|
device = torch.device('cuda', self.device_id) |
|
|
|
output = self.model(image.to(device), [text]) |
|
|
|
|