Browse Source

fix style issues

master
mulin.lyh 3 years ago
parent
commit
9ae5b67204
3 changed files with 7 additions and 6 deletions
  1. +2
    -1
      modelscope/hub/utils/utils.py
  2. +2
    -2
      modelscope/pipelines/base.py
  3. +3
    -3
      modelscope/trainers/trainer.py

+ 2
- 1
modelscope/hub/utils/utils.py View File

@@ -2,10 +2,11 @@


import hashlib import hashlib
import os import os
import requests
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional


import requests

from modelscope.hub.api import ModelScopeConfig from modelscope.hub.api import ModelScopeConfig
from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN, from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN,
DEFAULT_MODELSCOPE_GROUP, DEFAULT_MODELSCOPE_GROUP,


+ 2
- 2
modelscope/pipelines/base.py View File

@@ -10,6 +10,7 @@ from typing import Any, Dict, Generator, List, Mapping, Union


import numpy as np import numpy as np


from modelscope.hub.utils.utils import create_library_statistics
from modelscope.models.base import Model from modelscope.models.base import Model
from modelscope.msdatasets import MsDataset from modelscope.msdatasets import MsDataset
from modelscope.outputs import TASK_OUTPUTS from modelscope.outputs import TASK_OUTPUTS
@@ -23,7 +24,6 @@ from modelscope.utils.hub import read_config, snapshot_download
from modelscope.utils.import_utils import is_tf_available, is_torch_available from modelscope.utils.import_utils import is_tf_available, is_torch_available
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger
from modelscope.utils.torch_utils import _find_free_port, _is_free_port from modelscope.utils.torch_utils import _find_free_port, _is_free_port
from modelscope.hub.utils.utils import create_library_statistics
from .util import is_model, is_official_hub_path from .util import is_model, is_official_hub_path


if is_torch_available(): if is_torch_available():
@@ -154,7 +154,7 @@ class Pipeline(ABC):
# modelscope library developer will handle this function # modelscope library developer will handle this function
for single_model in self.models: for single_model in self.models:
if hasattr(single_model, 'name'): if hasattr(single_model, 'name'):
create_library_statistics("pipeline", single_model.name, None)
create_library_statistics('pipeline', single_model.name, None)
# place model to cpu or gpu # place model to cpu or gpu
if (self.model or (self.has_multiple_models and self.models[0])): if (self.model or (self.has_multiple_models and self.models[0])):
if not self._model_prepare: if not self._model_prepare:


+ 3
- 3
modelscope/trainers/trainer.py View File

@@ -14,8 +14,8 @@ from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataloader import default_collate from torch.utils.data.dataloader import default_collate
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler


from modelscope.hub.utils.utils import create_library_statistics
from modelscope.hub.snapshot_download import snapshot_download from modelscope.hub.snapshot_download import snapshot_download
from modelscope.hub.utils.utils import create_library_statistics
from modelscope.metainfo import Trainers from modelscope.metainfo import Trainers
from modelscope.metrics import build_metric, task_default_metrics from modelscope.metrics import build_metric, task_default_metrics
from modelscope.models.base import Model, TorchModel from modelscope.models.base import Model, TorchModel
@@ -438,7 +438,7 @@ class EpochBasedTrainer(BaseTrainer):
def train(self, checkpoint_path=None, *args, **kwargs): def train(self, checkpoint_path=None, *args, **kwargs):
self._mode = ModeKeys.TRAIN self._mode = ModeKeys.TRAIN
if hasattr(self.model, 'name'): if hasattr(self.model, 'name'):
create_library_statistics("train", self.model.name, None)
create_library_statistics('train', self.model.name, None)


if self.train_dataset is None: if self.train_dataset is None:
self.train_dataloader = self.get_train_dataloader() self.train_dataloader = self.get_train_dataloader()
@@ -460,7 +460,7 @@ class EpochBasedTrainer(BaseTrainer):


def evaluate(self, checkpoint_path=None): def evaluate(self, checkpoint_path=None):
if hasattr(self.model, 'name'): if hasattr(self.model, 'name'):
create_library_statistics("evaluate", self.model.name, None)
create_library_statistics('evaluate', self.model.name, None)
if checkpoint_path is not None and os.path.isfile(checkpoint_path): if checkpoint_path is not None and os.path.isfile(checkpoint_path):
from modelscope.trainers.hooks import CheckpointHook from modelscope.trainers.hooks import CheckpointHook
CheckpointHook.load_checkpoint(checkpoint_path, self) CheckpointHook.load_checkpoint(checkpoint_path, self)


Loading…
Cancel
Save