tingwei.gtw yingda.chen 3 years ago
parent
commit
4a9dfbf095
3 changed files with 18 additions and 9 deletions
  1. +1
    -1
      modelscope/models/cv/video_inpainting/__init__.py
  2. +4
    -3
      modelscope/models/cv/video_inpainting/inpainting.py
  3. +13
    -5
      modelscope/models/cv/video_inpainting/inpainting_model.py

+ 1
- 1
modelscope/models/cv/video_inpainting/__init__.py View File

@@ -1,4 +1,4 @@
# copyright (c) Alibaba, Inc. and its affiliates.
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING


from modelscope.utils.import_utils import LazyImportModule from modelscope.utils.import_utils import LazyImportModule


+ 4
- 3
modelscope/models/cv/video_inpainting/inpainting.py View File

@@ -1,6 +1,6 @@
""" VideoInpaintingProcess """ VideoInpaintingProcess
Base modules are adapted from https://github.com/researchmm/STTN,
originally Apache 2.0 License, Copyright (c) 2018-2022 OpenMMLab,
The implementation here is modified based on STTN,
originally Apache 2.0 License and publicly avaialbe at https://github.com/researchmm/STTN
""" """


import os import os
@@ -243,7 +243,8 @@ def inpainting_by_model_balance(model, video_inputPath, mask_path,
for m in masks_temp for m in masks_temp
] ]
masks_temp = _to_tensors(masks_temp).unsqueeze(0) masks_temp = _to_tensors(masks_temp).unsqueeze(0)
feats_temp, masks_temp = feats_temp.cuda(), masks_temp.cuda()
if torch.cuda.is_available():
feats_temp, masks_temp = feats_temp.cuda(), masks_temp.cuda()
comp_frames = [None] * video_length comp_frames = [None] * video_length
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():


+ 13
- 5
modelscope/models/cv/video_inpainting/inpainting_model.py View File

@@ -1,15 +1,18 @@
""" VideoInpaintingNetwork
Base modules are adapted from https://github.com/researchmm/STTN,
originally Apache 2.0 License, Copyright (c) 2018-2022 OpenMMLab,
""" VideoInpaintingProcess
The implementation here is modified based on STTN,
originally Apache 2.0 License and publicly avaialbe at https://github.com/researchmm/STTN
""" """


import math import math


import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchvision.models as models


from modelscope.metainfo import Models from modelscope.metainfo import Models
from modelscope.models import Model
from modelscope.models.base import TorchModel from modelscope.models.base import TorchModel
from modelscope.models.builder import MODELS from modelscope.models.builder import MODELS
from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.constant import ModelFile, Tasks
@@ -84,8 +87,13 @@ class VideoInpainting(TorchModel):
super().__init__( super().__init__(
model_dir=model_dir, device_id=device_id, *args, **kwargs) model_dir=model_dir, device_id=device_id, *args, **kwargs)
self.model = InpaintGenerator() self.model = InpaintGenerator()
pretrained_params = torch.load('{}/{}'.format(
model_dir, ModelFile.TORCH_MODEL_BIN_FILE))
if torch.cuda.is_available():
device = 'cuda'
else:
device = 'cpu'
pretrained_params = torch.load(
'{}/{}'.format(model_dir, ModelFile.TORCH_MODEL_BIN_FILE),
map_location=device)
self.model.load_state_dict(pretrained_params['netG']) self.model.load_state_dict(pretrained_params['netG'])
self.model.eval() self.model.eval()
self.device_id = device_id self.device_id = device_id


Loading…
Cancel
Save