baiguan.yt yingda.chen 3 years ago
parent
commit
d8c98c51f8
3 changed files with 73 additions and 32 deletions
  1. +6
    -8
      modelscope/pipelines/cv/image_colorization_pipeline.py
  2. +38
    -19
      modelscope/pipelines/cv/image_portrait_enhancement_pipeline.py
  3. +29
    -5
      modelscope/pipelines/cv/image_super_resolution_pipeline.py

+ 6
- 8
modelscope/pipelines/cv/image_colorization_pipeline.py View File

@@ -12,7 +12,7 @@ from modelscope.models.cv.image_colorization import (DynamicUnetDeep,
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import LoadImage, load_image
from modelscope.preprocessors import LoadImage
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger

@@ -63,11 +63,11 @@ class ImageColorizationPipeline(Pipeline):
last_cross=True,
bottle=False,
nf_factor=2,
)
).to(self.device)
else:
body = models.resnet34(pretrained=True)
body = torch.nn.Sequential(*list(body.children())[:cut])
model = DynamicUnetDeep(
self.model = DynamicUnetDeep(
body,
n_classes=3,
blur=True,
@@ -78,7 +78,7 @@ class ImageColorizationPipeline(Pipeline):
last_cross=True,
bottle=False,
nf_factor=1.5,
)
).to(self.device)

model_path = f'{model}/{ModelFile.TORCH_MODEL_FILE}'
self.model.load_state_dict(
@@ -91,10 +91,8 @@ class ImageColorizationPipeline(Pipeline):
img = LoadImage.convert_to_img(input).convert('LA').convert('RGB')

self.wide, self.height = img.size
if self.wide * self.height > self.size * self.size:
self.orig_img = img.copy()
img = img.resize((self.size, self.size),
resample=PIL.Image.BILINEAR)
self.orig_img = img.copy()
img = img.resize((self.size, self.size), resample=PIL.Image.BILINEAR)

img = self.norm(img).unsqueeze(0).to(self.device)
result = {'img': img}


+ 38
- 19
modelscope/pipelines/cv/image_portrait_enhancement_pipeline.py View File

@@ -5,6 +5,7 @@ import cv2
import numpy as np
import PIL
import torch
import torch.nn.functional as F
from scipy.ndimage import gaussian_filter
from scipy.spatial.distance import pdist, squareform

@@ -118,7 +119,7 @@ class ImagePortraitEnhancementPipeline(Pipeline):
img_t = torch.from_numpy(img).to(self.device) / 255.
if is_norm:
img_t = (img_t - 0.5) / 0.5
img_t = img_t.permute(2, 0, 1).unsqueeze(0).flip(1) # BGR->RGB
img_t = img_t.permute(2, 0, 1).unsqueeze(0)
return img_t

def tensor2img(self, img_t, pmax=255.0, is_denorm=True, imtype=np.uint8):
@@ -129,19 +130,46 @@ class ImagePortraitEnhancementPipeline(Pipeline):

return img_np.astype(imtype)

def sr_process(self, img):
img = img.astype(np.float32) / 255.
img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
img = img.unsqueeze(0).to(self.device)

if self.scale == 2:
mod_scale = 2
elif self.scale == 1:
mod_scale = 4
else:
mod_scale = None
if mod_scale is not None:
h_pad, w_pad = 0, 0
_, _, h, w = img.size()
if (h % mod_scale != 0):
h_pad = (mod_scale - h % mod_scale)
if (w % mod_scale != 0):
w_pad = (mod_scale - w % mod_scale)
img = F.pad(img, (0, w_pad, 0, h_pad), 'reflect')

self.sr_model.eval()
with torch.no_grad():
output = self.sr_model(img)
del img
# remove extra pad
if mod_scale is not None:
_, _, h, w = output.size()
output = output[:, :, 0:h - h_pad, 0:w - w_pad]
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
output = (output * 255.0).round().astype(np.uint8)

return output

def preprocess(self, input: Input) -> Dict[str, Any]:
img = LoadImage.convert_to_ndarray(input)

img_sr = None
img_sr = img
if self.use_sr:
self.sr_model.eval()
with torch.no_grad():
img_t = self.img2tensor(img, is_norm=False)
img_out = self.sr_model(img_t)

img_sr = img_out.squeeze(0).permute(1, 2, 0).flip(2).cpu().clamp_(
0, 1).numpy()
img_sr = (img_sr * 255.0).round().astype(np.uint8)
img_sr = self.sr_process(img)

img = cv2.resize(img, img_sr.shape[:2][::-1])

@@ -160,7 +188,6 @@ class ImagePortraitEnhancementPipeline(Pipeline):
for i, (faceb, facial5points) in enumerate(zip(facebs, landms)):
if faceb[4] < self.threshold:
continue
# fh, fw = (faceb[3] - faceb[1]), (faceb[2] - faceb[0])

facial5points = np.reshape(facial5points, (2, 5))

@@ -184,14 +211,6 @@ class ImagePortraitEnhancementPipeline(Pipeline):
if dist > self.id_thres:
continue

# blending parameter
fq = max(1., (fq_o - self.fqa_thres))
fq = (1 - 2 * dist) * (1.0 / (1 + math.exp(-(2 * fq - 1))))

# blend face
ef = cv2.addWeighted(ef, fq * self.alpha, of, 1 - fq * self.alpha,
0.0)

tmp_mask = self.mask
tmp_mask = cv2.resize(tmp_mask, ef.shape[:2])
tmp_mask = cv2.warpAffine(


+ 29
- 5
modelscope/pipelines/cv/image_super_resolution_pipeline.py View File

@@ -4,6 +4,7 @@ import cv2
import numpy as np
import PIL
import torch
import torch.nn.functional as F

from modelscope.metainfo import Pipelines
from modelscope.models.cv.super_resolution import RRDBNet
@@ -32,6 +33,7 @@ class ImageSuperResolutionPipeline(Pipeline):
self.device = torch.device('cuda')
else:
self.device = torch.device('cpu')

self.num_feat = 64
self.num_block = 23
self.scale = 4
@@ -58,13 +60,35 @@ class ImageSuperResolutionPipeline(Pipeline):

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
self.sr_model.eval()
with torch.no_grad():
out = self.sr_model(input['img'])

out = out.squeeze(0).permute(1, 2, 0).flip(2)
out_img = np.clip(out.float().cpu().numpy(), 0, 1) * 255
img = input['img']
if self.scale == 2:
mod_scale = 2
elif self.scale == 1:
mod_scale = 4
else:
mod_scale = None
if mod_scale is not None:
h_pad, w_pad = 0, 0
_, _, h, w = img.size()
if (h % mod_scale != 0):
h_pad = (mod_scale - h % mod_scale)
if (w % mod_scale != 0):
w_pad = (mod_scale - w % mod_scale)
img = F.pad(img, (0, w_pad, 0, h_pad), 'reflect')

with torch.no_grad():
output = self.sr_model(img)
del img
# remove extra pad
if mod_scale is not None:
_, _, h, w = output.size()
output = output[:, :, 0:h - h_pad, 0:w - w_pad]
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
output = (output * 255.0).round().astype(np.uint8)

return {OutputKeys.OUTPUT_IMG: out_img.astype(np.uint8)}
return {OutputKeys.OUTPUT_IMG: output}

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

Loading…
Cancel
Save