diff --git a/modelscope/models/cv/video_inpainting/__init__.py b/modelscope/models/cv/video_inpainting/__init__.py index fd93fe3c..f5489da9 100644 --- a/modelscope/models/cv/video_inpainting/__init__.py +++ b/modelscope/models/cv/video_inpainting/__init__.py @@ -1,4 +1,4 @@ -# copyright (c) Alibaba, Inc. and its affiliates. +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. from typing import TYPE_CHECKING from modelscope.utils.import_utils import LazyImportModule diff --git a/modelscope/models/cv/video_inpainting/inpainting.py b/modelscope/models/cv/video_inpainting/inpainting.py index 9632e01c..e2af2ad0 100644 --- a/modelscope/models/cv/video_inpainting/inpainting.py +++ b/modelscope/models/cv/video_inpainting/inpainting.py @@ -1,6 +1,6 @@ """ VideoInpaintingProcess -Base modules are adapted from https://github.com/researchmm/STTN, -originally Apache 2.0 License, Copyright (c) 2018-2022 OpenMMLab, +The implementation here is modified based on STTN, +originally Apache 2.0 License and publicly avaialbe at https://github.com/researchmm/STTN """ import os @@ -243,7 +243,8 @@ def inpainting_by_model_balance(model, video_inputPath, mask_path, for m in masks_temp ] masks_temp = _to_tensors(masks_temp).unsqueeze(0) - feats_temp, masks_temp = feats_temp.cuda(), masks_temp.cuda() + if torch.cuda.is_available(): + feats_temp, masks_temp = feats_temp.cuda(), masks_temp.cuda() comp_frames = [None] * video_length model.eval() with torch.no_grad(): diff --git a/modelscope/models/cv/video_inpainting/inpainting_model.py b/modelscope/models/cv/video_inpainting/inpainting_model.py index a791b0ab..ffecde67 100644 --- a/modelscope/models/cv/video_inpainting/inpainting_model.py +++ b/modelscope/models/cv/video_inpainting/inpainting_model.py @@ -1,15 +1,18 @@ -""" VideoInpaintingNetwork -Base modules are adapted from https://github.com/researchmm/STTN, -originally Apache 2.0 License, Copyright (c) 2018-2022 OpenMMLab, +""" VideoInpaintingProcess +The implementation here is modified based on STTN, + originally Apache 2.0 License and publicly avaialbe at https://github.com/researchmm/STTN """ import math +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F +import torchvision.models as models from modelscope.metainfo import Models +from modelscope.models import Model from modelscope.models.base import TorchModel from modelscope.models.builder import MODELS from modelscope.utils.constant import ModelFile, Tasks @@ -84,8 +87,13 @@ class VideoInpainting(TorchModel): super().__init__( model_dir=model_dir, device_id=device_id, *args, **kwargs) self.model = InpaintGenerator() - pretrained_params = torch.load('{}/{}'.format( - model_dir, ModelFile.TORCH_MODEL_BIN_FILE)) + if torch.cuda.is_available(): + device = 'cuda' + else: + device = 'cpu' + pretrained_params = torch.load( + '{}/{}'.format(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), + map_location=device) self.model.load_state_dict(pretrained_params['netG']) self.model.eval() self.device_id = device_id