Browse Source

substitute face detection model in skin_retouching_pipeline.py

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10909902
master^2
ly261666 wenmeng.zwm 3 years ago
parent
commit
4208d51e23
1 changed files with 16 additions and 11 deletions
  1. +16
    -11
      modelscope/pipelines/cv/skin_retouching_pipeline.py

+ 16
- 11
modelscope/pipelines/cv/skin_retouching_pipeline.py View File

@@ -15,11 +15,10 @@ from modelscope.models.cv.skin_retouching.detection_model.detection_unet_in impo
DetectionUNet DetectionUNet
from modelscope.models.cv.skin_retouching.inpainting_model.inpainting_unet import \ from modelscope.models.cv.skin_retouching.inpainting_model.inpainting_unet import \
RetouchingNet RetouchingNet
from modelscope.models.cv.skin_retouching.retinaface.predict_single import \
Model
from modelscope.models.cv.skin_retouching.unet_deploy import UNet from modelscope.models.cv.skin_retouching.unet_deploy import UNet
from modelscope.models.cv.skin_retouching.utils import * # noqa F403 from modelscope.models.cv.skin_retouching.utils import * # noqa F403
from modelscope.outputs import OutputKeys from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import LoadImage from modelscope.preprocessors import LoadImage
@@ -48,8 +47,6 @@ class SkinRetouchingPipeline(Pipeline):


device = create_device(self.device_name) device = create_device(self.device_name)
model_path = os.path.join(self.model, ModelFile.TORCH_MODEL_FILE) model_path = os.path.join(self.model, ModelFile.TORCH_MODEL_FILE)
detector_model_path = os.path.join(
self.model, 'retinaface_resnet50_2020-07-20_old_torch.pth')
local_model_path = os.path.join(self.model, 'joint_20210926.pth') local_model_path = os.path.join(self.model, 'joint_20210926.pth')
skin_model_path = os.path.join(self.model, ModelFile.TF_GRAPH_FILE) skin_model_path = os.path.join(self.model, ModelFile.TF_GRAPH_FILE)


@@ -58,10 +55,9 @@ class SkinRetouchingPipeline(Pipeline):
torch.load(model_path, map_location='cpu')['generator']) torch.load(model_path, map_location='cpu')['generator'])
self.generator.eval() self.generator.eval()


self.detector = Model(max_size=512, device=device)
state_dict = torch.load(detector_model_path, map_location='cpu')
self.detector.load_state_dict(state_dict)
self.detector.eval()
det_model_id = 'damo/cv_resnet50_face-detection_retinaface'
self.detector = pipeline(Tasks.face_detection, model=det_model_id)
self.detector.detector.to(device)


self.local_model_path = local_model_path self.local_model_path = local_model_path
ckpt_dict_load = torch.load(self.local_model_path, map_location='cpu') ckpt_dict_load = torch.load(self.local_model_path, map_location='cpu')
@@ -136,9 +132,18 @@ class SkinRetouchingPipeline(Pipeline):
(rgb_image.shape[0], rgb_image.shape[1], 3), (rgb_image.shape[0], rgb_image.shape[1], 3),
dtype=np.float32) * 0.5 dtype=np.float32) * 0.5


results = self.detector.predict_jsons(
rgb_image
) # list, [{'bbox':, [x1, y1, x2, y2], 'score'...}, ...]
det_results = self.detector(rgb_image)
# list, [{'bbox':, [x1, y1, x2, y2], 'score'...}, ...]
results = []
for i in range(len(det_results['scores'])):
info_dict = {}
info_dict['bbox'] = np.array(det_results['boxes'][i]).astype(
np.int32).tolist()
info_dict['score'] = det_results['scores'][i]
info_dict['landmarks'] = np.array(
det_results['keypoints'][i]).astype(np.int32).reshape(
5, 2).tolist()
results.append(info_dict)


crop_bboxes = get_crop_bbox(results) crop_bboxes = get_crop_bbox(results)




Loading…
Cancel
Save