tingwei.gtw yingda.chen 3 years ago
parent
commit
4a9dfbf095
3 changed files with 18 additions and 9 deletions
  1. +1
    -1
      modelscope/models/cv/video_inpainting/__init__.py
  2. +4
    -3
      modelscope/models/cv/video_inpainting/inpainting.py
  3. +13
    -5
      modelscope/models/cv/video_inpainting/inpainting_model.py

+ 1
- 1
modelscope/models/cv/video_inpainting/__init__.py View File

@@ -1,4 +1,4 @@
# copyright (c) Alibaba, Inc. and its affiliates.
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
from typing import TYPE_CHECKING

from modelscope.utils.import_utils import LazyImportModule


+ 4
- 3
modelscope/models/cv/video_inpainting/inpainting.py View File

@@ -1,6 +1,6 @@
""" VideoInpaintingProcess
Base modules are adapted from https://github.com/researchmm/STTN,
originally Apache 2.0 License, Copyright (c) 2018-2022 OpenMMLab,
The implementation here is modified based on STTN,
originally Apache 2.0 License and publicly avaialbe at https://github.com/researchmm/STTN
"""

import os
@@ -243,7 +243,8 @@ def inpainting_by_model_balance(model, video_inputPath, mask_path,
for m in masks_temp
]
masks_temp = _to_tensors(masks_temp).unsqueeze(0)
feats_temp, masks_temp = feats_temp.cuda(), masks_temp.cuda()
if torch.cuda.is_available():
feats_temp, masks_temp = feats_temp.cuda(), masks_temp.cuda()
comp_frames = [None] * video_length
model.eval()
with torch.no_grad():


+ 13
- 5
modelscope/models/cv/video_inpainting/inpainting_model.py View File

@@ -1,15 +1,18 @@
""" VideoInpaintingNetwork
Base modules are adapted from https://github.com/researchmm/STTN,
originally Apache 2.0 License, Copyright (c) 2018-2022 OpenMMLab,
""" VideoInpaintingProcess
The implementation here is modified based on STTN,
originally Apache 2.0 License and publicly avaialbe at https://github.com/researchmm/STTN
"""

import math

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

from modelscope.metainfo import Models
from modelscope.models import Model
from modelscope.models.base import TorchModel
from modelscope.models.builder import MODELS
from modelscope.utils.constant import ModelFile, Tasks
@@ -84,8 +87,13 @@ class VideoInpainting(TorchModel):
super().__init__(
model_dir=model_dir, device_id=device_id, *args, **kwargs)
self.model = InpaintGenerator()
pretrained_params = torch.load('{}/{}'.format(
model_dir, ModelFile.TORCH_MODEL_BIN_FILE))
if torch.cuda.is_available():
device = 'cuda'
else:
device = 'cpu'
pretrained_params = torch.load(
'{}/{}'.format(model_dir, ModelFile.TORCH_MODEL_BIN_FILE),
map_location=device)
self.model.load_state_dict(pretrained_params['netG'])
self.model.eval()
self.device_id = device_id


Loading…
Cancel
Save