|
|
|
@@ -1,15 +1,18 @@ |
|
|
|
""" VideoInpaintingNetwork |
|
|
|
Base modules are adapted from https://github.com/researchmm/STTN, |
|
|
|
originally Apache 2.0 License, Copyright (c) 2018-2022 OpenMMLab, |
|
|
|
""" VideoInpaintingProcess |
|
|
|
The implementation here is modified based on STTN, |
|
|
|
originally Apache 2.0 License and publicly avaialbe at https://github.com/researchmm/STTN |
|
|
|
""" |
|
|
|
|
|
|
|
import math |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
import torch |
|
|
|
import torch.nn as nn |
|
|
|
import torch.nn.functional as F |
|
|
|
import torchvision.models as models |
|
|
|
|
|
|
|
from modelscope.metainfo import Models |
|
|
|
from modelscope.models import Model |
|
|
|
from modelscope.models.base import TorchModel |
|
|
|
from modelscope.models.builder import MODELS |
|
|
|
from modelscope.utils.constant import ModelFile, Tasks |
|
|
|
@@ -84,8 +87,13 @@ class VideoInpainting(TorchModel): |
|
|
|
super().__init__( |
|
|
|
model_dir=model_dir, device_id=device_id, *args, **kwargs) |
|
|
|
self.model = InpaintGenerator() |
|
|
|
pretrained_params = torch.load('{}/{}'.format( |
|
|
|
model_dir, ModelFile.TORCH_MODEL_BIN_FILE)) |
|
|
|
if torch.cuda.is_available(): |
|
|
|
device = 'cuda' |
|
|
|
else: |
|
|
|
device = 'cpu' |
|
|
|
pretrained_params = torch.load( |
|
|
|
'{}/{}'.format(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), |
|
|
|
map_location=device) |
|
|
|
self.model.load_state_dict(pretrained_params['netG']) |
|
|
|
self.model.eval() |
|
|
|
self.device_id = device_id |
|
|
|
|