|
|
|
@@ -14,52 +14,14 @@ from modelscope.hub.snapshot_download import snapshot_download |
|
|
|
from modelscope.metainfo import Trainers |
|
|
|
from modelscope.models.cv.image_portrait_enhancement import \ |
|
|
|
ImagePortraitEnhancement |
|
|
|
from modelscope.msdatasets import MsDataset |
|
|
|
from modelscope.msdatasets.task_datasets.image_portrait_enhancement import \ |
|
|
|
ImagePortraitEnhancementDataset |
|
|
|
from modelscope.trainers import build_trainer |
|
|
|
from modelscope.utils.constant import ModelFile |
|
|
|
from modelscope.utils.constant import DownloadMode, ModelFile |
|
|
|
from modelscope.utils.test_utils import test_level |
|
|
|
|
|
|
|
|
|
|
|
class PairedImageDataset(data.Dataset): |
|
|
|
|
|
|
|
def __init__(self, root, size=512): |
|
|
|
super(PairedImageDataset, self).__init__() |
|
|
|
self.size = size |
|
|
|
gt_dir = osp.join(root, 'gt') |
|
|
|
lq_dir = osp.join(root, 'lq') |
|
|
|
self.gt_filelist = os.listdir(gt_dir) |
|
|
|
self.gt_filelist = sorted(self.gt_filelist, key=lambda x: int(x[:-4])) |
|
|
|
self.gt_filelist = [osp.join(gt_dir, f) for f in self.gt_filelist] |
|
|
|
self.lq_filelist = os.listdir(lq_dir) |
|
|
|
self.lq_filelist = sorted(self.lq_filelist, key=lambda x: int(x[:-4])) |
|
|
|
self.lq_filelist = [osp.join(lq_dir, f) for f in self.lq_filelist] |
|
|
|
|
|
|
|
def _img_to_tensor(self, img): |
|
|
|
img = torch.from_numpy(img[:, :, [2, 1, 0]]).permute(2, 0, 1).type( |
|
|
|
torch.float32) / 255. |
|
|
|
return (img - 0.5) / 0.5 |
|
|
|
|
|
|
|
def __getitem__(self, index): |
|
|
|
lq = cv2.imread(self.lq_filelist[index]) |
|
|
|
gt = cv2.imread(self.gt_filelist[index]) |
|
|
|
lq = cv2.resize( |
|
|
|
lq, (self.size, self.size), interpolation=cv2.INTER_CUBIC) |
|
|
|
gt = cv2.resize( |
|
|
|
gt, (self.size, self.size), interpolation=cv2.INTER_CUBIC) |
|
|
|
|
|
|
|
return \ |
|
|
|
{'src': self._img_to_tensor(lq), 'target': self._img_to_tensor(gt)} |
|
|
|
|
|
|
|
def __len__(self): |
|
|
|
return len(self.gt_filelist) |
|
|
|
|
|
|
|
def to_torch_dataset(self, |
|
|
|
columns: Union[str, List[str]] = None, |
|
|
|
preprocessors: Union[Callable, List[Callable]] = None, |
|
|
|
**format_kwargs): |
|
|
|
# self.preprocessor = preprocessors |
|
|
|
return self |
|
|
|
|
|
|
|
|
|
|
|
class TestImagePortraitEnhancementTrainer(unittest.TestCase): |
|
|
|
|
|
|
|
def setUp(self): |
|
|
|
@@ -70,8 +32,23 @@ class TestImagePortraitEnhancementTrainer(unittest.TestCase): |
|
|
|
|
|
|
|
self.model_id = 'damo/cv_gpen_image-portrait-enhancement' |
|
|
|
|
|
|
|
self.dataset = PairedImageDataset( |
|
|
|
'./data/test/images/face_enhancement/') |
|
|
|
dataset_train = MsDataset.load( |
|
|
|
'image-portrait-enhancement-dataset', |
|
|
|
namespace='modelscope', |
|
|
|
subset_name='default', |
|
|
|
split='test', |
|
|
|
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)._hf_ds |
|
|
|
dataset_val = MsDataset.load( |
|
|
|
'image-portrait-enhancement-dataset', |
|
|
|
namespace='modelscope', |
|
|
|
subset_name='default', |
|
|
|
split='test', |
|
|
|
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)._hf_ds |
|
|
|
|
|
|
|
self.dataset_train = ImagePortraitEnhancementDataset( |
|
|
|
dataset_train, is_train=True) |
|
|
|
self.dataset_val = ImagePortraitEnhancementDataset( |
|
|
|
dataset_val, is_train=False) |
|
|
|
|
|
|
|
def tearDown(self): |
|
|
|
shutil.rmtree(self.tmp_dir, ignore_errors=True) |
|
|
|
@@ -81,8 +58,8 @@ class TestImagePortraitEnhancementTrainer(unittest.TestCase): |
|
|
|
def test_trainer(self): |
|
|
|
kwargs = dict( |
|
|
|
model=self.model_id, |
|
|
|
train_dataset=self.dataset, |
|
|
|
eval_dataset=self.dataset, |
|
|
|
train_dataset=self.dataset_train, |
|
|
|
eval_dataset=self.dataset_val, |
|
|
|
device='gpu', |
|
|
|
work_dir=self.tmp_dir) |
|
|
|
|
|
|
|
@@ -101,8 +78,8 @@ class TestImagePortraitEnhancementTrainer(unittest.TestCase): |
|
|
|
kwargs = dict( |
|
|
|
cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), |
|
|
|
model=model, |
|
|
|
train_dataset=self.dataset, |
|
|
|
eval_dataset=self.dataset, |
|
|
|
train_dataset=self.dataset_train, |
|
|
|
eval_dataset=self.dataset_val, |
|
|
|
device='gpu', |
|
|
|
max_epochs=2, |
|
|
|
work_dir=self.tmp_dir) |
|
|
|
|