|
|
|
@@ -39,6 +39,7 @@ from modelscope.utils.logger import get_logger |
|
|
|
from modelscope.utils.registry import build_from_cfg |
|
|
|
from modelscope.utils.torch_utils import (get_dist_info, get_local_rank, |
|
|
|
init_dist, set_random_seed) |
|
|
|
from modelscope.hub.api import HubApi |
|
|
|
from .base import BaseTrainer |
|
|
|
from .builder import TRAINERS |
|
|
|
from .default_config import merge_cfg |
|
|
|
@@ -436,6 +437,9 @@ class EpochBasedTrainer(BaseTrainer): |
|
|
|
|
|
|
|
def train(self, checkpoint_path=None, *args, **kwargs): |
|
|
|
self._mode = ModeKeys.TRAIN |
|
|
|
_api = HubApi() |
|
|
|
model_name = self.cfg.task |
|
|
|
_api.create_library_statistics("train", model_name, None) |
|
|
|
|
|
|
|
if self.train_dataset is None: |
|
|
|
self.train_dataloader = self.get_train_dataloader() |
|
|
|
@@ -456,6 +460,9 @@ class EpochBasedTrainer(BaseTrainer): |
|
|
|
self.train_loop(self.train_dataloader) |
|
|
|
|
|
|
|
def evaluate(self, checkpoint_path=None): |
|
|
|
_api = HubApi() |
|
|
|
model_name = self.cfg.task |
|
|
|
_api.create_library_statistics("evaluate", model_name, None) |
|
|
|
if checkpoint_path is not None and os.path.isfile(checkpoint_path): |
|
|
|
from modelscope.trainers.hooks import CheckpointHook |
|
|
|
CheckpointHook.load_checkpoint(checkpoint_path, self) |
|
|
|
|