Browse Source

[to #42322933] Add video-inpainting files

视频编辑的cr
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10026166
master
tingwei.gtw yingda.chen 3 years ago
parent
commit
7b23c41748
13 changed files with 798 additions and 1 deletions
  1. +3
    -0
      data/test/videos/mask_dir/mask_00000_00320.png
  2. +3
    -0
      data/test/videos/mask_dir/mask_00321_00633.png
  3. +3
    -0
      data/test/videos/video_inpainting_test.mp4
  4. +2
    -0
      modelscope/metainfo.py
  5. +20
    -0
      modelscope/models/cv/video_inpainting/__init__.py
  6. +298
    -0
      modelscope/models/cv/video_inpainting/inpainting.py
  7. +373
    -0
      modelscope/models/cv/video_inpainting/inpainting_model.py
  8. +5
    -0
      modelscope/outputs.py
  9. +2
    -0
      modelscope/pipelines/builder.py
  10. +47
    -0
      modelscope/pipelines/cv/video_inpainting_pipeline.py
  11. +3
    -0
      modelscope/utils/constant.py
  12. +0
    -1
      tests/pipelines/test_person_image_cartoon.py
  13. +39
    -0
      tests/pipelines/test_video_inpainting.py

+ 3
- 0
data/test/videos/mask_dir/mask_00000_00320.png View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b158f6029d9763d7f84042f7c5835f398c688fdbb6b3f4fe6431101d4118c66c
size 2766

+ 3
- 0
data/test/videos/mask_dir/mask_00321_00633.png View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0dcf46b93077e2229ab69cd6ddb80e2689546c575ee538bb2033fee1124ef3e3
size 2761

+ 3
- 0
data/test/videos/video_inpainting_test.mp4 View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9c9870df5a86acaaec67063183dace795479cd0f05296f13058995f475149c56
size 2957783

+ 2
- 0
modelscope/metainfo.py View File

@@ -38,6 +38,7 @@ class Models(object):
mogface = 'mogface'
mtcnn = 'mtcnn'
ulfd = 'ulfd'
video_inpainting = 'video-inpainting'

# EasyCV models
yolox = 'YOLOX'
@@ -169,6 +170,7 @@ class Pipelines(object):
text_driven_segmentation = 'text-driven-segmentation'
movie_scene_segmentation = 'resnet50-bert-movie-scene-segmentation'
shop_segmentation = 'shop-segmentation'
video_inpainting = 'video-inpainting'

# nlp tasks
sentence_similarity = 'sentence-similarity'


+ 20
- 0
modelscope/models/cv/video_inpainting/__init__.py View File

@@ -0,0 +1,20 @@
# copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING

from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .inpainting_model import VideoInpainting

else:
_import_structure = {'inpainting_model': ['VideoInpainting']}

import sys

sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

+ 298
- 0
modelscope/models/cv/video_inpainting/inpainting.py View File

@@ -0,0 +1,298 @@
""" VideoInpaintingProcess
Base modules are adapted from https://github.com/researchmm/STTN,
originally Apache 2.0 License, Copyright (c) 2018-2022 OpenMMLab,
"""

import os
import time

import cv2
import numpy as np
import torch
from PIL import Image
from torchvision import transforms

torch.backends.cudnn.enabled = False

w, h = 192, 96
ref_length = 300
neighbor_stride = 20
default_fps = 24
MAX_frame = 300


def video_process(video_input_path):
video_input = cv2.VideoCapture(video_input_path)
success, frame = video_input.read()
if success is False:
decode_error = 'decode_error'
w, h, fps = 0, 0, 0
else:
decode_error = None
h, w = frame.shape[0:2]
fps = video_input.get(cv2.CAP_PROP_FPS)
video_input.release()

return decode_error, fps, w, h


class Stack(object):

def __init__(self, roll=False):
self.roll = roll

def __call__(self, img_group):
mode = img_group[0].mode
if mode == '1':
img_group = [img.convert('L') for img in img_group]
mode = 'L'
if mode == 'L':
return np.stack([np.expand_dims(x, 2) for x in img_group], axis=2)
elif mode == 'RGB':
if self.roll:
return np.stack([np.array(x)[:, :, ::-1] for x in img_group],
axis=2)
else:
return np.stack(img_group, axis=2)
else:
raise NotImplementedError(f'Image mode {mode}')


class ToTorchFormatTensor(object):
""" Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """

def __init__(self, div=True):
self.div = div

def __call__(self, pic):
if isinstance(pic, np.ndarray):
img = torch.from_numpy(pic).permute(2, 3, 0, 1).contiguous()
else:
img = torch.ByteTensor(
torch.ByteStorage.from_buffer(pic.tobytes()))
img = img.view(pic.size[1], pic.size[0], len(pic.mode))
img = img.transpose(0, 1).transpose(0, 2).contiguous()
img = img.float().div(255) if self.div else img.float()
return img


_to_tensors = transforms.Compose([Stack(), ToTorchFormatTensor()])


def get_crop_mask_v1(mask):
orig_h, orig_w, _ = mask.shape
if (mask == 255).all():
return mask, (0, int(orig_h), 0,
int(orig_w)), [0, int(orig_h), 0,
int(orig_w)
], [0, int(orig_h), 0,
int(orig_w)]

hs = np.min(np.where(mask == 0)[0])
he = np.max(np.where(mask == 0)[0])
ws = np.min(np.where(mask == 0)[1])
we = np.max(np.where(mask == 0)[1])
crop_box = [ws, hs, we, he]

mask_h = round(int(orig_h / 2) / 4) * 4
mask_w = round(int(orig_w / 2) / 4) * 4

if (hs < mask_h) and (he < mask_h) and (ws < mask_w) and (we < mask_w):
crop_mask = mask[:mask_h, :mask_w, :]
res_pix = (0, mask_h, 0, mask_w)
elif (hs < mask_h) and (he < mask_h) and (ws > mask_w) and (we > mask_w):
crop_mask = mask[:mask_h, orig_w - mask_w:orig_w, :]
res_pix = (0, mask_h, orig_w - mask_w, int(orig_w))
elif (hs > mask_h) and (he > mask_h) and (ws < mask_w) and (we < mask_w):
crop_mask = mask[orig_h - mask_h:orig_h, :mask_w, :]
res_pix = (orig_h - mask_h, int(orig_h), 0, mask_w)
elif (hs > mask_h) and (he > mask_h) and (ws > mask_w) and (we > mask_w):
crop_mask = mask[orig_h - mask_h:orig_h, orig_w - mask_w:orig_w, :]
res_pix = (orig_h - mask_h, int(orig_h), orig_w - mask_w, int(orig_w))

elif (hs < mask_h) and (he < mask_h) and (ws < mask_w) and (we > mask_w):
crop_mask = mask[:mask_h, :, :]
res_pix = (0, mask_h, 0, int(orig_w))
elif (hs < mask_h) and (he > mask_h) and (ws < mask_w) and (we < mask_w):
crop_mask = mask[:, :mask_w, :]
res_pix = (0, int(orig_h), 0, mask_w)
elif (hs > mask_h) and (he > mask_h) and (ws < mask_w) and (we > mask_w):
crop_mask = mask[orig_h - mask_h:orig_h, :, :]
res_pix = (orig_h - mask_h, int(orig_h), 0, int(orig_w))
elif (hs < mask_h) and (he > mask_h) and (ws > mask_w) and (we > mask_w):
crop_mask = mask[:, orig_w - mask_w:orig_w, :]
res_pix = (0, int(orig_h), orig_w - mask_w, int(orig_w))
else:
crop_mask = mask
res_pix = (0, int(orig_h), 0, int(orig_w))
a = ws - res_pix[2]
b = hs - res_pix[0]
c = we - res_pix[2]
d = he - res_pix[0]
return crop_mask, res_pix, crop_box, [a, b, c, d]


def get_ref_index(neighbor_ids, length):
ref_index = []
for i in range(0, length, ref_length):
if i not in neighbor_ids:
ref_index.append(i)
return ref_index


def read_mask_oneImage(mpath):
masks = []
print('mask_path: {}'.format(mpath))
start = int(mpath.split('/')[-1].split('mask_')[1].split('_')[0])
end = int(
mpath.split('/')[-1].split('mask_')[1].split('_')[1].split('.')[0])
m = Image.open(mpath)
m = np.array(m.convert('L'))
m = np.array(m > 0).astype(np.uint8)
m = 1 - m
for i in range(start - 1, end + 1):
masks.append(Image.fromarray(m * 255))
return masks


def check_size(h, w):
is_resize = False
if h != 240:
h = 240
is_resize = True
if w != 432:
w = 432
is_resize = True
return is_resize


def get_mask_list(mask_path):
mask_names = os.listdir(mask_path)
mask_names.sort()

abs_mask_path = []
mask_list = []
begin_list = []
end_list = []

for mask_name in mask_names:
mask_name_tmp = mask_name.split('mask_')[1]
begin_list.append(int(mask_name_tmp.split('_')[0]))
end_list.append(int(mask_name_tmp.split('_')[1].split('.')[0]))
abs_mask_path.append(os.path.join(mask_path, mask_name))
mask = cv2.imread(os.path.join(mask_path, mask_name))
mask_list.append(mask)
return mask_list, begin_list, end_list, abs_mask_path


def inpainting_by_model_balance(model, video_inputPath, mask_path,
video_savePath, fps, w_ori, h_ori):

video_ori = cv2.VideoCapture(video_inputPath)

fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video_save = cv2.VideoWriter(video_savePath, fourcc, fps, (w_ori, h_ori))

mask_list, begin_list, end_list, abs_mask_path = get_mask_list(mask_path)

img_npy = []

for index, mask in enumerate(mask_list):

masks = read_mask_oneImage(abs_mask_path[index])

mask, res_pix, crop_for_oriimg, crop_for_inpimg = get_crop_mask_v1(
mask)
mask_h, mask_w = mask.shape[0:2]
is_resize = check_size(mask.shape[0], mask.shape[1])

begin = begin_list[index]
end = end_list[index]
print('begin: {}'.format(begin))
print('end: {}'.format(end))

for i in range(begin, end + 1, MAX_frame):
begin_time = time.time()
if i + MAX_frame <= end:
video_length = MAX_frame
else:
video_length = end - i + 1

for frame_count in range(video_length):
_, frame = video_ori.read()
img_npy.append(frame)
frames_temp = []
for f in img_npy:
f = Image.fromarray(f)
i_temp = f.crop(
(res_pix[2], res_pix[0], res_pix[3], res_pix[1]))
a = i_temp.resize((w, h), Image.NEAREST)
frames_temp.append(a)
feats_temp = _to_tensors(frames_temp).unsqueeze(0) * 2 - 1
frames_temp = [np.array(f).astype(np.uint8) for f in frames_temp]
masks_temp = []
for m in masks[i - begin:i + video_length - begin]:

m_temp = m.crop(
(res_pix[2], res_pix[0], res_pix[3], res_pix[1]))
b = m_temp.resize((w, h), Image.NEAREST)
masks_temp.append(b)
binary_masks_temp = [
np.expand_dims((np.array(m) != 0).astype(np.uint8), 2)
for m in masks_temp
]
masks_temp = _to_tensors(masks_temp).unsqueeze(0)
feats_temp, masks_temp = feats_temp.cuda(), masks_temp.cuda()
comp_frames = [None] * video_length
model.eval()
with torch.no_grad():
feats_out = feats_temp * (1 - masks_temp).float()
feats_out = feats_out.view(video_length, 3, h, w)
feats_out = model.model.encoder(feats_out)
_, c, feat_h, feat_w = feats_out.size()
feats_out = feats_out.view(1, video_length, c, feat_h, feat_w)

for f in range(0, video_length, neighbor_stride):
neighbor_ids = [
i for i in range(
max(0, f - neighbor_stride),
min(video_length, f + neighbor_stride + 1))
]
ref_ids = get_ref_index(neighbor_ids, video_length)
with torch.no_grad():
pred_feat = model.model.infer(
feats_out[0, neighbor_ids + ref_ids, :, :, :],
masks_temp[0, neighbor_ids + ref_ids, :, :, :])
pred_img = torch.tanh(
model.model.decoder(
pred_feat[:len(neighbor_ids), :, :, :])).detach()
pred_img = (pred_img + 1) / 2
pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy() * 255
for j in range(len(neighbor_ids)):
idx = neighbor_ids[j]
img = np.array(pred_img[j]).astype(
np.uint8) * binary_masks_temp[idx] + frames_temp[
idx] * (1 - binary_masks_temp[idx])
if comp_frames[idx] is None:
comp_frames[idx] = img
else:
comp_frames[idx] = comp_frames[idx].astype(
np.float32) * 0.5 + img.astype(
np.float32) * 0.5
print('inpainting time:', time.time() - begin_time)
for f in range(video_length):
comp = np.array(comp_frames[f]).astype(
np.uint8) * binary_masks_temp[f] + frames_temp[f] * (
1 - binary_masks_temp[f])
if is_resize:
comp = cv2.resize(comp, (mask_w, mask_h))
complete_frame = img_npy[f]
a1, b1, c1, d1 = crop_for_oriimg
a2, b2, c2, d2 = crop_for_inpimg
complete_frame[b1:d1, a1:c1] = comp[b2:d2, a2:c2]
video_save.write(complete_frame)

img_npy = []

video_ori.release()

+ 373
- 0
modelscope/models/cv/video_inpainting/inpainting_model.py View File

@@ -0,0 +1,373 @@
""" VideoInpaintingNetwork
Base modules are adapted from https://github.com/researchmm/STTN,
originally Apache 2.0 License, Copyright (c) 2018-2022 OpenMMLab,
"""

import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from modelscope.metainfo import Models
from modelscope.models.base import TorchModel
from modelscope.models.builder import MODELS
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()


class BaseNetwork(nn.Module):

def __init__(self):
super(BaseNetwork, self).__init__()

def print_network(self):
if isinstance(self, list):
self = self[0]
num_params = 0
for param in self.parameters():
num_params += param.numel()
print(
'Network [%s] was created. Total number of parameters: %.1f million. '
'To see the architecture, do print(network).' %
(type(self).__name__, num_params / 1000000))

def init_weights(self, init_type='normal', gain=0.02):
'''
initialize network's weights
init_type: normal | xavier | kaiming | orthogonal
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
'''

def init_func(m):
classname = m.__class__.__name__
if classname.find('InstanceNorm2d') != -1:
if hasattr(m, 'weight') and m.weight is not None:
nn.init.constant_(m.weight.data, 1.0)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
elif hasattr(m, 'weight') and (classname.find('Conv') != -1
or classname.find('Linear') != -1):
if init_type == 'normal':
nn.init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
nn.init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'xavier_uniform':
nn.init.xavier_uniform_(m.weight.data, gain=1.0)
elif init_type == 'kaiming':
nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
nn.init.orthogonal_(m.weight.data, gain=gain)
elif init_type == 'none':
m.reset_parameters()
else:
raise NotImplementedError(
'initialization method [%s] is not implemented'
% init_type)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)

self.apply(init_func)

for m in self.children():
if hasattr(m, 'init_weights'):
m.init_weights(init_type, gain)


@MODELS.register_module(
Tasks.video_inpainting, module_name=Models.video_inpainting)
class VideoInpainting(TorchModel):

def __init__(self, model_dir, device_id=0, *args, **kwargs):
super().__init__(
model_dir=model_dir, device_id=device_id, *args, **kwargs)
self.model = InpaintGenerator()
pretrained_params = torch.load('{}/{}'.format(
model_dir, ModelFile.TORCH_MODEL_BIN_FILE))
self.model.load_state_dict(pretrained_params['netG'])
self.model.eval()
self.device_id = device_id
if self.device_id >= 0 and torch.cuda.is_available():
self.model.to('cuda:{}'.format(self.device_id))
logger.info('Use GPU: {}'.format(self.device_id))
else:
self.device_id = -1
logger.info('Use CPU for inference')


class InpaintGenerator(BaseNetwork):

def __init__(self, init_weights=True):
super(InpaintGenerator, self).__init__()
channel = 256
stack_num = 6
patchsize = [(48, 24), (16, 8), (8, 4), (4, 2)]
blocks = []
for _ in range(stack_num):
blocks.append(TransformerBlock(patchsize, hidden=channel))
self.transformer = nn.Sequential(*blocks)

self.encoder = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, channel, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2, inplace=True),
)

self.decoder = nn.Sequential(
deconv(channel, 128, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2, inplace=True),
deconv(64, 64, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1))

if init_weights:
self.init_weights()

def forward(self, masked_frames, masks):
b, t, c, h, w = masked_frames.size()
masks = masks.view(b * t, 1, h, w)
enc_feat = self.encoder(masked_frames.view(b * t, c, h, w))
_, c, h, w = enc_feat.size()
masks = F.interpolate(masks, scale_factor=1.0 / 4)
enc_feat = self.transformer({
'x': enc_feat,
'm': masks,
'b': b,
'c': c
})['x']
output = self.decoder(enc_feat)
output = torch.tanh(output)
return output

def infer(self, feat, masks):
t, c, h, w = masks.size()
masks = masks.view(t, c, h, w)
masks = F.interpolate(masks, scale_factor=1.0 / 4)
t, c, _, _ = feat.size()
enc_feat = self.transformer({
'x': feat,
'm': masks,
'b': 1,
'c': c
})['x']
return enc_feat


class deconv(nn.Module):

def __init__(self,
input_channel,
output_channel,
kernel_size=3,
padding=0):
super().__init__()
self.conv = nn.Conv2d(
input_channel,
output_channel,
kernel_size=kernel_size,
stride=1,
padding=padding)

def forward(self, x):
x = F.interpolate(
x, scale_factor=2, mode='bilinear', align_corners=True)
x = self.conv(x)
return x


class Attention(nn.Module):
"""
Compute 'Scaled Dot Product Attention
"""

def forward(self, query, key, value, m):
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(
query.size(-1))
scores.masked_fill(m, -1e9)
p_attn = F.softmax(scores, dim=-1)
p_val = torch.matmul(p_attn, value)
return p_val, p_attn


class MultiHeadedAttention(nn.Module):
"""
Take in model size and number of heads.
"""

def __init__(self, patchsize, d_model):
super().__init__()
self.patchsize = patchsize
self.query_embedding = nn.Conv2d(
d_model, d_model, kernel_size=1, padding=0)
self.value_embedding = nn.Conv2d(
d_model, d_model, kernel_size=1, padding=0)
self.key_embedding = nn.Conv2d(
d_model, d_model, kernel_size=1, padding=0)
self.output_linear = nn.Sequential(
nn.Conv2d(d_model, d_model, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, inplace=True))
self.attention = Attention()

def forward(self, x, m, b, c):
bt, _, h, w = x.size()
t = bt // b
d_k = c // len(self.patchsize)
output = []
_query = self.query_embedding(x)
_key = self.key_embedding(x)
_value = self.value_embedding(x)
for (width, height), query, key, value in zip(
self.patchsize,
torch.chunk(_query, len(self.patchsize), dim=1),
torch.chunk(_key, len(self.patchsize), dim=1),
torch.chunk(_value, len(self.patchsize), dim=1)):
out_w, out_h = w // width, h // height
mm = m.view(b, t, 1, out_h, height, out_w, width)
mm = mm.permute(0, 1, 3, 5, 2, 4,
6).contiguous().view(b, t * out_h * out_w,
height * width)
mm = (mm.mean(-1) > 0.5).unsqueeze(1).repeat(
1, t * out_h * out_w, 1)
query = query.view(b, t, d_k, out_h, height, out_w, width)
query = query.permute(0, 1, 3, 5, 2, 4,
6).contiguous().view(b, t * out_h * out_w,
d_k * height * width)
key = key.view(b, t, d_k, out_h, height, out_w, width)
key = key.permute(0, 1, 3, 5, 2, 4,
6).contiguous().view(b, t * out_h * out_w,
d_k * height * width)
value = value.view(b, t, d_k, out_h, height, out_w, width)
value = value.permute(0, 1, 3, 5, 2, 4,
6).contiguous().view(b, t * out_h * out_w,
d_k * height * width)
y, _ = self.attention(query, key, value, mm)
y = y.view(b, t, out_h, out_w, d_k, height, width)
y = y.permute(0, 1, 4, 2, 5, 3, 6).contiguous().view(bt, d_k, h, w)
output.append(y)
output = torch.cat(output, 1)
x = self.output_linear(output)
return x


class FeedForward(nn.Module):

def __init__(self, d_model):
super(FeedForward, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(d_model, d_model, kernel_size=3, padding=2, dilation=2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(d_model, d_model, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, inplace=True))

def forward(self, x):
x = self.conv(x)
return x


class TransformerBlock(nn.Module):
"""
Transformer = MultiHead_Attention + Feed_Forward with sublayer connection
"""

def __init__(self, patchsize, hidden=128): # hidden=128
super().__init__()
self.attention = MultiHeadedAttention(patchsize, d_model=hidden)
self.feed_forward = FeedForward(hidden)

def forward(self, x):
x, m, b, c = x['x'], x['m'], x['b'], x['c']
x = x + self.attention(x, m, b, c)
x = x + self.feed_forward(x)
return {'x': x, 'm': m, 'b': b, 'c': c}


class Discriminator(BaseNetwork):

def __init__(self,
in_channels=3,
use_sigmoid=False,
use_spectral_norm=True,
init_weights=True):
super(Discriminator, self).__init__()
self.use_sigmoid = use_sigmoid
nf = 64

self.conv = nn.Sequential(
spectral_norm(
nn.Conv3d(
in_channels=in_channels,
out_channels=nf * 1,
kernel_size=(3, 5, 5),
stride=(1, 2, 2),
padding=1,
bias=not use_spectral_norm), use_spectral_norm),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(
nn.Conv3d(
nf * 1,
nf * 2,
kernel_size=(3, 5, 5),
stride=(1, 2, 2),
padding=(1, 2, 2),
bias=not use_spectral_norm), use_spectral_norm),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(
nn.Conv3d(
nf * 2,
nf * 4,
kernel_size=(3, 5, 5),
stride=(1, 2, 2),
padding=(1, 2, 2),
bias=not use_spectral_norm), use_spectral_norm),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(
nn.Conv3d(
nf * 4,
nf * 4,
kernel_size=(3, 5, 5),
stride=(1, 2, 2),
padding=(1, 2, 2),
bias=not use_spectral_norm), use_spectral_norm),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(
nn.Conv3d(
nf * 4,
nf * 4,
kernel_size=(3, 5, 5),
stride=(1, 2, 2),
padding=(1, 2, 2),
bias=not use_spectral_norm), use_spectral_norm),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv3d(
nf * 4,
nf * 4,
kernel_size=(3, 5, 5),
stride=(1, 2, 2),
padding=(1, 2, 2)))

if init_weights:
self.init_weights()

def forward(self, xs):
xs_t = torch.transpose(xs, 0, 1)
xs_t = xs_t.unsqueeze(0)
feat = self.conv(xs_t)
if self.use_sigmoid:
feat = torch.sigmoid(feat)
out = torch.transpose(feat, 1, 2)
return out


def spectral_norm(module, mode=True):
if mode:
return _spectral_norm(module)
return module

+ 5
- 0
modelscope/outputs.py View File

@@ -610,4 +610,9 @@ TASK_OUTPUTS = {
# "img_embedding": np.array with shape [1, D],
# }
Tasks.image_reid_person: [OutputKeys.IMG_EMBEDDING],

# {
# 'output': ['Done' / 'Decode_Error']
# }
Tasks.video_inpainting: [OutputKeys.OUTPUT]
}

+ 2
- 0
modelscope/pipelines/builder.py View File

@@ -168,6 +168,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/cv_resnet50-bert_video-scene-segmentation_movienet'),
Tasks.shop_segmentation: (Pipelines.shop_segmentation,
'damo/cv_vitb16_segmentation_shop-seg'),
Tasks.video_inpainting: (Pipelines.video_inpainting,
'damo/cv_video-inpainting'),
}




+ 47
- 0
modelscope/pipelines/cv/video_inpainting_pipeline.py View File

@@ -0,0 +1,47 @@
from typing import Any, Dict

from modelscope.metainfo import Pipelines
from modelscope.models.cv.video_inpainting import inpainting
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()


@PIPELINES.register_module(
Tasks.video_inpainting, module_name=Pipelines.video_inpainting)
class VideoInpaintingPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
"""
use `model` to create video inpainting pipeline for prediction
Args:
model: model id on modelscope hub.
"""

super().__init__(model=model, **kwargs)
logger.info('load model done')

def preprocess(self, input: Input) -> Dict[str, Any]:
return input

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
decode_error, fps, w, h = inpainting.video_process(
input['video_input_path'])

if decode_error is not None:
return {OutputKeys.OUTPUT: 'decode_error'}

inpainting.inpainting_by_model_balance(self.model,
input['video_input_path'],
input['mask_path'],
input['video_output_path'], fps,
w, h)

return {OutputKeys.OUTPUT: 'Done'}

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

+ 3
- 0
modelscope/utils/constant.py View File

@@ -70,6 +70,9 @@ class CVTasks(object):
crowd_counting = 'crowd-counting'
movie_scene_segmentation = 'movie-scene-segmentation'

# video editing
video_inpainting = 'video-inpainting'

# reid and tracking
video_single_object_tracking = 'video-single-object-tracking'
video_summarization = 'video-summarization'


+ 0
- 1
tests/pipelines/test_person_image_cartoon.py View File

@@ -1,5 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import os.path as osp
import unittest



+ 39
- 0
tests/pipelines/test_video_inpainting.py View File

@@ -0,0 +1,39 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest

from modelscope.pipelines import pipeline
from modelscope.pipelines.base import Pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class VideoInpaintingTest(unittest.TestCase):

def setUp(self) -> None:
self.model = 'damo/cv_video-inpainting'
self.mask_dir = 'data/test/videos/mask_dir'
self.video_in = 'data/test/videos/video_inpainting_test.mp4'
self.video_out = 'out.mp4'
self.input = {
'video_input_path': self.video_in,
'video_output_path': self.video_out,
'mask_path': self.mask_dir
}

def pipeline_inference(self, pipeline: Pipeline, input: str):
result = pipeline(input)
print(result)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_modelhub(self):
video_inpainting = pipeline(Tasks.video_inpainting, model=self.model)
self.pipeline_inference(video_inpainting, self.input)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_modelhub_default_model(self):
video_inpainting = pipeline(Tasks.video_inpainting)
self.pipeline_inference(video_inpainting, self.input)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save