|
|
@@ -27,9 +27,9 @@ class ClassificationModel(TorchModel): |
|
|
|
|
|
|
|
|
self.load_pretrained_checkpoint() |
|
|
self.load_pretrained_checkpoint() |
|
|
|
|
|
|
|
|
def forward(self, Inputs): |
|
|
|
|
|
|
|
|
def forward(self, inputs): |
|
|
|
|
|
|
|
|
return self.cls_model(**Inputs) |
|
|
|
|
|
|
|
|
return self.cls_model(**inputs) |
|
|
|
|
|
|
|
|
def load_pretrained_checkpoint(self): |
|
|
def load_pretrained_checkpoint(self): |
|
|
import mmcv |
|
|
import mmcv |
|
|
|