Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9744999master
| @@ -45,8 +45,8 @@ class FQA(object): | |||||
| model.load_state_dict(model_dict) | model.load_state_dict(model_dict) | ||||
| def get_face_quality(self, img): | def get_face_quality(self, img): | ||||
| img = torch.from_numpy(img).permute(2, 0, | |||||
| 1).unsqueeze(0).flip(1).cuda() | |||||
| img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).flip(1).to( | |||||
| self.device) | |||||
| img = (img - 127.5) / 128.0 | img = (img - 127.5) / 128.0 | ||||
| # extract features & predict quality | # extract features & predict quality | ||||
| @@ -36,7 +36,6 @@ class ImageColorizationPipeline(Pipeline): | |||||
| self.device = torch.device('cuda') | self.device = torch.device('cuda') | ||||
| else: | else: | ||||
| self.device = torch.device('cpu') | self.device = torch.device('cpu') | ||||
| self.size = 1024 | |||||
| self.orig_img = None | self.orig_img = None | ||||
| self.model_type = 'stable' | self.model_type = 'stable' | ||||
| @@ -91,6 +90,8 @@ class ImageColorizationPipeline(Pipeline): | |||||
| img = LoadImage.convert_to_img(input).convert('LA').convert('RGB') | img = LoadImage.convert_to_img(input).convert('LA').convert('RGB') | ||||
| self.wide, self.height = img.size | self.wide, self.height = img.size | ||||
| if self.wide * self.height < 100000: | |||||
| self.size = 256 | |||||
| self.orig_img = img.copy() | self.orig_img = img.copy() | ||||
| img = img.resize((self.size, self.size), resample=PIL.Image.BILINEAR) | img = img.resize((self.size, self.size), resample=PIL.Image.BILINEAR) | ||||
| @@ -58,7 +58,8 @@ class ImagePortraitEnhancementPipeline(Pipeline): | |||||
| gpen_model_path = f'{model}/{ModelFile.TORCH_MODEL_FILE}' | gpen_model_path = f'{model}/{ModelFile.TORCH_MODEL_FILE}' | ||||
| self.face_enhancer.load_state_dict( | self.face_enhancer.load_state_dict( | ||||
| torch.load(gpen_model_path), strict=True) | |||||
| torch.load(gpen_model_path, map_location=torch.device('cpu')), | |||||
| strict=True) | |||||
| logger.info('load face enhancer model done') | logger.info('load face enhancer model done') | ||||
| @@ -82,7 +83,9 @@ class ImagePortraitEnhancementPipeline(Pipeline): | |||||
| sr_model_path = f'{model}/super_resolution/realesrnet_x{self.scale}.pth' | sr_model_path = f'{model}/super_resolution/realesrnet_x{self.scale}.pth' | ||||
| self.sr_model.load_state_dict( | self.sr_model.load_state_dict( | ||||
| torch.load(sr_model_path)['params_ema'], strict=True) | |||||
| torch.load(sr_model_path, | |||||
| map_location=torch.device('cpu'))['params_ema'], | |||||
| strict=True) | |||||
| logger.info('load sr model done') | logger.info('load sr model done') | ||||