|
|
|
@@ -14,8 +14,8 @@ from torch.utils.data import DataLoader, Dataset |
|
|
|
from torch.utils.data.dataloader import default_collate |
|
|
|
from torch.utils.data.distributed import DistributedSampler |
|
|
|
|
|
|
|
from modelscope.hub.utils.utils import create_library_statistics |
|
|
|
from modelscope.hub.snapshot_download import snapshot_download |
|
|
|
from modelscope.hub.utils.utils import create_library_statistics |
|
|
|
from modelscope.metainfo import Trainers |
|
|
|
from modelscope.metrics import build_metric, task_default_metrics |
|
|
|
from modelscope.models.base import Model, TorchModel |
|
|
|
@@ -438,7 +438,7 @@ class EpochBasedTrainer(BaseTrainer): |
|
|
|
def train(self, checkpoint_path=None, *args, **kwargs): |
|
|
|
self._mode = ModeKeys.TRAIN |
|
|
|
if hasattr(self.model, 'name'): |
|
|
|
create_library_statistics("train", self.model.name, None) |
|
|
|
create_library_statistics('train', self.model.name, None) |
|
|
|
|
|
|
|
if self.train_dataset is None: |
|
|
|
self.train_dataloader = self.get_train_dataloader() |
|
|
|
@@ -460,7 +460,7 @@ class EpochBasedTrainer(BaseTrainer): |
|
|
|
|
|
|
|
def evaluate(self, checkpoint_path=None): |
|
|
|
if hasattr(self.model, 'name'): |
|
|
|
create_library_statistics("evaluate", self.model.name, None) |
|
|
|
create_library_statistics('evaluate', self.model.name, None) |
|
|
|
if checkpoint_path is not None and os.path.isfile(checkpoint_path): |
|
|
|
from modelscope.trainers.hooks import CheckpointHook |
|
|
|
CheckpointHook.load_checkpoint(checkpoint_path, self) |
|
|
|
|