From 7b23c417484f038cd0c7fcd76f2ca95e677cac94 Mon Sep 17 00:00:00 2001 From: "tingwei.gtw" Date: Wed, 7 Sep 2022 21:06:25 +0800 Subject: [PATCH] [to #42322933] Add video-inpainting files MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 视频编辑的cr Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10026166 --- .../test/videos/mask_dir/mask_00000_00320.png | 3 + .../test/videos/mask_dir/mask_00321_00633.png | 3 + data/test/videos/video_inpainting_test.mp4 | 3 + modelscope/metainfo.py | 2 + .../models/cv/video_inpainting/__init__.py | 20 + .../models/cv/video_inpainting/inpainting.py | 298 ++++++++++++++ .../cv/video_inpainting/inpainting_model.py | 373 ++++++++++++++++++ modelscope/outputs.py | 5 + modelscope/pipelines/builder.py | 2 + .../pipelines/cv/video_inpainting_pipeline.py | 47 +++ modelscope/utils/constant.py | 3 + tests/pipelines/test_person_image_cartoon.py | 1 - tests/pipelines/test_video_inpainting.py | 39 ++ 13 files changed, 798 insertions(+), 1 deletion(-) create mode 100644 data/test/videos/mask_dir/mask_00000_00320.png create mode 100644 data/test/videos/mask_dir/mask_00321_00633.png create mode 100644 data/test/videos/video_inpainting_test.mp4 create mode 100644 modelscope/models/cv/video_inpainting/__init__.py create mode 100644 modelscope/models/cv/video_inpainting/inpainting.py create mode 100644 modelscope/models/cv/video_inpainting/inpainting_model.py create mode 100644 modelscope/pipelines/cv/video_inpainting_pipeline.py create mode 100644 tests/pipelines/test_video_inpainting.py diff --git a/data/test/videos/mask_dir/mask_00000_00320.png b/data/test/videos/mask_dir/mask_00000_00320.png new file mode 100644 index 00000000..2eae71a1 --- /dev/null +++ b/data/test/videos/mask_dir/mask_00000_00320.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b158f6029d9763d7f84042f7c5835f398c688fdbb6b3f4fe6431101d4118c66c +size 2766 diff --git a/data/test/videos/mask_dir/mask_00321_00633.png b/data/test/videos/mask_dir/mask_00321_00633.png new file mode 100644 index 00000000..89633eb6 --- /dev/null +++ b/data/test/videos/mask_dir/mask_00321_00633.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0dcf46b93077e2229ab69cd6ddb80e2689546c575ee538bb2033fee1124ef3e3 +size 2761 diff --git a/data/test/videos/video_inpainting_test.mp4 b/data/test/videos/video_inpainting_test.mp4 new file mode 100644 index 00000000..61f96fac --- /dev/null +++ b/data/test/videos/video_inpainting_test.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9c9870df5a86acaaec67063183dace795479cd0f05296f13058995f475149c56 +size 2957783 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index f904b5df..1bb2c389 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -38,6 +38,7 @@ class Models(object): mogface = 'mogface' mtcnn = 'mtcnn' ulfd = 'ulfd' + video_inpainting = 'video-inpainting' # EasyCV models yolox = 'YOLOX' @@ -169,6 +170,7 @@ class Pipelines(object): text_driven_segmentation = 'text-driven-segmentation' movie_scene_segmentation = 'resnet50-bert-movie-scene-segmentation' shop_segmentation = 'shop-segmentation' + video_inpainting = 'video-inpainting' # nlp tasks sentence_similarity = 'sentence-similarity' diff --git a/modelscope/models/cv/video_inpainting/__init__.py b/modelscope/models/cv/video_inpainting/__init__.py new file mode 100644 index 00000000..fd93fe3c --- /dev/null +++ b/modelscope/models/cv/video_inpainting/__init__.py @@ -0,0 +1,20 @@ +# copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .inpainting_model import VideoInpainting + +else: + _import_structure = {'inpainting_model': ['VideoInpainting']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/video_inpainting/inpainting.py b/modelscope/models/cv/video_inpainting/inpainting.py new file mode 100644 index 00000000..9632e01c --- /dev/null +++ b/modelscope/models/cv/video_inpainting/inpainting.py @@ -0,0 +1,298 @@ +""" VideoInpaintingProcess +Base modules are adapted from https://github.com/researchmm/STTN, +originally Apache 2.0 License, Copyright (c) 2018-2022 OpenMMLab, +""" + +import os +import time + +import cv2 +import numpy as np +import torch +from PIL import Image +from torchvision import transforms + +torch.backends.cudnn.enabled = False + +w, h = 192, 96 +ref_length = 300 +neighbor_stride = 20 +default_fps = 24 +MAX_frame = 300 + + +def video_process(video_input_path): + video_input = cv2.VideoCapture(video_input_path) + success, frame = video_input.read() + if success is False: + decode_error = 'decode_error' + w, h, fps = 0, 0, 0 + else: + decode_error = None + h, w = frame.shape[0:2] + fps = video_input.get(cv2.CAP_PROP_FPS) + video_input.release() + + return decode_error, fps, w, h + + +class Stack(object): + + def __init__(self, roll=False): + self.roll = roll + + def __call__(self, img_group): + mode = img_group[0].mode + if mode == '1': + img_group = [img.convert('L') for img in img_group] + mode = 'L' + if mode == 'L': + return np.stack([np.expand_dims(x, 2) for x in img_group], axis=2) + elif mode == 'RGB': + if self.roll: + return np.stack([np.array(x)[:, :, ::-1] for x in img_group], + axis=2) + else: + return np.stack(img_group, axis=2) + else: + raise NotImplementedError(f'Image mode {mode}') + + +class ToTorchFormatTensor(object): + """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] + to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ + + def __init__(self, div=True): + self.div = div + + def __call__(self, pic): + if isinstance(pic, np.ndarray): + img = torch.from_numpy(pic).permute(2, 3, 0, 1).contiguous() + else: + img = torch.ByteTensor( + torch.ByteStorage.from_buffer(pic.tobytes())) + img = img.view(pic.size[1], pic.size[0], len(pic.mode)) + img = img.transpose(0, 1).transpose(0, 2).contiguous() + img = img.float().div(255) if self.div else img.float() + return img + + +_to_tensors = transforms.Compose([Stack(), ToTorchFormatTensor()]) + + +def get_crop_mask_v1(mask): + orig_h, orig_w, _ = mask.shape + if (mask == 255).all(): + return mask, (0, int(orig_h), 0, + int(orig_w)), [0, int(orig_h), 0, + int(orig_w) + ], [0, int(orig_h), 0, + int(orig_w)] + + hs = np.min(np.where(mask == 0)[0]) + he = np.max(np.where(mask == 0)[0]) + ws = np.min(np.where(mask == 0)[1]) + we = np.max(np.where(mask == 0)[1]) + crop_box = [ws, hs, we, he] + + mask_h = round(int(orig_h / 2) / 4) * 4 + mask_w = round(int(orig_w / 2) / 4) * 4 + + if (hs < mask_h) and (he < mask_h) and (ws < mask_w) and (we < mask_w): + crop_mask = mask[:mask_h, :mask_w, :] + res_pix = (0, mask_h, 0, mask_w) + elif (hs < mask_h) and (he < mask_h) and (ws > mask_w) and (we > mask_w): + crop_mask = mask[:mask_h, orig_w - mask_w:orig_w, :] + res_pix = (0, mask_h, orig_w - mask_w, int(orig_w)) + elif (hs > mask_h) and (he > mask_h) and (ws < mask_w) and (we < mask_w): + crop_mask = mask[orig_h - mask_h:orig_h, :mask_w, :] + res_pix = (orig_h - mask_h, int(orig_h), 0, mask_w) + elif (hs > mask_h) and (he > mask_h) and (ws > mask_w) and (we > mask_w): + crop_mask = mask[orig_h - mask_h:orig_h, orig_w - mask_w:orig_w, :] + res_pix = (orig_h - mask_h, int(orig_h), orig_w - mask_w, int(orig_w)) + + elif (hs < mask_h) and (he < mask_h) and (ws < mask_w) and (we > mask_w): + crop_mask = mask[:mask_h, :, :] + res_pix = (0, mask_h, 0, int(orig_w)) + elif (hs < mask_h) and (he > mask_h) and (ws < mask_w) and (we < mask_w): + crop_mask = mask[:, :mask_w, :] + res_pix = (0, int(orig_h), 0, mask_w) + elif (hs > mask_h) and (he > mask_h) and (ws < mask_w) and (we > mask_w): + crop_mask = mask[orig_h - mask_h:orig_h, :, :] + res_pix = (orig_h - mask_h, int(orig_h), 0, int(orig_w)) + elif (hs < mask_h) and (he > mask_h) and (ws > mask_w) and (we > mask_w): + crop_mask = mask[:, orig_w - mask_w:orig_w, :] + res_pix = (0, int(orig_h), orig_w - mask_w, int(orig_w)) + else: + crop_mask = mask + res_pix = (0, int(orig_h), 0, int(orig_w)) + a = ws - res_pix[2] + b = hs - res_pix[0] + c = we - res_pix[2] + d = he - res_pix[0] + return crop_mask, res_pix, crop_box, [a, b, c, d] + + +def get_ref_index(neighbor_ids, length): + ref_index = [] + for i in range(0, length, ref_length): + if i not in neighbor_ids: + ref_index.append(i) + return ref_index + + +def read_mask_oneImage(mpath): + masks = [] + print('mask_path: {}'.format(mpath)) + start = int(mpath.split('/')[-1].split('mask_')[1].split('_')[0]) + end = int( + mpath.split('/')[-1].split('mask_')[1].split('_')[1].split('.')[0]) + m = Image.open(mpath) + m = np.array(m.convert('L')) + m = np.array(m > 0).astype(np.uint8) + m = 1 - m + for i in range(start - 1, end + 1): + masks.append(Image.fromarray(m * 255)) + return masks + + +def check_size(h, w): + is_resize = False + if h != 240: + h = 240 + is_resize = True + if w != 432: + w = 432 + is_resize = True + return is_resize + + +def get_mask_list(mask_path): + mask_names = os.listdir(mask_path) + mask_names.sort() + + abs_mask_path = [] + mask_list = [] + begin_list = [] + end_list = [] + + for mask_name in mask_names: + mask_name_tmp = mask_name.split('mask_')[1] + begin_list.append(int(mask_name_tmp.split('_')[0])) + end_list.append(int(mask_name_tmp.split('_')[1].split('.')[0])) + abs_mask_path.append(os.path.join(mask_path, mask_name)) + mask = cv2.imread(os.path.join(mask_path, mask_name)) + mask_list.append(mask) + return mask_list, begin_list, end_list, abs_mask_path + + +def inpainting_by_model_balance(model, video_inputPath, mask_path, + video_savePath, fps, w_ori, h_ori): + + video_ori = cv2.VideoCapture(video_inputPath) + + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + video_save = cv2.VideoWriter(video_savePath, fourcc, fps, (w_ori, h_ori)) + + mask_list, begin_list, end_list, abs_mask_path = get_mask_list(mask_path) + + img_npy = [] + + for index, mask in enumerate(mask_list): + + masks = read_mask_oneImage(abs_mask_path[index]) + + mask, res_pix, crop_for_oriimg, crop_for_inpimg = get_crop_mask_v1( + mask) + mask_h, mask_w = mask.shape[0:2] + is_resize = check_size(mask.shape[0], mask.shape[1]) + + begin = begin_list[index] + end = end_list[index] + print('begin: {}'.format(begin)) + print('end: {}'.format(end)) + + for i in range(begin, end + 1, MAX_frame): + begin_time = time.time() + if i + MAX_frame <= end: + video_length = MAX_frame + else: + video_length = end - i + 1 + + for frame_count in range(video_length): + _, frame = video_ori.read() + img_npy.append(frame) + frames_temp = [] + for f in img_npy: + f = Image.fromarray(f) + i_temp = f.crop( + (res_pix[2], res_pix[0], res_pix[3], res_pix[1])) + a = i_temp.resize((w, h), Image.NEAREST) + frames_temp.append(a) + feats_temp = _to_tensors(frames_temp).unsqueeze(0) * 2 - 1 + frames_temp = [np.array(f).astype(np.uint8) for f in frames_temp] + masks_temp = [] + for m in masks[i - begin:i + video_length - begin]: + + m_temp = m.crop( + (res_pix[2], res_pix[0], res_pix[3], res_pix[1])) + b = m_temp.resize((w, h), Image.NEAREST) + masks_temp.append(b) + binary_masks_temp = [ + np.expand_dims((np.array(m) != 0).astype(np.uint8), 2) + for m in masks_temp + ] + masks_temp = _to_tensors(masks_temp).unsqueeze(0) + feats_temp, masks_temp = feats_temp.cuda(), masks_temp.cuda() + comp_frames = [None] * video_length + model.eval() + with torch.no_grad(): + feats_out = feats_temp * (1 - masks_temp).float() + feats_out = feats_out.view(video_length, 3, h, w) + feats_out = model.model.encoder(feats_out) + _, c, feat_h, feat_w = feats_out.size() + feats_out = feats_out.view(1, video_length, c, feat_h, feat_w) + + for f in range(0, video_length, neighbor_stride): + neighbor_ids = [ + i for i in range( + max(0, f - neighbor_stride), + min(video_length, f + neighbor_stride + 1)) + ] + ref_ids = get_ref_index(neighbor_ids, video_length) + with torch.no_grad(): + pred_feat = model.model.infer( + feats_out[0, neighbor_ids + ref_ids, :, :, :], + masks_temp[0, neighbor_ids + ref_ids, :, :, :]) + pred_img = torch.tanh( + model.model.decoder( + pred_feat[:len(neighbor_ids), :, :, :])).detach() + pred_img = (pred_img + 1) / 2 + pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy() * 255 + for j in range(len(neighbor_ids)): + idx = neighbor_ids[j] + img = np.array(pred_img[j]).astype( + np.uint8) * binary_masks_temp[idx] + frames_temp[ + idx] * (1 - binary_masks_temp[idx]) + if comp_frames[idx] is None: + comp_frames[idx] = img + else: + comp_frames[idx] = comp_frames[idx].astype( + np.float32) * 0.5 + img.astype( + np.float32) * 0.5 + print('inpainting time:', time.time() - begin_time) + for f in range(video_length): + comp = np.array(comp_frames[f]).astype( + np.uint8) * binary_masks_temp[f] + frames_temp[f] * ( + 1 - binary_masks_temp[f]) + if is_resize: + comp = cv2.resize(comp, (mask_w, mask_h)) + complete_frame = img_npy[f] + a1, b1, c1, d1 = crop_for_oriimg + a2, b2, c2, d2 = crop_for_inpimg + complete_frame[b1:d1, a1:c1] = comp[b2:d2, a2:c2] + video_save.write(complete_frame) + + img_npy = [] + + video_ori.release() diff --git a/modelscope/models/cv/video_inpainting/inpainting_model.py b/modelscope/models/cv/video_inpainting/inpainting_model.py new file mode 100644 index 00000000..a791b0ab --- /dev/null +++ b/modelscope/models/cv/video_inpainting/inpainting_model.py @@ -0,0 +1,373 @@ +""" VideoInpaintingNetwork +Base modules are adapted from https://github.com/researchmm/STTN, +originally Apache 2.0 License, Copyright (c) 2018-2022 OpenMMLab, +""" + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.metainfo import Models +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +class BaseNetwork(nn.Module): + + def __init__(self): + super(BaseNetwork, self).__init__() + + def print_network(self): + if isinstance(self, list): + self = self[0] + num_params = 0 + for param in self.parameters(): + num_params += param.numel() + print( + 'Network [%s] was created. Total number of parameters: %.1f million. ' + 'To see the architecture, do print(network).' % + (type(self).__name__, num_params / 1000000)) + + def init_weights(self, init_type='normal', gain=0.02): + ''' + initialize network's weights + init_type: normal | xavier | kaiming | orthogonal + https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39 + ''' + + def init_func(m): + classname = m.__class__.__name__ + if classname.find('InstanceNorm2d') != -1: + if hasattr(m, 'weight') and m.weight is not None: + nn.init.constant_(m.weight.data, 1.0) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias.data, 0.0) + elif hasattr(m, 'weight') and (classname.find('Conv') != -1 + or classname.find('Linear') != -1): + if init_type == 'normal': + nn.init.normal_(m.weight.data, 0.0, gain) + elif init_type == 'xavier': + nn.init.xavier_normal_(m.weight.data, gain=gain) + elif init_type == 'xavier_uniform': + nn.init.xavier_uniform_(m.weight.data, gain=1.0) + elif init_type == 'kaiming': + nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + nn.init.orthogonal_(m.weight.data, gain=gain) + elif init_type == 'none': + m.reset_parameters() + else: + raise NotImplementedError( + 'initialization method [%s] is not implemented' + % init_type) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias.data, 0.0) + + self.apply(init_func) + + for m in self.children(): + if hasattr(m, 'init_weights'): + m.init_weights(init_type, gain) + + +@MODELS.register_module( + Tasks.video_inpainting, module_name=Models.video_inpainting) +class VideoInpainting(TorchModel): + + def __init__(self, model_dir, device_id=0, *args, **kwargs): + super().__init__( + model_dir=model_dir, device_id=device_id, *args, **kwargs) + self.model = InpaintGenerator() + pretrained_params = torch.load('{}/{}'.format( + model_dir, ModelFile.TORCH_MODEL_BIN_FILE)) + self.model.load_state_dict(pretrained_params['netG']) + self.model.eval() + self.device_id = device_id + if self.device_id >= 0 and torch.cuda.is_available(): + self.model.to('cuda:{}'.format(self.device_id)) + logger.info('Use GPU: {}'.format(self.device_id)) + else: + self.device_id = -1 + logger.info('Use CPU for inference') + + +class InpaintGenerator(BaseNetwork): + + def __init__(self, init_weights=True): + super(InpaintGenerator, self).__init__() + channel = 256 + stack_num = 6 + patchsize = [(48, 24), (16, 8), (8, 4), (4, 2)] + blocks = [] + for _ in range(stack_num): + blocks.append(TransformerBlock(patchsize, hidden=channel)) + self.transformer = nn.Sequential(*blocks) + + self.encoder = nn.Sequential( + nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(128, channel, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True), + ) + + self.decoder = nn.Sequential( + deconv(channel, 128, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True), + deconv(64, 64, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1)) + + if init_weights: + self.init_weights() + + def forward(self, masked_frames, masks): + b, t, c, h, w = masked_frames.size() + masks = masks.view(b * t, 1, h, w) + enc_feat = self.encoder(masked_frames.view(b * t, c, h, w)) + _, c, h, w = enc_feat.size() + masks = F.interpolate(masks, scale_factor=1.0 / 4) + enc_feat = self.transformer({ + 'x': enc_feat, + 'm': masks, + 'b': b, + 'c': c + })['x'] + output = self.decoder(enc_feat) + output = torch.tanh(output) + return output + + def infer(self, feat, masks): + t, c, h, w = masks.size() + masks = masks.view(t, c, h, w) + masks = F.interpolate(masks, scale_factor=1.0 / 4) + t, c, _, _ = feat.size() + enc_feat = self.transformer({ + 'x': feat, + 'm': masks, + 'b': 1, + 'c': c + })['x'] + return enc_feat + + +class deconv(nn.Module): + + def __init__(self, + input_channel, + output_channel, + kernel_size=3, + padding=0): + super().__init__() + self.conv = nn.Conv2d( + input_channel, + output_channel, + kernel_size=kernel_size, + stride=1, + padding=padding) + + def forward(self, x): + x = F.interpolate( + x, scale_factor=2, mode='bilinear', align_corners=True) + x = self.conv(x) + return x + + +class Attention(nn.Module): + """ + Compute 'Scaled Dot Product Attention + """ + + def forward(self, query, key, value, m): + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt( + query.size(-1)) + scores.masked_fill(m, -1e9) + p_attn = F.softmax(scores, dim=-1) + p_val = torch.matmul(p_attn, value) + return p_val, p_attn + + +class MultiHeadedAttention(nn.Module): + """ + Take in model size and number of heads. + """ + + def __init__(self, patchsize, d_model): + super().__init__() + self.patchsize = patchsize + self.query_embedding = nn.Conv2d( + d_model, d_model, kernel_size=1, padding=0) + self.value_embedding = nn.Conv2d( + d_model, d_model, kernel_size=1, padding=0) + self.key_embedding = nn.Conv2d( + d_model, d_model, kernel_size=1, padding=0) + self.output_linear = nn.Sequential( + nn.Conv2d(d_model, d_model, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, inplace=True)) + self.attention = Attention() + + def forward(self, x, m, b, c): + bt, _, h, w = x.size() + t = bt // b + d_k = c // len(self.patchsize) + output = [] + _query = self.query_embedding(x) + _key = self.key_embedding(x) + _value = self.value_embedding(x) + for (width, height), query, key, value in zip( + self.patchsize, + torch.chunk(_query, len(self.patchsize), dim=1), + torch.chunk(_key, len(self.patchsize), dim=1), + torch.chunk(_value, len(self.patchsize), dim=1)): + out_w, out_h = w // width, h // height + mm = m.view(b, t, 1, out_h, height, out_w, width) + mm = mm.permute(0, 1, 3, 5, 2, 4, + 6).contiguous().view(b, t * out_h * out_w, + height * width) + mm = (mm.mean(-1) > 0.5).unsqueeze(1).repeat( + 1, t * out_h * out_w, 1) + query = query.view(b, t, d_k, out_h, height, out_w, width) + query = query.permute(0, 1, 3, 5, 2, 4, + 6).contiguous().view(b, t * out_h * out_w, + d_k * height * width) + key = key.view(b, t, d_k, out_h, height, out_w, width) + key = key.permute(0, 1, 3, 5, 2, 4, + 6).contiguous().view(b, t * out_h * out_w, + d_k * height * width) + value = value.view(b, t, d_k, out_h, height, out_w, width) + value = value.permute(0, 1, 3, 5, 2, 4, + 6).contiguous().view(b, t * out_h * out_w, + d_k * height * width) + y, _ = self.attention(query, key, value, mm) + y = y.view(b, t, out_h, out_w, d_k, height, width) + y = y.permute(0, 1, 4, 2, 5, 3, 6).contiguous().view(bt, d_k, h, w) + output.append(y) + output = torch.cat(output, 1) + x = self.output_linear(output) + return x + + +class FeedForward(nn.Module): + + def __init__(self, d_model): + super(FeedForward, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d(d_model, d_model, kernel_size=3, padding=2, dilation=2), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(d_model, d_model, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, inplace=True)) + + def forward(self, x): + x = self.conv(x) + return x + + +class TransformerBlock(nn.Module): + """ + Transformer = MultiHead_Attention + Feed_Forward with sublayer connection + """ + + def __init__(self, patchsize, hidden=128): # hidden=128 + super().__init__() + self.attention = MultiHeadedAttention(patchsize, d_model=hidden) + self.feed_forward = FeedForward(hidden) + + def forward(self, x): + x, m, b, c = x['x'], x['m'], x['b'], x['c'] + x = x + self.attention(x, m, b, c) + x = x + self.feed_forward(x) + return {'x': x, 'm': m, 'b': b, 'c': c} + + +class Discriminator(BaseNetwork): + + def __init__(self, + in_channels=3, + use_sigmoid=False, + use_spectral_norm=True, + init_weights=True): + super(Discriminator, self).__init__() + self.use_sigmoid = use_sigmoid + nf = 64 + + self.conv = nn.Sequential( + spectral_norm( + nn.Conv3d( + in_channels=in_channels, + out_channels=nf * 1, + kernel_size=(3, 5, 5), + stride=(1, 2, 2), + padding=1, + bias=not use_spectral_norm), use_spectral_norm), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm( + nn.Conv3d( + nf * 1, + nf * 2, + kernel_size=(3, 5, 5), + stride=(1, 2, 2), + padding=(1, 2, 2), + bias=not use_spectral_norm), use_spectral_norm), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm( + nn.Conv3d( + nf * 2, + nf * 4, + kernel_size=(3, 5, 5), + stride=(1, 2, 2), + padding=(1, 2, 2), + bias=not use_spectral_norm), use_spectral_norm), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm( + nn.Conv3d( + nf * 4, + nf * 4, + kernel_size=(3, 5, 5), + stride=(1, 2, 2), + padding=(1, 2, 2), + bias=not use_spectral_norm), use_spectral_norm), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm( + nn.Conv3d( + nf * 4, + nf * 4, + kernel_size=(3, 5, 5), + stride=(1, 2, 2), + padding=(1, 2, 2), + bias=not use_spectral_norm), use_spectral_norm), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv3d( + nf * 4, + nf * 4, + kernel_size=(3, 5, 5), + stride=(1, 2, 2), + padding=(1, 2, 2))) + + if init_weights: + self.init_weights() + + def forward(self, xs): + xs_t = torch.transpose(xs, 0, 1) + xs_t = xs_t.unsqueeze(0) + feat = self.conv(xs_t) + if self.use_sigmoid: + feat = torch.sigmoid(feat) + out = torch.transpose(feat, 1, 2) + return out + + +def spectral_norm(module, mode=True): + if mode: + return _spectral_norm(module) + return module diff --git a/modelscope/outputs.py b/modelscope/outputs.py index 6c7500bb..37ab3481 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -610,4 +610,9 @@ TASK_OUTPUTS = { # "img_embedding": np.array with shape [1, D], # } Tasks.image_reid_person: [OutputKeys.IMG_EMBEDDING], + + # { + # 'output': ['Done' / 'Decode_Error'] + # } + Tasks.video_inpainting: [OutputKeys.OUTPUT] } diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index fa79ca11..a1f093a3 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -168,6 +168,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { 'damo/cv_resnet50-bert_video-scene-segmentation_movienet'), Tasks.shop_segmentation: (Pipelines.shop_segmentation, 'damo/cv_vitb16_segmentation_shop-seg'), + Tasks.video_inpainting: (Pipelines.video_inpainting, + 'damo/cv_video-inpainting'), } diff --git a/modelscope/pipelines/cv/video_inpainting_pipeline.py b/modelscope/pipelines/cv/video_inpainting_pipeline.py new file mode 100644 index 00000000..15444e05 --- /dev/null +++ b/modelscope/pipelines/cv/video_inpainting_pipeline.py @@ -0,0 +1,47 @@ +from typing import Any, Dict + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.video_inpainting import inpainting +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.video_inpainting, module_name=Pipelines.video_inpainting) +class VideoInpaintingPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create video inpainting pipeline for prediction + Args: + model: model id on modelscope hub. + """ + + super().__init__(model=model, **kwargs) + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + return input + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + decode_error, fps, w, h = inpainting.video_process( + input['video_input_path']) + + if decode_error is not None: + return {OutputKeys.OUTPUT: 'decode_error'} + + inpainting.inpainting_by_model_balance(self.model, + input['video_input_path'], + input['mask_path'], + input['video_output_path'], fps, + w, h) + + return {OutputKeys.OUTPUT: 'Done'} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 47d38dd7..8fb00ed6 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -70,6 +70,9 @@ class CVTasks(object): crowd_counting = 'crowd-counting' movie_scene_segmentation = 'movie-scene-segmentation' + # video editing + video_inpainting = 'video-inpainting' + # reid and tracking video_single_object_tracking = 'video-single-object-tracking' video_summarization = 'video-summarization' diff --git a/tests/pipelines/test_person_image_cartoon.py b/tests/pipelines/test_person_image_cartoon.py index bdbf8b61..90aaa500 100644 --- a/tests/pipelines/test_person_image_cartoon.py +++ b/tests/pipelines/test_person_image_cartoon.py @@ -1,5 +1,4 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -import os import os.path as osp import unittest diff --git a/tests/pipelines/test_video_inpainting.py b/tests/pipelines/test_video_inpainting.py new file mode 100644 index 00000000..8364b1b3 --- /dev/null +++ b/tests/pipelines/test_video_inpainting.py @@ -0,0 +1,39 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class VideoInpaintingTest(unittest.TestCase): + + def setUp(self) -> None: + self.model = 'damo/cv_video-inpainting' + self.mask_dir = 'data/test/videos/mask_dir' + self.video_in = 'data/test/videos/video_inpainting_test.mp4' + self.video_out = 'out.mp4' + self.input = { + 'video_input_path': self.video_in, + 'video_output_path': self.video_out, + 'mask_path': self.mask_dir + } + + def pipeline_inference(self, pipeline: Pipeline, input: str): + result = pipeline(input) + print(result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + video_inpainting = pipeline(Tasks.video_inpainting, model=self.model) + self.pipeline_inference(video_inpainting, self.input) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + video_inpainting = pipeline(Tasks.video_inpainting) + self.pipeline_inference(video_inpainting, self.input) + + +if __name__ == '__main__': + unittest.main()