视频编辑的cr
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10026166
master
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:b158f6029d9763d7f84042f7c5835f398c688fdbb6b3f4fe6431101d4118c66c | |||
| size 2766 | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:0dcf46b93077e2229ab69cd6ddb80e2689546c575ee538bb2033fee1124ef3e3 | |||
| size 2761 | |||
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:9c9870df5a86acaaec67063183dace795479cd0f05296f13058995f475149c56 | |||
| size 2957783 | |||
| @@ -38,6 +38,7 @@ class Models(object): | |||
| mogface = 'mogface' | |||
| mtcnn = 'mtcnn' | |||
| ulfd = 'ulfd' | |||
| video_inpainting = 'video-inpainting' | |||
| # EasyCV models | |||
| yolox = 'YOLOX' | |||
| @@ -169,6 +170,7 @@ class Pipelines(object): | |||
| text_driven_segmentation = 'text-driven-segmentation' | |||
| movie_scene_segmentation = 'resnet50-bert-movie-scene-segmentation' | |||
| shop_segmentation = 'shop-segmentation' | |||
| video_inpainting = 'video-inpainting' | |||
| # nlp tasks | |||
| sentence_similarity = 'sentence-similarity' | |||
| @@ -0,0 +1,20 @@ | |||
| # copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import TYPE_CHECKING | |||
| from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .inpainting_model import VideoInpainting | |||
| else: | |||
| _import_structure = {'inpainting_model': ['VideoInpainting']} | |||
| import sys | |||
| sys.modules[__name__] = LazyImportModule( | |||
| __name__, | |||
| globals()['__file__'], | |||
| _import_structure, | |||
| module_spec=__spec__, | |||
| extra_objects={}, | |||
| ) | |||
| @@ -0,0 +1,298 @@ | |||
| """ VideoInpaintingProcess | |||
| Base modules are adapted from https://github.com/researchmm/STTN, | |||
| originally Apache 2.0 License, Copyright (c) 2018-2022 OpenMMLab, | |||
| """ | |||
| import os | |||
| import time | |||
| import cv2 | |||
| import numpy as np | |||
| import torch | |||
| from PIL import Image | |||
| from torchvision import transforms | |||
| torch.backends.cudnn.enabled = False | |||
| w, h = 192, 96 | |||
| ref_length = 300 | |||
| neighbor_stride = 20 | |||
| default_fps = 24 | |||
| MAX_frame = 300 | |||
| def video_process(video_input_path): | |||
| video_input = cv2.VideoCapture(video_input_path) | |||
| success, frame = video_input.read() | |||
| if success is False: | |||
| decode_error = 'decode_error' | |||
| w, h, fps = 0, 0, 0 | |||
| else: | |||
| decode_error = None | |||
| h, w = frame.shape[0:2] | |||
| fps = video_input.get(cv2.CAP_PROP_FPS) | |||
| video_input.release() | |||
| return decode_error, fps, w, h | |||
| class Stack(object): | |||
| def __init__(self, roll=False): | |||
| self.roll = roll | |||
| def __call__(self, img_group): | |||
| mode = img_group[0].mode | |||
| if mode == '1': | |||
| img_group = [img.convert('L') for img in img_group] | |||
| mode = 'L' | |||
| if mode == 'L': | |||
| return np.stack([np.expand_dims(x, 2) for x in img_group], axis=2) | |||
| elif mode == 'RGB': | |||
| if self.roll: | |||
| return np.stack([np.array(x)[:, :, ::-1] for x in img_group], | |||
| axis=2) | |||
| else: | |||
| return np.stack(img_group, axis=2) | |||
| else: | |||
| raise NotImplementedError(f'Image mode {mode}') | |||
| class ToTorchFormatTensor(object): | |||
| """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] | |||
| to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ | |||
| def __init__(self, div=True): | |||
| self.div = div | |||
| def __call__(self, pic): | |||
| if isinstance(pic, np.ndarray): | |||
| img = torch.from_numpy(pic).permute(2, 3, 0, 1).contiguous() | |||
| else: | |||
| img = torch.ByteTensor( | |||
| torch.ByteStorage.from_buffer(pic.tobytes())) | |||
| img = img.view(pic.size[1], pic.size[0], len(pic.mode)) | |||
| img = img.transpose(0, 1).transpose(0, 2).contiguous() | |||
| img = img.float().div(255) if self.div else img.float() | |||
| return img | |||
| _to_tensors = transforms.Compose([Stack(), ToTorchFormatTensor()]) | |||
| def get_crop_mask_v1(mask): | |||
| orig_h, orig_w, _ = mask.shape | |||
| if (mask == 255).all(): | |||
| return mask, (0, int(orig_h), 0, | |||
| int(orig_w)), [0, int(orig_h), 0, | |||
| int(orig_w) | |||
| ], [0, int(orig_h), 0, | |||
| int(orig_w)] | |||
| hs = np.min(np.where(mask == 0)[0]) | |||
| he = np.max(np.where(mask == 0)[0]) | |||
| ws = np.min(np.where(mask == 0)[1]) | |||
| we = np.max(np.where(mask == 0)[1]) | |||
| crop_box = [ws, hs, we, he] | |||
| mask_h = round(int(orig_h / 2) / 4) * 4 | |||
| mask_w = round(int(orig_w / 2) / 4) * 4 | |||
| if (hs < mask_h) and (he < mask_h) and (ws < mask_w) and (we < mask_w): | |||
| crop_mask = mask[:mask_h, :mask_w, :] | |||
| res_pix = (0, mask_h, 0, mask_w) | |||
| elif (hs < mask_h) and (he < mask_h) and (ws > mask_w) and (we > mask_w): | |||
| crop_mask = mask[:mask_h, orig_w - mask_w:orig_w, :] | |||
| res_pix = (0, mask_h, orig_w - mask_w, int(orig_w)) | |||
| elif (hs > mask_h) and (he > mask_h) and (ws < mask_w) and (we < mask_w): | |||
| crop_mask = mask[orig_h - mask_h:orig_h, :mask_w, :] | |||
| res_pix = (orig_h - mask_h, int(orig_h), 0, mask_w) | |||
| elif (hs > mask_h) and (he > mask_h) and (ws > mask_w) and (we > mask_w): | |||
| crop_mask = mask[orig_h - mask_h:orig_h, orig_w - mask_w:orig_w, :] | |||
| res_pix = (orig_h - mask_h, int(orig_h), orig_w - mask_w, int(orig_w)) | |||
| elif (hs < mask_h) and (he < mask_h) and (ws < mask_w) and (we > mask_w): | |||
| crop_mask = mask[:mask_h, :, :] | |||
| res_pix = (0, mask_h, 0, int(orig_w)) | |||
| elif (hs < mask_h) and (he > mask_h) and (ws < mask_w) and (we < mask_w): | |||
| crop_mask = mask[:, :mask_w, :] | |||
| res_pix = (0, int(orig_h), 0, mask_w) | |||
| elif (hs > mask_h) and (he > mask_h) and (ws < mask_w) and (we > mask_w): | |||
| crop_mask = mask[orig_h - mask_h:orig_h, :, :] | |||
| res_pix = (orig_h - mask_h, int(orig_h), 0, int(orig_w)) | |||
| elif (hs < mask_h) and (he > mask_h) and (ws > mask_w) and (we > mask_w): | |||
| crop_mask = mask[:, orig_w - mask_w:orig_w, :] | |||
| res_pix = (0, int(orig_h), orig_w - mask_w, int(orig_w)) | |||
| else: | |||
| crop_mask = mask | |||
| res_pix = (0, int(orig_h), 0, int(orig_w)) | |||
| a = ws - res_pix[2] | |||
| b = hs - res_pix[0] | |||
| c = we - res_pix[2] | |||
| d = he - res_pix[0] | |||
| return crop_mask, res_pix, crop_box, [a, b, c, d] | |||
| def get_ref_index(neighbor_ids, length): | |||
| ref_index = [] | |||
| for i in range(0, length, ref_length): | |||
| if i not in neighbor_ids: | |||
| ref_index.append(i) | |||
| return ref_index | |||
| def read_mask_oneImage(mpath): | |||
| masks = [] | |||
| print('mask_path: {}'.format(mpath)) | |||
| start = int(mpath.split('/')[-1].split('mask_')[1].split('_')[0]) | |||
| end = int( | |||
| mpath.split('/')[-1].split('mask_')[1].split('_')[1].split('.')[0]) | |||
| m = Image.open(mpath) | |||
| m = np.array(m.convert('L')) | |||
| m = np.array(m > 0).astype(np.uint8) | |||
| m = 1 - m | |||
| for i in range(start - 1, end + 1): | |||
| masks.append(Image.fromarray(m * 255)) | |||
| return masks | |||
| def check_size(h, w): | |||
| is_resize = False | |||
| if h != 240: | |||
| h = 240 | |||
| is_resize = True | |||
| if w != 432: | |||
| w = 432 | |||
| is_resize = True | |||
| return is_resize | |||
| def get_mask_list(mask_path): | |||
| mask_names = os.listdir(mask_path) | |||
| mask_names.sort() | |||
| abs_mask_path = [] | |||
| mask_list = [] | |||
| begin_list = [] | |||
| end_list = [] | |||
| for mask_name in mask_names: | |||
| mask_name_tmp = mask_name.split('mask_')[1] | |||
| begin_list.append(int(mask_name_tmp.split('_')[0])) | |||
| end_list.append(int(mask_name_tmp.split('_')[1].split('.')[0])) | |||
| abs_mask_path.append(os.path.join(mask_path, mask_name)) | |||
| mask = cv2.imread(os.path.join(mask_path, mask_name)) | |||
| mask_list.append(mask) | |||
| return mask_list, begin_list, end_list, abs_mask_path | |||
| def inpainting_by_model_balance(model, video_inputPath, mask_path, | |||
| video_savePath, fps, w_ori, h_ori): | |||
| video_ori = cv2.VideoCapture(video_inputPath) | |||
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |||
| video_save = cv2.VideoWriter(video_savePath, fourcc, fps, (w_ori, h_ori)) | |||
| mask_list, begin_list, end_list, abs_mask_path = get_mask_list(mask_path) | |||
| img_npy = [] | |||
| for index, mask in enumerate(mask_list): | |||
| masks = read_mask_oneImage(abs_mask_path[index]) | |||
| mask, res_pix, crop_for_oriimg, crop_for_inpimg = get_crop_mask_v1( | |||
| mask) | |||
| mask_h, mask_w = mask.shape[0:2] | |||
| is_resize = check_size(mask.shape[0], mask.shape[1]) | |||
| begin = begin_list[index] | |||
| end = end_list[index] | |||
| print('begin: {}'.format(begin)) | |||
| print('end: {}'.format(end)) | |||
| for i in range(begin, end + 1, MAX_frame): | |||
| begin_time = time.time() | |||
| if i + MAX_frame <= end: | |||
| video_length = MAX_frame | |||
| else: | |||
| video_length = end - i + 1 | |||
| for frame_count in range(video_length): | |||
| _, frame = video_ori.read() | |||
| img_npy.append(frame) | |||
| frames_temp = [] | |||
| for f in img_npy: | |||
| f = Image.fromarray(f) | |||
| i_temp = f.crop( | |||
| (res_pix[2], res_pix[0], res_pix[3], res_pix[1])) | |||
| a = i_temp.resize((w, h), Image.NEAREST) | |||
| frames_temp.append(a) | |||
| feats_temp = _to_tensors(frames_temp).unsqueeze(0) * 2 - 1 | |||
| frames_temp = [np.array(f).astype(np.uint8) for f in frames_temp] | |||
| masks_temp = [] | |||
| for m in masks[i - begin:i + video_length - begin]: | |||
| m_temp = m.crop( | |||
| (res_pix[2], res_pix[0], res_pix[3], res_pix[1])) | |||
| b = m_temp.resize((w, h), Image.NEAREST) | |||
| masks_temp.append(b) | |||
| binary_masks_temp = [ | |||
| np.expand_dims((np.array(m) != 0).astype(np.uint8), 2) | |||
| for m in masks_temp | |||
| ] | |||
| masks_temp = _to_tensors(masks_temp).unsqueeze(0) | |||
| feats_temp, masks_temp = feats_temp.cuda(), masks_temp.cuda() | |||
| comp_frames = [None] * video_length | |||
| model.eval() | |||
| with torch.no_grad(): | |||
| feats_out = feats_temp * (1 - masks_temp).float() | |||
| feats_out = feats_out.view(video_length, 3, h, w) | |||
| feats_out = model.model.encoder(feats_out) | |||
| _, c, feat_h, feat_w = feats_out.size() | |||
| feats_out = feats_out.view(1, video_length, c, feat_h, feat_w) | |||
| for f in range(0, video_length, neighbor_stride): | |||
| neighbor_ids = [ | |||
| i for i in range( | |||
| max(0, f - neighbor_stride), | |||
| min(video_length, f + neighbor_stride + 1)) | |||
| ] | |||
| ref_ids = get_ref_index(neighbor_ids, video_length) | |||
| with torch.no_grad(): | |||
| pred_feat = model.model.infer( | |||
| feats_out[0, neighbor_ids + ref_ids, :, :, :], | |||
| masks_temp[0, neighbor_ids + ref_ids, :, :, :]) | |||
| pred_img = torch.tanh( | |||
| model.model.decoder( | |||
| pred_feat[:len(neighbor_ids), :, :, :])).detach() | |||
| pred_img = (pred_img + 1) / 2 | |||
| pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy() * 255 | |||
| for j in range(len(neighbor_ids)): | |||
| idx = neighbor_ids[j] | |||
| img = np.array(pred_img[j]).astype( | |||
| np.uint8) * binary_masks_temp[idx] + frames_temp[ | |||
| idx] * (1 - binary_masks_temp[idx]) | |||
| if comp_frames[idx] is None: | |||
| comp_frames[idx] = img | |||
| else: | |||
| comp_frames[idx] = comp_frames[idx].astype( | |||
| np.float32) * 0.5 + img.astype( | |||
| np.float32) * 0.5 | |||
| print('inpainting time:', time.time() - begin_time) | |||
| for f in range(video_length): | |||
| comp = np.array(comp_frames[f]).astype( | |||
| np.uint8) * binary_masks_temp[f] + frames_temp[f] * ( | |||
| 1 - binary_masks_temp[f]) | |||
| if is_resize: | |||
| comp = cv2.resize(comp, (mask_w, mask_h)) | |||
| complete_frame = img_npy[f] | |||
| a1, b1, c1, d1 = crop_for_oriimg | |||
| a2, b2, c2, d2 = crop_for_inpimg | |||
| complete_frame[b1:d1, a1:c1] = comp[b2:d2, a2:c2] | |||
| video_save.write(complete_frame) | |||
| img_npy = [] | |||
| video_ori.release() | |||
| @@ -0,0 +1,373 @@ | |||
| """ VideoInpaintingNetwork | |||
| Base modules are adapted from https://github.com/researchmm/STTN, | |||
| originally Apache 2.0 License, Copyright (c) 2018-2022 OpenMMLab, | |||
| """ | |||
| import math | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base import TorchModel | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| class BaseNetwork(nn.Module): | |||
| def __init__(self): | |||
| super(BaseNetwork, self).__init__() | |||
| def print_network(self): | |||
| if isinstance(self, list): | |||
| self = self[0] | |||
| num_params = 0 | |||
| for param in self.parameters(): | |||
| num_params += param.numel() | |||
| print( | |||
| 'Network [%s] was created. Total number of parameters: %.1f million. ' | |||
| 'To see the architecture, do print(network).' % | |||
| (type(self).__name__, num_params / 1000000)) | |||
| def init_weights(self, init_type='normal', gain=0.02): | |||
| ''' | |||
| initialize network's weights | |||
| init_type: normal | xavier | kaiming | orthogonal | |||
| https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39 | |||
| ''' | |||
| def init_func(m): | |||
| classname = m.__class__.__name__ | |||
| if classname.find('InstanceNorm2d') != -1: | |||
| if hasattr(m, 'weight') and m.weight is not None: | |||
| nn.init.constant_(m.weight.data, 1.0) | |||
| if hasattr(m, 'bias') and m.bias is not None: | |||
| nn.init.constant_(m.bias.data, 0.0) | |||
| elif hasattr(m, 'weight') and (classname.find('Conv') != -1 | |||
| or classname.find('Linear') != -1): | |||
| if init_type == 'normal': | |||
| nn.init.normal_(m.weight.data, 0.0, gain) | |||
| elif init_type == 'xavier': | |||
| nn.init.xavier_normal_(m.weight.data, gain=gain) | |||
| elif init_type == 'xavier_uniform': | |||
| nn.init.xavier_uniform_(m.weight.data, gain=1.0) | |||
| elif init_type == 'kaiming': | |||
| nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') | |||
| elif init_type == 'orthogonal': | |||
| nn.init.orthogonal_(m.weight.data, gain=gain) | |||
| elif init_type == 'none': | |||
| m.reset_parameters() | |||
| else: | |||
| raise NotImplementedError( | |||
| 'initialization method [%s] is not implemented' | |||
| % init_type) | |||
| if hasattr(m, 'bias') and m.bias is not None: | |||
| nn.init.constant_(m.bias.data, 0.0) | |||
| self.apply(init_func) | |||
| for m in self.children(): | |||
| if hasattr(m, 'init_weights'): | |||
| m.init_weights(init_type, gain) | |||
| @MODELS.register_module( | |||
| Tasks.video_inpainting, module_name=Models.video_inpainting) | |||
| class VideoInpainting(TorchModel): | |||
| def __init__(self, model_dir, device_id=0, *args, **kwargs): | |||
| super().__init__( | |||
| model_dir=model_dir, device_id=device_id, *args, **kwargs) | |||
| self.model = InpaintGenerator() | |||
| pretrained_params = torch.load('{}/{}'.format( | |||
| model_dir, ModelFile.TORCH_MODEL_BIN_FILE)) | |||
| self.model.load_state_dict(pretrained_params['netG']) | |||
| self.model.eval() | |||
| self.device_id = device_id | |||
| if self.device_id >= 0 and torch.cuda.is_available(): | |||
| self.model.to('cuda:{}'.format(self.device_id)) | |||
| logger.info('Use GPU: {}'.format(self.device_id)) | |||
| else: | |||
| self.device_id = -1 | |||
| logger.info('Use CPU for inference') | |||
| class InpaintGenerator(BaseNetwork): | |||
| def __init__(self, init_weights=True): | |||
| super(InpaintGenerator, self).__init__() | |||
| channel = 256 | |||
| stack_num = 6 | |||
| patchsize = [(48, 24), (16, 8), (8, 4), (4, 2)] | |||
| blocks = [] | |||
| for _ in range(stack_num): | |||
| blocks.append(TransformerBlock(patchsize, hidden=channel)) | |||
| self.transformer = nn.Sequential(*blocks) | |||
| self.encoder = nn.Sequential( | |||
| nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), | |||
| nn.LeakyReLU(0.2, inplace=True), | |||
| nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), | |||
| nn.LeakyReLU(0.2, inplace=True), | |||
| nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), | |||
| nn.LeakyReLU(0.2, inplace=True), | |||
| nn.Conv2d(128, channel, kernel_size=3, stride=1, padding=1), | |||
| nn.LeakyReLU(0.2, inplace=True), | |||
| ) | |||
| self.decoder = nn.Sequential( | |||
| deconv(channel, 128, kernel_size=3, padding=1), | |||
| nn.LeakyReLU(0.2, inplace=True), | |||
| nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), | |||
| nn.LeakyReLU(0.2, inplace=True), | |||
| deconv(64, 64, kernel_size=3, padding=1), | |||
| nn.LeakyReLU(0.2, inplace=True), | |||
| nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1)) | |||
| if init_weights: | |||
| self.init_weights() | |||
| def forward(self, masked_frames, masks): | |||
| b, t, c, h, w = masked_frames.size() | |||
| masks = masks.view(b * t, 1, h, w) | |||
| enc_feat = self.encoder(masked_frames.view(b * t, c, h, w)) | |||
| _, c, h, w = enc_feat.size() | |||
| masks = F.interpolate(masks, scale_factor=1.0 / 4) | |||
| enc_feat = self.transformer({ | |||
| 'x': enc_feat, | |||
| 'm': masks, | |||
| 'b': b, | |||
| 'c': c | |||
| })['x'] | |||
| output = self.decoder(enc_feat) | |||
| output = torch.tanh(output) | |||
| return output | |||
| def infer(self, feat, masks): | |||
| t, c, h, w = masks.size() | |||
| masks = masks.view(t, c, h, w) | |||
| masks = F.interpolate(masks, scale_factor=1.0 / 4) | |||
| t, c, _, _ = feat.size() | |||
| enc_feat = self.transformer({ | |||
| 'x': feat, | |||
| 'm': masks, | |||
| 'b': 1, | |||
| 'c': c | |||
| })['x'] | |||
| return enc_feat | |||
| class deconv(nn.Module): | |||
| def __init__(self, | |||
| input_channel, | |||
| output_channel, | |||
| kernel_size=3, | |||
| padding=0): | |||
| super().__init__() | |||
| self.conv = nn.Conv2d( | |||
| input_channel, | |||
| output_channel, | |||
| kernel_size=kernel_size, | |||
| stride=1, | |||
| padding=padding) | |||
| def forward(self, x): | |||
| x = F.interpolate( | |||
| x, scale_factor=2, mode='bilinear', align_corners=True) | |||
| x = self.conv(x) | |||
| return x | |||
| class Attention(nn.Module): | |||
| """ | |||
| Compute 'Scaled Dot Product Attention | |||
| """ | |||
| def forward(self, query, key, value, m): | |||
| scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt( | |||
| query.size(-1)) | |||
| scores.masked_fill(m, -1e9) | |||
| p_attn = F.softmax(scores, dim=-1) | |||
| p_val = torch.matmul(p_attn, value) | |||
| return p_val, p_attn | |||
| class MultiHeadedAttention(nn.Module): | |||
| """ | |||
| Take in model size and number of heads. | |||
| """ | |||
| def __init__(self, patchsize, d_model): | |||
| super().__init__() | |||
| self.patchsize = patchsize | |||
| self.query_embedding = nn.Conv2d( | |||
| d_model, d_model, kernel_size=1, padding=0) | |||
| self.value_embedding = nn.Conv2d( | |||
| d_model, d_model, kernel_size=1, padding=0) | |||
| self.key_embedding = nn.Conv2d( | |||
| d_model, d_model, kernel_size=1, padding=0) | |||
| self.output_linear = nn.Sequential( | |||
| nn.Conv2d(d_model, d_model, kernel_size=3, padding=1), | |||
| nn.LeakyReLU(0.2, inplace=True)) | |||
| self.attention = Attention() | |||
| def forward(self, x, m, b, c): | |||
| bt, _, h, w = x.size() | |||
| t = bt // b | |||
| d_k = c // len(self.patchsize) | |||
| output = [] | |||
| _query = self.query_embedding(x) | |||
| _key = self.key_embedding(x) | |||
| _value = self.value_embedding(x) | |||
| for (width, height), query, key, value in zip( | |||
| self.patchsize, | |||
| torch.chunk(_query, len(self.patchsize), dim=1), | |||
| torch.chunk(_key, len(self.patchsize), dim=1), | |||
| torch.chunk(_value, len(self.patchsize), dim=1)): | |||
| out_w, out_h = w // width, h // height | |||
| mm = m.view(b, t, 1, out_h, height, out_w, width) | |||
| mm = mm.permute(0, 1, 3, 5, 2, 4, | |||
| 6).contiguous().view(b, t * out_h * out_w, | |||
| height * width) | |||
| mm = (mm.mean(-1) > 0.5).unsqueeze(1).repeat( | |||
| 1, t * out_h * out_w, 1) | |||
| query = query.view(b, t, d_k, out_h, height, out_w, width) | |||
| query = query.permute(0, 1, 3, 5, 2, 4, | |||
| 6).contiguous().view(b, t * out_h * out_w, | |||
| d_k * height * width) | |||
| key = key.view(b, t, d_k, out_h, height, out_w, width) | |||
| key = key.permute(0, 1, 3, 5, 2, 4, | |||
| 6).contiguous().view(b, t * out_h * out_w, | |||
| d_k * height * width) | |||
| value = value.view(b, t, d_k, out_h, height, out_w, width) | |||
| value = value.permute(0, 1, 3, 5, 2, 4, | |||
| 6).contiguous().view(b, t * out_h * out_w, | |||
| d_k * height * width) | |||
| y, _ = self.attention(query, key, value, mm) | |||
| y = y.view(b, t, out_h, out_w, d_k, height, width) | |||
| y = y.permute(0, 1, 4, 2, 5, 3, 6).contiguous().view(bt, d_k, h, w) | |||
| output.append(y) | |||
| output = torch.cat(output, 1) | |||
| x = self.output_linear(output) | |||
| return x | |||
| class FeedForward(nn.Module): | |||
| def __init__(self, d_model): | |||
| super(FeedForward, self).__init__() | |||
| self.conv = nn.Sequential( | |||
| nn.Conv2d(d_model, d_model, kernel_size=3, padding=2, dilation=2), | |||
| nn.LeakyReLU(0.2, inplace=True), | |||
| nn.Conv2d(d_model, d_model, kernel_size=3, padding=1), | |||
| nn.LeakyReLU(0.2, inplace=True)) | |||
| def forward(self, x): | |||
| x = self.conv(x) | |||
| return x | |||
| class TransformerBlock(nn.Module): | |||
| """ | |||
| Transformer = MultiHead_Attention + Feed_Forward with sublayer connection | |||
| """ | |||
| def __init__(self, patchsize, hidden=128): # hidden=128 | |||
| super().__init__() | |||
| self.attention = MultiHeadedAttention(patchsize, d_model=hidden) | |||
| self.feed_forward = FeedForward(hidden) | |||
| def forward(self, x): | |||
| x, m, b, c = x['x'], x['m'], x['b'], x['c'] | |||
| x = x + self.attention(x, m, b, c) | |||
| x = x + self.feed_forward(x) | |||
| return {'x': x, 'm': m, 'b': b, 'c': c} | |||
| class Discriminator(BaseNetwork): | |||
| def __init__(self, | |||
| in_channels=3, | |||
| use_sigmoid=False, | |||
| use_spectral_norm=True, | |||
| init_weights=True): | |||
| super(Discriminator, self).__init__() | |||
| self.use_sigmoid = use_sigmoid | |||
| nf = 64 | |||
| self.conv = nn.Sequential( | |||
| spectral_norm( | |||
| nn.Conv3d( | |||
| in_channels=in_channels, | |||
| out_channels=nf * 1, | |||
| kernel_size=(3, 5, 5), | |||
| stride=(1, 2, 2), | |||
| padding=1, | |||
| bias=not use_spectral_norm), use_spectral_norm), | |||
| nn.LeakyReLU(0.2, inplace=True), | |||
| spectral_norm( | |||
| nn.Conv3d( | |||
| nf * 1, | |||
| nf * 2, | |||
| kernel_size=(3, 5, 5), | |||
| stride=(1, 2, 2), | |||
| padding=(1, 2, 2), | |||
| bias=not use_spectral_norm), use_spectral_norm), | |||
| nn.LeakyReLU(0.2, inplace=True), | |||
| spectral_norm( | |||
| nn.Conv3d( | |||
| nf * 2, | |||
| nf * 4, | |||
| kernel_size=(3, 5, 5), | |||
| stride=(1, 2, 2), | |||
| padding=(1, 2, 2), | |||
| bias=not use_spectral_norm), use_spectral_norm), | |||
| nn.LeakyReLU(0.2, inplace=True), | |||
| spectral_norm( | |||
| nn.Conv3d( | |||
| nf * 4, | |||
| nf * 4, | |||
| kernel_size=(3, 5, 5), | |||
| stride=(1, 2, 2), | |||
| padding=(1, 2, 2), | |||
| bias=not use_spectral_norm), use_spectral_norm), | |||
| nn.LeakyReLU(0.2, inplace=True), | |||
| spectral_norm( | |||
| nn.Conv3d( | |||
| nf * 4, | |||
| nf * 4, | |||
| kernel_size=(3, 5, 5), | |||
| stride=(1, 2, 2), | |||
| padding=(1, 2, 2), | |||
| bias=not use_spectral_norm), use_spectral_norm), | |||
| nn.LeakyReLU(0.2, inplace=True), | |||
| nn.Conv3d( | |||
| nf * 4, | |||
| nf * 4, | |||
| kernel_size=(3, 5, 5), | |||
| stride=(1, 2, 2), | |||
| padding=(1, 2, 2))) | |||
| if init_weights: | |||
| self.init_weights() | |||
| def forward(self, xs): | |||
| xs_t = torch.transpose(xs, 0, 1) | |||
| xs_t = xs_t.unsqueeze(0) | |||
| feat = self.conv(xs_t) | |||
| if self.use_sigmoid: | |||
| feat = torch.sigmoid(feat) | |||
| out = torch.transpose(feat, 1, 2) | |||
| return out | |||
| def spectral_norm(module, mode=True): | |||
| if mode: | |||
| return _spectral_norm(module) | |||
| return module | |||
| @@ -610,4 +610,9 @@ TASK_OUTPUTS = { | |||
| # "img_embedding": np.array with shape [1, D], | |||
| # } | |||
| Tasks.image_reid_person: [OutputKeys.IMG_EMBEDDING], | |||
| # { | |||
| # 'output': ['Done' / 'Decode_Error'] | |||
| # } | |||
| Tasks.video_inpainting: [OutputKeys.OUTPUT] | |||
| } | |||
| @@ -168,6 +168,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| 'damo/cv_resnet50-bert_video-scene-segmentation_movienet'), | |||
| Tasks.shop_segmentation: (Pipelines.shop_segmentation, | |||
| 'damo/cv_vitb16_segmentation_shop-seg'), | |||
| Tasks.video_inpainting: (Pipelines.video_inpainting, | |||
| 'damo/cv_video-inpainting'), | |||
| } | |||
| @@ -0,0 +1,47 @@ | |||
| from typing import Any, Dict | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models.cv.video_inpainting import inpainting | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Input, Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.video_inpainting, module_name=Pipelines.video_inpainting) | |||
| class VideoInpaintingPipeline(Pipeline): | |||
| def __init__(self, model: str, **kwargs): | |||
| """ | |||
| use `model` to create video inpainting pipeline for prediction | |||
| Args: | |||
| model: model id on modelscope hub. | |||
| """ | |||
| super().__init__(model=model, **kwargs) | |||
| logger.info('load model done') | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| return input | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| decode_error, fps, w, h = inpainting.video_process( | |||
| input['video_input_path']) | |||
| if decode_error is not None: | |||
| return {OutputKeys.OUTPUT: 'decode_error'} | |||
| inpainting.inpainting_by_model_balance(self.model, | |||
| input['video_input_path'], | |||
| input['mask_path'], | |||
| input['video_output_path'], fps, | |||
| w, h) | |||
| return {OutputKeys.OUTPUT: 'Done'} | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| return inputs | |||
| @@ -70,6 +70,9 @@ class CVTasks(object): | |||
| crowd_counting = 'crowd-counting' | |||
| movie_scene_segmentation = 'movie-scene-segmentation' | |||
| # video editing | |||
| video_inpainting = 'video-inpainting' | |||
| # reid and tracking | |||
| video_single_object_tracking = 'video-single-object-tracking' | |||
| video_summarization = 'video-summarization' | |||
| @@ -1,5 +1,4 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import os.path as osp | |||
| import unittest | |||
| @@ -0,0 +1,39 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import unittest | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.base import Pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class VideoInpaintingTest(unittest.TestCase): | |||
| def setUp(self) -> None: | |||
| self.model = 'damo/cv_video-inpainting' | |||
| self.mask_dir = 'data/test/videos/mask_dir' | |||
| self.video_in = 'data/test/videos/video_inpainting_test.mp4' | |||
| self.video_out = 'out.mp4' | |||
| self.input = { | |||
| 'video_input_path': self.video_in, | |||
| 'video_output_path': self.video_out, | |||
| 'mask_path': self.mask_dir | |||
| } | |||
| def pipeline_inference(self, pipeline: Pipeline, input: str): | |||
| result = pipeline(input) | |||
| print(result) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_modelhub(self): | |||
| video_inpainting = pipeline(Tasks.video_inpainting, model=self.model) | |||
| self.pipeline_inference(video_inpainting, self.input) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_modelhub_default_model(self): | |||
| video_inpainting = pipeline(Tasks.video_inpainting) | |||
| self.pipeline_inference(video_inpainting, self.input) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||