|
|
|
@@ -437,8 +437,8 @@ class EpochBasedTrainer(BaseTrainer): |
|
|
|
|
|
|
|
def train(self, checkpoint_path=None, *args, **kwargs): |
|
|
|
self._mode = ModeKeys.TRAIN |
|
|
|
model_name = self.cfg.model.type |
|
|
|
create_library_statistics("train", model_name, None) |
|
|
|
if hasattr(self.model, 'name'): |
|
|
|
create_library_statistics("train", self.model.name, None) |
|
|
|
|
|
|
|
if self.train_dataset is None: |
|
|
|
self.train_dataloader = self.get_train_dataloader() |
|
|
|
@@ -459,8 +459,8 @@ class EpochBasedTrainer(BaseTrainer): |
|
|
|
self.train_loop(self.train_dataloader) |
|
|
|
|
|
|
|
def evaluate(self, checkpoint_path=None): |
|
|
|
model_name = self.cfg.model.type |
|
|
|
create_library_statistics("evaluate", model_name, None) |
|
|
|
if hasattr(self.model, 'name'): |
|
|
|
create_library_statistics("evaluate", self.model.name, None) |
|
|
|
if checkpoint_path is not None and os.path.isfile(checkpoint_path): |
|
|
|
from modelscope.trainers.hooks import CheckpointHook |
|
|
|
CheckpointHook.load_checkpoint(checkpoint_path, self) |
|
|
|
|