|
|
|
@@ -16,6 +16,7 @@ class TrainEstimator(BaseEstimator): |
|
|
|
""" |
|
|
|
def __init__(self, loss_f = "nll_loss", evaluation = [Acc()]): |
|
|
|
super().__init__(loss_f, evaluation) |
|
|
|
self.evaluation = evaluation |
|
|
|
self.estimator=OneShotEstimator(self.loss_f, self.evaluation) |
|
|
|
|
|
|
|
def infer(self, model: BaseSpace, dataset, mask="train"): |
|
|
|
@@ -34,6 +35,14 @@ class TrainEstimator(BaseEstimator): |
|
|
|
feval=self.evaluation, |
|
|
|
loss=self.loss_f, |
|
|
|
lr_scheduler_type=None) |
|
|
|
self.trainer.train(dataset) |
|
|
|
with torch.no_grad(): |
|
|
|
return self.estimator.infer(boxmodel.model, dataset, mask) |
|
|
|
try: |
|
|
|
self.trainer.train(dataset) |
|
|
|
with torch.no_grad(): |
|
|
|
return self.estimator.infer(boxmodel.model, dataset, mask) |
|
|
|
except RuntimeError as e: |
|
|
|
if "cuda" in str(e) or "CUDA" in str(e): |
|
|
|
INF = 100 |
|
|
|
fin = [-INF if eva.is_higher_better else INF for eva in self.evaluation] |
|
|
|
return fin, 0 |
|
|
|
else: |
|
|
|
raise e |