视频编辑的cr
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10026166
master
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:b158f6029d9763d7f84042f7c5835f398c688fdbb6b3f4fe6431101d4118c66c | |||||
| size 2766 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:0dcf46b93077e2229ab69cd6ddb80e2689546c575ee538bb2033fee1124ef3e3 | |||||
| size 2761 | |||||
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:9c9870df5a86acaaec67063183dace795479cd0f05296f13058995f475149c56 | |||||
| size 2957783 | |||||
| @@ -38,6 +38,7 @@ class Models(object): | |||||
| mogface = 'mogface' | mogface = 'mogface' | ||||
| mtcnn = 'mtcnn' | mtcnn = 'mtcnn' | ||||
| ulfd = 'ulfd' | ulfd = 'ulfd' | ||||
| video_inpainting = 'video-inpainting' | |||||
| # EasyCV models | # EasyCV models | ||||
| yolox = 'YOLOX' | yolox = 'YOLOX' | ||||
| @@ -169,6 +170,7 @@ class Pipelines(object): | |||||
| text_driven_segmentation = 'text-driven-segmentation' | text_driven_segmentation = 'text-driven-segmentation' | ||||
| movie_scene_segmentation = 'resnet50-bert-movie-scene-segmentation' | movie_scene_segmentation = 'resnet50-bert-movie-scene-segmentation' | ||||
| shop_segmentation = 'shop-segmentation' | shop_segmentation = 'shop-segmentation' | ||||
| video_inpainting = 'video-inpainting' | |||||
| # nlp tasks | # nlp tasks | ||||
| sentence_similarity = 'sentence-similarity' | sentence_similarity = 'sentence-similarity' | ||||
| @@ -0,0 +1,20 @@ | |||||
| # copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import TYPE_CHECKING | |||||
| from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | |||||
| from .inpainting_model import VideoInpainting | |||||
| else: | |||||
| _import_structure = {'inpainting_model': ['VideoInpainting']} | |||||
| import sys | |||||
| sys.modules[__name__] = LazyImportModule( | |||||
| __name__, | |||||
| globals()['__file__'], | |||||
| _import_structure, | |||||
| module_spec=__spec__, | |||||
| extra_objects={}, | |||||
| ) | |||||
| @@ -0,0 +1,298 @@ | |||||
| """ VideoInpaintingProcess | |||||
| Base modules are adapted from https://github.com/researchmm/STTN, | |||||
| originally Apache 2.0 License, Copyright (c) 2018-2022 OpenMMLab, | |||||
| """ | |||||
| import os | |||||
| import time | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import torch | |||||
| from PIL import Image | |||||
| from torchvision import transforms | |||||
| torch.backends.cudnn.enabled = False | |||||
| w, h = 192, 96 | |||||
| ref_length = 300 | |||||
| neighbor_stride = 20 | |||||
| default_fps = 24 | |||||
| MAX_frame = 300 | |||||
| def video_process(video_input_path): | |||||
| video_input = cv2.VideoCapture(video_input_path) | |||||
| success, frame = video_input.read() | |||||
| if success is False: | |||||
| decode_error = 'decode_error' | |||||
| w, h, fps = 0, 0, 0 | |||||
| else: | |||||
| decode_error = None | |||||
| h, w = frame.shape[0:2] | |||||
| fps = video_input.get(cv2.CAP_PROP_FPS) | |||||
| video_input.release() | |||||
| return decode_error, fps, w, h | |||||
| class Stack(object): | |||||
| def __init__(self, roll=False): | |||||
| self.roll = roll | |||||
| def __call__(self, img_group): | |||||
| mode = img_group[0].mode | |||||
| if mode == '1': | |||||
| img_group = [img.convert('L') for img in img_group] | |||||
| mode = 'L' | |||||
| if mode == 'L': | |||||
| return np.stack([np.expand_dims(x, 2) for x in img_group], axis=2) | |||||
| elif mode == 'RGB': | |||||
| if self.roll: | |||||
| return np.stack([np.array(x)[:, :, ::-1] for x in img_group], | |||||
| axis=2) | |||||
| else: | |||||
| return np.stack(img_group, axis=2) | |||||
| else: | |||||
| raise NotImplementedError(f'Image mode {mode}') | |||||
| class ToTorchFormatTensor(object): | |||||
| """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] | |||||
| to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ | |||||
| def __init__(self, div=True): | |||||
| self.div = div | |||||
| def __call__(self, pic): | |||||
| if isinstance(pic, np.ndarray): | |||||
| img = torch.from_numpy(pic).permute(2, 3, 0, 1).contiguous() | |||||
| else: | |||||
| img = torch.ByteTensor( | |||||
| torch.ByteStorage.from_buffer(pic.tobytes())) | |||||
| img = img.view(pic.size[1], pic.size[0], len(pic.mode)) | |||||
| img = img.transpose(0, 1).transpose(0, 2).contiguous() | |||||
| img = img.float().div(255) if self.div else img.float() | |||||
| return img | |||||
| _to_tensors = transforms.Compose([Stack(), ToTorchFormatTensor()]) | |||||
| def get_crop_mask_v1(mask): | |||||
| orig_h, orig_w, _ = mask.shape | |||||
| if (mask == 255).all(): | |||||
| return mask, (0, int(orig_h), 0, | |||||
| int(orig_w)), [0, int(orig_h), 0, | |||||
| int(orig_w) | |||||
| ], [0, int(orig_h), 0, | |||||
| int(orig_w)] | |||||
| hs = np.min(np.where(mask == 0)[0]) | |||||
| he = np.max(np.where(mask == 0)[0]) | |||||
| ws = np.min(np.where(mask == 0)[1]) | |||||
| we = np.max(np.where(mask == 0)[1]) | |||||
| crop_box = [ws, hs, we, he] | |||||
| mask_h = round(int(orig_h / 2) / 4) * 4 | |||||
| mask_w = round(int(orig_w / 2) / 4) * 4 | |||||
| if (hs < mask_h) and (he < mask_h) and (ws < mask_w) and (we < mask_w): | |||||
| crop_mask = mask[:mask_h, :mask_w, :] | |||||
| res_pix = (0, mask_h, 0, mask_w) | |||||
| elif (hs < mask_h) and (he < mask_h) and (ws > mask_w) and (we > mask_w): | |||||
| crop_mask = mask[:mask_h, orig_w - mask_w:orig_w, :] | |||||
| res_pix = (0, mask_h, orig_w - mask_w, int(orig_w)) | |||||
| elif (hs > mask_h) and (he > mask_h) and (ws < mask_w) and (we < mask_w): | |||||
| crop_mask = mask[orig_h - mask_h:orig_h, :mask_w, :] | |||||
| res_pix = (orig_h - mask_h, int(orig_h), 0, mask_w) | |||||
| elif (hs > mask_h) and (he > mask_h) and (ws > mask_w) and (we > mask_w): | |||||
| crop_mask = mask[orig_h - mask_h:orig_h, orig_w - mask_w:orig_w, :] | |||||
| res_pix = (orig_h - mask_h, int(orig_h), orig_w - mask_w, int(orig_w)) | |||||
| elif (hs < mask_h) and (he < mask_h) and (ws < mask_w) and (we > mask_w): | |||||
| crop_mask = mask[:mask_h, :, :] | |||||
| res_pix = (0, mask_h, 0, int(orig_w)) | |||||
| elif (hs < mask_h) and (he > mask_h) and (ws < mask_w) and (we < mask_w): | |||||
| crop_mask = mask[:, :mask_w, :] | |||||
| res_pix = (0, int(orig_h), 0, mask_w) | |||||
| elif (hs > mask_h) and (he > mask_h) and (ws < mask_w) and (we > mask_w): | |||||
| crop_mask = mask[orig_h - mask_h:orig_h, :, :] | |||||
| res_pix = (orig_h - mask_h, int(orig_h), 0, int(orig_w)) | |||||
| elif (hs < mask_h) and (he > mask_h) and (ws > mask_w) and (we > mask_w): | |||||
| crop_mask = mask[:, orig_w - mask_w:orig_w, :] | |||||
| res_pix = (0, int(orig_h), orig_w - mask_w, int(orig_w)) | |||||
| else: | |||||
| crop_mask = mask | |||||
| res_pix = (0, int(orig_h), 0, int(orig_w)) | |||||
| a = ws - res_pix[2] | |||||
| b = hs - res_pix[0] | |||||
| c = we - res_pix[2] | |||||
| d = he - res_pix[0] | |||||
| return crop_mask, res_pix, crop_box, [a, b, c, d] | |||||
| def get_ref_index(neighbor_ids, length): | |||||
| ref_index = [] | |||||
| for i in range(0, length, ref_length): | |||||
| if i not in neighbor_ids: | |||||
| ref_index.append(i) | |||||
| return ref_index | |||||
| def read_mask_oneImage(mpath): | |||||
| masks = [] | |||||
| print('mask_path: {}'.format(mpath)) | |||||
| start = int(mpath.split('/')[-1].split('mask_')[1].split('_')[0]) | |||||
| end = int( | |||||
| mpath.split('/')[-1].split('mask_')[1].split('_')[1].split('.')[0]) | |||||
| m = Image.open(mpath) | |||||
| m = np.array(m.convert('L')) | |||||
| m = np.array(m > 0).astype(np.uint8) | |||||
| m = 1 - m | |||||
| for i in range(start - 1, end + 1): | |||||
| masks.append(Image.fromarray(m * 255)) | |||||
| return masks | |||||
| def check_size(h, w): | |||||
| is_resize = False | |||||
| if h != 240: | |||||
| h = 240 | |||||
| is_resize = True | |||||
| if w != 432: | |||||
| w = 432 | |||||
| is_resize = True | |||||
| return is_resize | |||||
| def get_mask_list(mask_path): | |||||
| mask_names = os.listdir(mask_path) | |||||
| mask_names.sort() | |||||
| abs_mask_path = [] | |||||
| mask_list = [] | |||||
| begin_list = [] | |||||
| end_list = [] | |||||
| for mask_name in mask_names: | |||||
| mask_name_tmp = mask_name.split('mask_')[1] | |||||
| begin_list.append(int(mask_name_tmp.split('_')[0])) | |||||
| end_list.append(int(mask_name_tmp.split('_')[1].split('.')[0])) | |||||
| abs_mask_path.append(os.path.join(mask_path, mask_name)) | |||||
| mask = cv2.imread(os.path.join(mask_path, mask_name)) | |||||
| mask_list.append(mask) | |||||
| return mask_list, begin_list, end_list, abs_mask_path | |||||
| def inpainting_by_model_balance(model, video_inputPath, mask_path, | |||||
| video_savePath, fps, w_ori, h_ori): | |||||
| video_ori = cv2.VideoCapture(video_inputPath) | |||||
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |||||
| video_save = cv2.VideoWriter(video_savePath, fourcc, fps, (w_ori, h_ori)) | |||||
| mask_list, begin_list, end_list, abs_mask_path = get_mask_list(mask_path) | |||||
| img_npy = [] | |||||
| for index, mask in enumerate(mask_list): | |||||
| masks = read_mask_oneImage(abs_mask_path[index]) | |||||
| mask, res_pix, crop_for_oriimg, crop_for_inpimg = get_crop_mask_v1( | |||||
| mask) | |||||
| mask_h, mask_w = mask.shape[0:2] | |||||
| is_resize = check_size(mask.shape[0], mask.shape[1]) | |||||
| begin = begin_list[index] | |||||
| end = end_list[index] | |||||
| print('begin: {}'.format(begin)) | |||||
| print('end: {}'.format(end)) | |||||
| for i in range(begin, end + 1, MAX_frame): | |||||
| begin_time = time.time() | |||||
| if i + MAX_frame <= end: | |||||
| video_length = MAX_frame | |||||
| else: | |||||
| video_length = end - i + 1 | |||||
| for frame_count in range(video_length): | |||||
| _, frame = video_ori.read() | |||||
| img_npy.append(frame) | |||||
| frames_temp = [] | |||||
| for f in img_npy: | |||||
| f = Image.fromarray(f) | |||||
| i_temp = f.crop( | |||||
| (res_pix[2], res_pix[0], res_pix[3], res_pix[1])) | |||||
| a = i_temp.resize((w, h), Image.NEAREST) | |||||
| frames_temp.append(a) | |||||
| feats_temp = _to_tensors(frames_temp).unsqueeze(0) * 2 - 1 | |||||
| frames_temp = [np.array(f).astype(np.uint8) for f in frames_temp] | |||||
| masks_temp = [] | |||||
| for m in masks[i - begin:i + video_length - begin]: | |||||
| m_temp = m.crop( | |||||
| (res_pix[2], res_pix[0], res_pix[3], res_pix[1])) | |||||
| b = m_temp.resize((w, h), Image.NEAREST) | |||||
| masks_temp.append(b) | |||||
| binary_masks_temp = [ | |||||
| np.expand_dims((np.array(m) != 0).astype(np.uint8), 2) | |||||
| for m in masks_temp | |||||
| ] | |||||
| masks_temp = _to_tensors(masks_temp).unsqueeze(0) | |||||
| feats_temp, masks_temp = feats_temp.cuda(), masks_temp.cuda() | |||||
| comp_frames = [None] * video_length | |||||
| model.eval() | |||||
| with torch.no_grad(): | |||||
| feats_out = feats_temp * (1 - masks_temp).float() | |||||
| feats_out = feats_out.view(video_length, 3, h, w) | |||||
| feats_out = model.model.encoder(feats_out) | |||||
| _, c, feat_h, feat_w = feats_out.size() | |||||
| feats_out = feats_out.view(1, video_length, c, feat_h, feat_w) | |||||
| for f in range(0, video_length, neighbor_stride): | |||||
| neighbor_ids = [ | |||||
| i for i in range( | |||||
| max(0, f - neighbor_stride), | |||||
| min(video_length, f + neighbor_stride + 1)) | |||||
| ] | |||||
| ref_ids = get_ref_index(neighbor_ids, video_length) | |||||
| with torch.no_grad(): | |||||
| pred_feat = model.model.infer( | |||||
| feats_out[0, neighbor_ids + ref_ids, :, :, :], | |||||
| masks_temp[0, neighbor_ids + ref_ids, :, :, :]) | |||||
| pred_img = torch.tanh( | |||||
| model.model.decoder( | |||||
| pred_feat[:len(neighbor_ids), :, :, :])).detach() | |||||
| pred_img = (pred_img + 1) / 2 | |||||
| pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy() * 255 | |||||
| for j in range(len(neighbor_ids)): | |||||
| idx = neighbor_ids[j] | |||||
| img = np.array(pred_img[j]).astype( | |||||
| np.uint8) * binary_masks_temp[idx] + frames_temp[ | |||||
| idx] * (1 - binary_masks_temp[idx]) | |||||
| if comp_frames[idx] is None: | |||||
| comp_frames[idx] = img | |||||
| else: | |||||
| comp_frames[idx] = comp_frames[idx].astype( | |||||
| np.float32) * 0.5 + img.astype( | |||||
| np.float32) * 0.5 | |||||
| print('inpainting time:', time.time() - begin_time) | |||||
| for f in range(video_length): | |||||
| comp = np.array(comp_frames[f]).astype( | |||||
| np.uint8) * binary_masks_temp[f] + frames_temp[f] * ( | |||||
| 1 - binary_masks_temp[f]) | |||||
| if is_resize: | |||||
| comp = cv2.resize(comp, (mask_w, mask_h)) | |||||
| complete_frame = img_npy[f] | |||||
| a1, b1, c1, d1 = crop_for_oriimg | |||||
| a2, b2, c2, d2 = crop_for_inpimg | |||||
| complete_frame[b1:d1, a1:c1] = comp[b2:d2, a2:c2] | |||||
| video_save.write(complete_frame) | |||||
| img_npy = [] | |||||
| video_ori.release() | |||||
| @@ -0,0 +1,373 @@ | |||||
| """ VideoInpaintingNetwork | |||||
| Base modules are adapted from https://github.com/researchmm/STTN, | |||||
| originally Apache 2.0 License, Copyright (c) 2018-2022 OpenMMLab, | |||||
| """ | |||||
| import math | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.base import TorchModel | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| class BaseNetwork(nn.Module): | |||||
| def __init__(self): | |||||
| super(BaseNetwork, self).__init__() | |||||
| def print_network(self): | |||||
| if isinstance(self, list): | |||||
| self = self[0] | |||||
| num_params = 0 | |||||
| for param in self.parameters(): | |||||
| num_params += param.numel() | |||||
| print( | |||||
| 'Network [%s] was created. Total number of parameters: %.1f million. ' | |||||
| 'To see the architecture, do print(network).' % | |||||
| (type(self).__name__, num_params / 1000000)) | |||||
| def init_weights(self, init_type='normal', gain=0.02): | |||||
| ''' | |||||
| initialize network's weights | |||||
| init_type: normal | xavier | kaiming | orthogonal | |||||
| https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39 | |||||
| ''' | |||||
| def init_func(m): | |||||
| classname = m.__class__.__name__ | |||||
| if classname.find('InstanceNorm2d') != -1: | |||||
| if hasattr(m, 'weight') and m.weight is not None: | |||||
| nn.init.constant_(m.weight.data, 1.0) | |||||
| if hasattr(m, 'bias') and m.bias is not None: | |||||
| nn.init.constant_(m.bias.data, 0.0) | |||||
| elif hasattr(m, 'weight') and (classname.find('Conv') != -1 | |||||
| or classname.find('Linear') != -1): | |||||
| if init_type == 'normal': | |||||
| nn.init.normal_(m.weight.data, 0.0, gain) | |||||
| elif init_type == 'xavier': | |||||
| nn.init.xavier_normal_(m.weight.data, gain=gain) | |||||
| elif init_type == 'xavier_uniform': | |||||
| nn.init.xavier_uniform_(m.weight.data, gain=1.0) | |||||
| elif init_type == 'kaiming': | |||||
| nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') | |||||
| elif init_type == 'orthogonal': | |||||
| nn.init.orthogonal_(m.weight.data, gain=gain) | |||||
| elif init_type == 'none': | |||||
| m.reset_parameters() | |||||
| else: | |||||
| raise NotImplementedError( | |||||
| 'initialization method [%s] is not implemented' | |||||
| % init_type) | |||||
| if hasattr(m, 'bias') and m.bias is not None: | |||||
| nn.init.constant_(m.bias.data, 0.0) | |||||
| self.apply(init_func) | |||||
| for m in self.children(): | |||||
| if hasattr(m, 'init_weights'): | |||||
| m.init_weights(init_type, gain) | |||||
| @MODELS.register_module( | |||||
| Tasks.video_inpainting, module_name=Models.video_inpainting) | |||||
| class VideoInpainting(TorchModel): | |||||
| def __init__(self, model_dir, device_id=0, *args, **kwargs): | |||||
| super().__init__( | |||||
| model_dir=model_dir, device_id=device_id, *args, **kwargs) | |||||
| self.model = InpaintGenerator() | |||||
| pretrained_params = torch.load('{}/{}'.format( | |||||
| model_dir, ModelFile.TORCH_MODEL_BIN_FILE)) | |||||
| self.model.load_state_dict(pretrained_params['netG']) | |||||
| self.model.eval() | |||||
| self.device_id = device_id | |||||
| if self.device_id >= 0 and torch.cuda.is_available(): | |||||
| self.model.to('cuda:{}'.format(self.device_id)) | |||||
| logger.info('Use GPU: {}'.format(self.device_id)) | |||||
| else: | |||||
| self.device_id = -1 | |||||
| logger.info('Use CPU for inference') | |||||
| class InpaintGenerator(BaseNetwork): | |||||
| def __init__(self, init_weights=True): | |||||
| super(InpaintGenerator, self).__init__() | |||||
| channel = 256 | |||||
| stack_num = 6 | |||||
| patchsize = [(48, 24), (16, 8), (8, 4), (4, 2)] | |||||
| blocks = [] | |||||
| for _ in range(stack_num): | |||||
| blocks.append(TransformerBlock(patchsize, hidden=channel)) | |||||
| self.transformer = nn.Sequential(*blocks) | |||||
| self.encoder = nn.Sequential( | |||||
| nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), | |||||
| nn.LeakyReLU(0.2, inplace=True), | |||||
| nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), | |||||
| nn.LeakyReLU(0.2, inplace=True), | |||||
| nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), | |||||
| nn.LeakyReLU(0.2, inplace=True), | |||||
| nn.Conv2d(128, channel, kernel_size=3, stride=1, padding=1), | |||||
| nn.LeakyReLU(0.2, inplace=True), | |||||
| ) | |||||
| self.decoder = nn.Sequential( | |||||
| deconv(channel, 128, kernel_size=3, padding=1), | |||||
| nn.LeakyReLU(0.2, inplace=True), | |||||
| nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), | |||||
| nn.LeakyReLU(0.2, inplace=True), | |||||
| deconv(64, 64, kernel_size=3, padding=1), | |||||
| nn.LeakyReLU(0.2, inplace=True), | |||||
| nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1)) | |||||
| if init_weights: | |||||
| self.init_weights() | |||||
| def forward(self, masked_frames, masks): | |||||
| b, t, c, h, w = masked_frames.size() | |||||
| masks = masks.view(b * t, 1, h, w) | |||||
| enc_feat = self.encoder(masked_frames.view(b * t, c, h, w)) | |||||
| _, c, h, w = enc_feat.size() | |||||
| masks = F.interpolate(masks, scale_factor=1.0 / 4) | |||||
| enc_feat = self.transformer({ | |||||
| 'x': enc_feat, | |||||
| 'm': masks, | |||||
| 'b': b, | |||||
| 'c': c | |||||
| })['x'] | |||||
| output = self.decoder(enc_feat) | |||||
| output = torch.tanh(output) | |||||
| return output | |||||
| def infer(self, feat, masks): | |||||
| t, c, h, w = masks.size() | |||||
| masks = masks.view(t, c, h, w) | |||||
| masks = F.interpolate(masks, scale_factor=1.0 / 4) | |||||
| t, c, _, _ = feat.size() | |||||
| enc_feat = self.transformer({ | |||||
| 'x': feat, | |||||
| 'm': masks, | |||||
| 'b': 1, | |||||
| 'c': c | |||||
| })['x'] | |||||
| return enc_feat | |||||
| class deconv(nn.Module): | |||||
| def __init__(self, | |||||
| input_channel, | |||||
| output_channel, | |||||
| kernel_size=3, | |||||
| padding=0): | |||||
| super().__init__() | |||||
| self.conv = nn.Conv2d( | |||||
| input_channel, | |||||
| output_channel, | |||||
| kernel_size=kernel_size, | |||||
| stride=1, | |||||
| padding=padding) | |||||
| def forward(self, x): | |||||
| x = F.interpolate( | |||||
| x, scale_factor=2, mode='bilinear', align_corners=True) | |||||
| x = self.conv(x) | |||||
| return x | |||||
| class Attention(nn.Module): | |||||
| """ | |||||
| Compute 'Scaled Dot Product Attention | |||||
| """ | |||||
| def forward(self, query, key, value, m): | |||||
| scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt( | |||||
| query.size(-1)) | |||||
| scores.masked_fill(m, -1e9) | |||||
| p_attn = F.softmax(scores, dim=-1) | |||||
| p_val = torch.matmul(p_attn, value) | |||||
| return p_val, p_attn | |||||
| class MultiHeadedAttention(nn.Module): | |||||
| """ | |||||
| Take in model size and number of heads. | |||||
| """ | |||||
| def __init__(self, patchsize, d_model): | |||||
| super().__init__() | |||||
| self.patchsize = patchsize | |||||
| self.query_embedding = nn.Conv2d( | |||||
| d_model, d_model, kernel_size=1, padding=0) | |||||
| self.value_embedding = nn.Conv2d( | |||||
| d_model, d_model, kernel_size=1, padding=0) | |||||
| self.key_embedding = nn.Conv2d( | |||||
| d_model, d_model, kernel_size=1, padding=0) | |||||
| self.output_linear = nn.Sequential( | |||||
| nn.Conv2d(d_model, d_model, kernel_size=3, padding=1), | |||||
| nn.LeakyReLU(0.2, inplace=True)) | |||||
| self.attention = Attention() | |||||
| def forward(self, x, m, b, c): | |||||
| bt, _, h, w = x.size() | |||||
| t = bt // b | |||||
| d_k = c // len(self.patchsize) | |||||
| output = [] | |||||
| _query = self.query_embedding(x) | |||||
| _key = self.key_embedding(x) | |||||
| _value = self.value_embedding(x) | |||||
| for (width, height), query, key, value in zip( | |||||
| self.patchsize, | |||||
| torch.chunk(_query, len(self.patchsize), dim=1), | |||||
| torch.chunk(_key, len(self.patchsize), dim=1), | |||||
| torch.chunk(_value, len(self.patchsize), dim=1)): | |||||
| out_w, out_h = w // width, h // height | |||||
| mm = m.view(b, t, 1, out_h, height, out_w, width) | |||||
| mm = mm.permute(0, 1, 3, 5, 2, 4, | |||||
| 6).contiguous().view(b, t * out_h * out_w, | |||||
| height * width) | |||||
| mm = (mm.mean(-1) > 0.5).unsqueeze(1).repeat( | |||||
| 1, t * out_h * out_w, 1) | |||||
| query = query.view(b, t, d_k, out_h, height, out_w, width) | |||||
| query = query.permute(0, 1, 3, 5, 2, 4, | |||||
| 6).contiguous().view(b, t * out_h * out_w, | |||||
| d_k * height * width) | |||||
| key = key.view(b, t, d_k, out_h, height, out_w, width) | |||||
| key = key.permute(0, 1, 3, 5, 2, 4, | |||||
| 6).contiguous().view(b, t * out_h * out_w, | |||||
| d_k * height * width) | |||||
| value = value.view(b, t, d_k, out_h, height, out_w, width) | |||||
| value = value.permute(0, 1, 3, 5, 2, 4, | |||||
| 6).contiguous().view(b, t * out_h * out_w, | |||||
| d_k * height * width) | |||||
| y, _ = self.attention(query, key, value, mm) | |||||
| y = y.view(b, t, out_h, out_w, d_k, height, width) | |||||
| y = y.permute(0, 1, 4, 2, 5, 3, 6).contiguous().view(bt, d_k, h, w) | |||||
| output.append(y) | |||||
| output = torch.cat(output, 1) | |||||
| x = self.output_linear(output) | |||||
| return x | |||||
| class FeedForward(nn.Module): | |||||
| def __init__(self, d_model): | |||||
| super(FeedForward, self).__init__() | |||||
| self.conv = nn.Sequential( | |||||
| nn.Conv2d(d_model, d_model, kernel_size=3, padding=2, dilation=2), | |||||
| nn.LeakyReLU(0.2, inplace=True), | |||||
| nn.Conv2d(d_model, d_model, kernel_size=3, padding=1), | |||||
| nn.LeakyReLU(0.2, inplace=True)) | |||||
| def forward(self, x): | |||||
| x = self.conv(x) | |||||
| return x | |||||
| class TransformerBlock(nn.Module): | |||||
| """ | |||||
| Transformer = MultiHead_Attention + Feed_Forward with sublayer connection | |||||
| """ | |||||
| def __init__(self, patchsize, hidden=128): # hidden=128 | |||||
| super().__init__() | |||||
| self.attention = MultiHeadedAttention(patchsize, d_model=hidden) | |||||
| self.feed_forward = FeedForward(hidden) | |||||
| def forward(self, x): | |||||
| x, m, b, c = x['x'], x['m'], x['b'], x['c'] | |||||
| x = x + self.attention(x, m, b, c) | |||||
| x = x + self.feed_forward(x) | |||||
| return {'x': x, 'm': m, 'b': b, 'c': c} | |||||
| class Discriminator(BaseNetwork): | |||||
| def __init__(self, | |||||
| in_channels=3, | |||||
| use_sigmoid=False, | |||||
| use_spectral_norm=True, | |||||
| init_weights=True): | |||||
| super(Discriminator, self).__init__() | |||||
| self.use_sigmoid = use_sigmoid | |||||
| nf = 64 | |||||
| self.conv = nn.Sequential( | |||||
| spectral_norm( | |||||
| nn.Conv3d( | |||||
| in_channels=in_channels, | |||||
| out_channels=nf * 1, | |||||
| kernel_size=(3, 5, 5), | |||||
| stride=(1, 2, 2), | |||||
| padding=1, | |||||
| bias=not use_spectral_norm), use_spectral_norm), | |||||
| nn.LeakyReLU(0.2, inplace=True), | |||||
| spectral_norm( | |||||
| nn.Conv3d( | |||||
| nf * 1, | |||||
| nf * 2, | |||||
| kernel_size=(3, 5, 5), | |||||
| stride=(1, 2, 2), | |||||
| padding=(1, 2, 2), | |||||
| bias=not use_spectral_norm), use_spectral_norm), | |||||
| nn.LeakyReLU(0.2, inplace=True), | |||||
| spectral_norm( | |||||
| nn.Conv3d( | |||||
| nf * 2, | |||||
| nf * 4, | |||||
| kernel_size=(3, 5, 5), | |||||
| stride=(1, 2, 2), | |||||
| padding=(1, 2, 2), | |||||
| bias=not use_spectral_norm), use_spectral_norm), | |||||
| nn.LeakyReLU(0.2, inplace=True), | |||||
| spectral_norm( | |||||
| nn.Conv3d( | |||||
| nf * 4, | |||||
| nf * 4, | |||||
| kernel_size=(3, 5, 5), | |||||
| stride=(1, 2, 2), | |||||
| padding=(1, 2, 2), | |||||
| bias=not use_spectral_norm), use_spectral_norm), | |||||
| nn.LeakyReLU(0.2, inplace=True), | |||||
| spectral_norm( | |||||
| nn.Conv3d( | |||||
| nf * 4, | |||||
| nf * 4, | |||||
| kernel_size=(3, 5, 5), | |||||
| stride=(1, 2, 2), | |||||
| padding=(1, 2, 2), | |||||
| bias=not use_spectral_norm), use_spectral_norm), | |||||
| nn.LeakyReLU(0.2, inplace=True), | |||||
| nn.Conv3d( | |||||
| nf * 4, | |||||
| nf * 4, | |||||
| kernel_size=(3, 5, 5), | |||||
| stride=(1, 2, 2), | |||||
| padding=(1, 2, 2))) | |||||
| if init_weights: | |||||
| self.init_weights() | |||||
| def forward(self, xs): | |||||
| xs_t = torch.transpose(xs, 0, 1) | |||||
| xs_t = xs_t.unsqueeze(0) | |||||
| feat = self.conv(xs_t) | |||||
| if self.use_sigmoid: | |||||
| feat = torch.sigmoid(feat) | |||||
| out = torch.transpose(feat, 1, 2) | |||||
| return out | |||||
| def spectral_norm(module, mode=True): | |||||
| if mode: | |||||
| return _spectral_norm(module) | |||||
| return module | |||||
| @@ -610,4 +610,9 @@ TASK_OUTPUTS = { | |||||
| # "img_embedding": np.array with shape [1, D], | # "img_embedding": np.array with shape [1, D], | ||||
| # } | # } | ||||
| Tasks.image_reid_person: [OutputKeys.IMG_EMBEDDING], | Tasks.image_reid_person: [OutputKeys.IMG_EMBEDDING], | ||||
| # { | |||||
| # 'output': ['Done' / 'Decode_Error'] | |||||
| # } | |||||
| Tasks.video_inpainting: [OutputKeys.OUTPUT] | |||||
| } | } | ||||
| @@ -168,6 +168,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| 'damo/cv_resnet50-bert_video-scene-segmentation_movienet'), | 'damo/cv_resnet50-bert_video-scene-segmentation_movienet'), | ||||
| Tasks.shop_segmentation: (Pipelines.shop_segmentation, | Tasks.shop_segmentation: (Pipelines.shop_segmentation, | ||||
| 'damo/cv_vitb16_segmentation_shop-seg'), | 'damo/cv_vitb16_segmentation_shop-seg'), | ||||
| Tasks.video_inpainting: (Pipelines.video_inpainting, | |||||
| 'damo/cv_video-inpainting'), | |||||
| } | } | ||||
| @@ -0,0 +1,47 @@ | |||||
| from typing import Any, Dict | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models.cv.video_inpainting import inpainting | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import Input, Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| @PIPELINES.register_module( | |||||
| Tasks.video_inpainting, module_name=Pipelines.video_inpainting) | |||||
| class VideoInpaintingPipeline(Pipeline): | |||||
| def __init__(self, model: str, **kwargs): | |||||
| """ | |||||
| use `model` to create video inpainting pipeline for prediction | |||||
| Args: | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| super().__init__(model=model, **kwargs) | |||||
| logger.info('load model done') | |||||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||||
| return input | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
| decode_error, fps, w, h = inpainting.video_process( | |||||
| input['video_input_path']) | |||||
| if decode_error is not None: | |||||
| return {OutputKeys.OUTPUT: 'decode_error'} | |||||
| inpainting.inpainting_by_model_balance(self.model, | |||||
| input['video_input_path'], | |||||
| input['mask_path'], | |||||
| input['video_output_path'], fps, | |||||
| w, h) | |||||
| return {OutputKeys.OUTPUT: 'Done'} | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| return inputs | |||||
| @@ -70,6 +70,9 @@ class CVTasks(object): | |||||
| crowd_counting = 'crowd-counting' | crowd_counting = 'crowd-counting' | ||||
| movie_scene_segmentation = 'movie-scene-segmentation' | movie_scene_segmentation = 'movie-scene-segmentation' | ||||
| # video editing | |||||
| video_inpainting = 'video-inpainting' | |||||
| # reid and tracking | # reid and tracking | ||||
| video_single_object_tracking = 'video-single-object-tracking' | video_single_object_tracking = 'video-single-object-tracking' | ||||
| video_summarization = 'video-summarization' | video_summarization = 'video-summarization' | ||||
| @@ -1,5 +1,4 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os | |||||
| import os.path as osp | import os.path as osp | ||||
| import unittest | import unittest | ||||
| @@ -0,0 +1,39 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import unittest | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.pipelines.base import Pipeline | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class VideoInpaintingTest(unittest.TestCase): | |||||
| def setUp(self) -> None: | |||||
| self.model = 'damo/cv_video-inpainting' | |||||
| self.mask_dir = 'data/test/videos/mask_dir' | |||||
| self.video_in = 'data/test/videos/video_inpainting_test.mp4' | |||||
| self.video_out = 'out.mp4' | |||||
| self.input = { | |||||
| 'video_input_path': self.video_in, | |||||
| 'video_output_path': self.video_out, | |||||
| 'mask_path': self.mask_dir | |||||
| } | |||||
| def pipeline_inference(self, pipeline: Pipeline, input: str): | |||||
| result = pipeline(input) | |||||
| print(result) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_modelhub(self): | |||||
| video_inpainting = pipeline(Tasks.video_inpainting, model=self.model) | |||||
| self.pipeline_inference(video_inpainting, self.input) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_modelhub_default_model(self): | |||||
| video_inpainting = pipeline(Tasks.video_inpainting) | |||||
| self.pipeline_inference(video_inpainting, self.input) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||