|
|
@@ -12,7 +12,7 @@ from modelscope.models.cv.image_colorization import (DynamicUnetDeep, |
|
|
from modelscope.outputs import OutputKeys |
|
|
from modelscope.outputs import OutputKeys |
|
|
from modelscope.pipelines.base import Input, Pipeline |
|
|
from modelscope.pipelines.base import Input, Pipeline |
|
|
from modelscope.pipelines.builder import PIPELINES |
|
|
from modelscope.pipelines.builder import PIPELINES |
|
|
from modelscope.preprocessors import load_image |
|
|
|
|
|
|
|
|
from modelscope.preprocessors import LoadImage, load_image |
|
|
from modelscope.utils.constant import ModelFile, Tasks |
|
|
from modelscope.utils.constant import ModelFile, Tasks |
|
|
from modelscope.utils.logger import get_logger |
|
|
from modelscope.utils.logger import get_logger |
|
|
|
|
|
|
|
|
@@ -31,7 +31,13 @@ class ImageColorizationPipeline(Pipeline): |
|
|
""" |
|
|
""" |
|
|
super().__init__(model=model, **kwargs) |
|
|
super().__init__(model=model, **kwargs) |
|
|
self.cut = 8 |
|
|
self.cut = 8 |
|
|
self.size = 1024 if self.device_name == 'cpu' else 512 |
|
|
|
|
|
|
|
|
self.size = 512 |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
|
|
|
self.device = torch.device('cuda') |
|
|
|
|
|
else: |
|
|
|
|
|
self.device = torch.device('cpu') |
|
|
|
|
|
self.size = 1024 |
|
|
|
|
|
|
|
|
self.orig_img = None |
|
|
self.orig_img = None |
|
|
self.model_type = 'stable' |
|
|
self.model_type = 'stable' |
|
|
self.norm = transforms.Compose([ |
|
|
self.norm = transforms.Compose([ |
|
|
@@ -82,18 +88,7 @@ class ImageColorizationPipeline(Pipeline): |
|
|
logger.info('load model done') |
|
|
logger.info('load model done') |
|
|
|
|
|
|
|
|
def preprocess(self, input: Input) -> Dict[str, Any]: |
|
|
def preprocess(self, input: Input) -> Dict[str, Any]: |
|
|
if isinstance(input, str): |
|
|
|
|
|
img = load_image(input).convert('LA').convert('RGB') |
|
|
|
|
|
elif isinstance(input, Image.Image): |
|
|
|
|
|
img = input.convert('LA').convert('RGB') |
|
|
|
|
|
elif isinstance(input, np.ndarray): |
|
|
|
|
|
if len(input.shape) == 2: |
|
|
|
|
|
input = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) |
|
|
|
|
|
img = input[:, :, ::-1] # in rgb order |
|
|
|
|
|
img = PIL.Image.fromarray(img).convert('LA').convert('RGB') |
|
|
|
|
|
else: |
|
|
|
|
|
raise TypeError(f'input should be either str, PIL.Image,' |
|
|
|
|
|
f' np.array, but got {type(input)}') |
|
|
|
|
|
|
|
|
img = LoadImage.convert_to_img(input).convert('LA').convert('RGB') |
|
|
|
|
|
|
|
|
self.wide, self.height = img.size |
|
|
self.wide, self.height = img.size |
|
|
if self.wide * self.height > self.size * self.size: |
|
|
if self.wide * self.height > self.size * self.size: |
|
|
|