You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

lseg_model.py 4.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os.path as osp
  3. from typing import Any, Dict
  4. import json
  5. import numpy as np
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. from PIL import Image
  10. from modelscope.metainfo import Models
  11. from modelscope.models.base import TorchModel
  12. from modelscope.models.builder import MODELS
  13. from modelscope.models.cv.text_driven_segmentation import \
  14. TextDrivenSegmentation
  15. from modelscope.outputs import OutputKeys
  16. from modelscope.preprocessors import LoadImage
  17. from modelscope.utils.constant import ModelFile, Tasks
  18. from modelscope.utils.logger import get_logger
  19. logger = get_logger()
  20. __all__ = ['TextDrivenSeg']
  21. @MODELS.register_module(
  22. Tasks.text_driven_segmentation,
  23. module_name=Models.text_driven_segmentation)
  24. class TextDrivenSeg(TorchModel):
  25. """ text driven segmentation model.
  26. """
  27. def __init__(self, model_dir, device_id=0, *args, **kwargs):
  28. super().__init__(
  29. model_dir=model_dir, device_id=device_id, *args, **kwargs)
  30. self.model = TextDrivenSegmentation(model_dir=model_dir)
  31. pretrained_params = torch.load('{}/{}'.format(
  32. model_dir, ModelFile.TORCH_MODEL_BIN_FILE))
  33. self.model.load_state_dict(pretrained_params)
  34. self.model.eval()
  35. if device_id >= 0 and torch.cuda.is_available():
  36. self.model.to('cuda:{}'.format(device_id))
  37. logger.info('Use GPU: {}'.format(device_id))
  38. else:
  39. device_id = -1
  40. logger.info('Use CPU for inference')
  41. self.device_id = device_id
  42. def preprocess(self, img, size=640):
  43. mean = [0.48145466, 0.4578275, 0.40821073]
  44. std = [0.26862954, 0.26130258, 0.27577711]
  45. h, w, c = img.shape
  46. max_hw = max(h, w)
  47. ratio = 1.0 * size / max_hw
  48. crop_h, crop_w = int(ratio * h), int(ratio * w)
  49. pil_img = Image.fromarray(img)
  50. pil_img = pil_img.resize((crop_w, crop_h), Image.BILINEAR)
  51. np_img = np.array(pil_img, dtype=np.float32) / 255.
  52. for j in range(3):
  53. np_img[:, :, j] = (np_img[:, :, j] - mean[j]) / std[j]
  54. img_pad = np.zeros((size, size, 3), dtype=np.float32)
  55. img_pad[:crop_h, :crop_w] = np_img
  56. img_pad = torch.from_numpy(img_pad).permute(2, 0,
  57. 1).unsqueeze(0).float()
  58. return img_pad, h, w, crop_h, crop_w
  59. def postprocess(self, tensors, crop_h, crop_w, ori_h, ori_w):
  60. output = np.clip(tensors * 255., a_min=0, a_max=255.)
  61. crop_output = np.array(output[:crop_h, :crop_w], dtype=np.uint8)
  62. pil_output = Image.fromarray(crop_output)
  63. pil_output = pil_output.resize((ori_w, ori_h), Image.BILINEAR)
  64. np_output = np.array(pil_output, dtype=np.uint8)
  65. np_output[np_output < 128] = 0
  66. np_output[np_output >= 128] = 255
  67. np_output = np.uint8(np_output)
  68. return np_output
  69. def forward(self, image, text):
  70. """
  71. image should be numpy array, dtype=np.uint8, shape: height*width*3
  72. """
  73. image_tensor, ori_h, ori_w, crop_h, crop_w = self.preprocess(
  74. image, size=640)
  75. pred = self.inference(image_tensor, text)
  76. msk = self.postprocess(pred, crop_h, crop_w, ori_h, ori_w, size=640)
  77. outputs = {OutputKeys.MASKS: msk}
  78. return outputs
  79. def inference(self, image, text):
  80. """
  81. image should be tensor, 1 * 3 * 640 * 640
  82. """
  83. with torch.no_grad():
  84. if self.device_id == -1:
  85. output = self.model(image)
  86. else:
  87. device = torch.device('cuda', self.device_id)
  88. output = self.model(image.to(device), [text])
  89. output = F.interpolate(output, size=(640, 640), mode='bilinear')
  90. output = F.softmax(output, dim=1)
  91. output = torch.argmax(output, dim=1)
  92. output = output[0]
  93. if self.device_id == -1:
  94. pred = output.data.numpy()
  95. else:
  96. pred = output.data.cpu().numpy()
  97. del output
  98. return pred