|
|
|
@@ -437,7 +437,7 @@ class EpochBasedTrainer(BaseTrainer): |
|
|
|
|
|
|
|
def train(self, checkpoint_path=None, *args, **kwargs): |
|
|
|
self._mode = ModeKeys.TRAIN |
|
|
|
model_name = self.cfg.task |
|
|
|
model_name = self.cfg.model.type |
|
|
|
create_library_statistics("train", model_name, None) |
|
|
|
|
|
|
|
if self.train_dataset is None: |
|
|
|
@@ -459,7 +459,7 @@ class EpochBasedTrainer(BaseTrainer): |
|
|
|
self.train_loop(self.train_dataloader) |
|
|
|
|
|
|
|
def evaluate(self, checkpoint_path=None): |
|
|
|
model_name = self.cfg.task |
|
|
|
model_name = self.cfg.model.type |
|
|
|
create_library_statistics("evaluate", model_name, None) |
|
|
|
if checkpoint_path is not None and os.path.isfile(checkpoint_path): |
|
|
|
from modelscope.trainers.hooks import CheckpointHook |
|
|
|
|