Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9515599master
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:08691a9373aa6d05b236a4ba788f3eccdea4c37aa77b30fc94b02ec3e1f18210 | |||||
| size 367017 | |||||
| @@ -16,6 +16,7 @@ class Models(object): | |||||
| nafnet = 'nafnet' | nafnet = 'nafnet' | ||||
| csrnet = 'csrnet' | csrnet = 'csrnet' | ||||
| cascade_mask_rcnn_swin = 'cascade_mask_rcnn_swin' | cascade_mask_rcnn_swin = 'cascade_mask_rcnn_swin' | ||||
| product_retrieval_embedding = 'product-retrieval-embedding' | |||||
| # nlp models | # nlp models | ||||
| bert = 'bert' | bert = 'bert' | ||||
| @@ -84,6 +85,7 @@ class Pipelines(object): | |||||
| image_super_resolution = 'rrdb-image-super-resolution' | image_super_resolution = 'rrdb-image-super-resolution' | ||||
| face_image_generation = 'gan-face-image-generation' | face_image_generation = 'gan-face-image-generation' | ||||
| style_transfer = 'AAMS-style-transfer' | style_transfer = 'AAMS-style-transfer' | ||||
| product_retrieval_embedding = 'resnet50-product-retrieval-embedding' | |||||
| face_recognition = 'ir101-face-recognition-cfglint' | face_recognition = 'ir101-face-recognition-cfglint' | ||||
| image_instance_segmentation = 'cascade-mask-rcnn-swin-image-instance-segmentation' | image_instance_segmentation = 'cascade-mask-rcnn-swin-image-instance-segmentation' | ||||
| image2image_translation = 'image-to-image-translation' | image2image_translation = 'image-to-image-translation' | ||||
| @@ -3,5 +3,5 @@ from . import (action_recognition, animal_recognition, cartoon, | |||||
| cmdssl_video_embedding, face_detection, face_generation, | cmdssl_video_embedding, face_detection, face_generation, | ||||
| image_classification, image_color_enhance, image_colorization, | image_classification, image_color_enhance, image_colorization, | ||||
| image_denoise, image_instance_segmentation, | image_denoise, image_instance_segmentation, | ||||
| image_to_image_translation, object_detection, super_resolution, | |||||
| virual_tryon) | |||||
| image_to_image_translation, object_detection, | |||||
| product_retrieval_embedding, super_resolution, virual_tryon) | |||||
| @@ -0,0 +1,23 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import TYPE_CHECKING | |||||
| from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | |||||
| from .item_model import ProductRetrievalEmbedding | |||||
| else: | |||||
| _import_structure = { | |||||
| 'item_model': ['ProductRetrievalEmbedding'], | |||||
| } | |||||
| import sys | |||||
| sys.modules[__name__] = LazyImportModule( | |||||
| __name__, | |||||
| globals()['__file__'], | |||||
| _import_structure, | |||||
| module_spec=__spec__, | |||||
| extra_objects={}, | |||||
| ) | |||||
| @@ -0,0 +1,517 @@ | |||||
| import cv2 | |||||
| import numpy as np | |||||
| class YOLOXONNX(object): | |||||
| """ | |||||
| Product detection model with onnx inference | |||||
| """ | |||||
| def __init__(self, onnx_path, multi_detect=False): | |||||
| """Create product detection model | |||||
| Args: | |||||
| onnx_path: onnx model path for product detection | |||||
| multi_detect: detection parameter, should be set as False | |||||
| """ | |||||
| self.input_reso = 416 | |||||
| self.iou_thr = 0.45 | |||||
| self.score_thr = 0.3 | |||||
| self.img_shape = tuple([self.input_reso, self.input_reso, 3]) | |||||
| self.num_classes = 13 | |||||
| self.onnx_path = onnx_path | |||||
| import onnxruntime as ort | |||||
| self.ort_session = ort.InferenceSession(self.onnx_path) | |||||
| self.with_p6 = False | |||||
| self.multi_detect = multi_detect | |||||
| def format_judge(self, img): | |||||
| m_min_width = 100 | |||||
| m_min_height = 100 | |||||
| height, width, c = img.shape | |||||
| if width * height > 1024 * 1024: | |||||
| if height > width: | |||||
| long_side = height | |||||
| short_side = width | |||||
| long_ratio = float(long_side) / 1024.0 | |||||
| short_ratio = float(short_side) / float(m_min_width) | |||||
| else: | |||||
| long_side = width | |||||
| short_side = height | |||||
| long_ratio = float(long_side) / 1024.0 | |||||
| short_ratio = float(short_side) / float(m_min_height) | |||||
| if long_side == height: | |||||
| if long_ratio < short_ratio: | |||||
| height_new = 1024 | |||||
| width_new = (int)((1024 * width) / height) | |||||
| img_res = cv2.resize(img, (width_new, height_new), | |||||
| cv2.INTER_LINEAR) | |||||
| else: | |||||
| height_new = (int)((m_min_width * height) / width) | |||||
| width_new = m_min_width | |||||
| img_res = cv2.resize(img, (width_new, height_new), | |||||
| cv2.INTER_LINEAR) | |||||
| elif long_side == width: | |||||
| if long_ratio < short_ratio: | |||||
| height_new = (int)((1024 * height) / width) | |||||
| width_new = 1024 | |||||
| img_res = cv2.resize(img, (width_new, height_new), | |||||
| cv2.INTER_LINEAR) | |||||
| else: | |||||
| width_new = (int)((m_min_height * width) / height) | |||||
| height_new = m_min_height | |||||
| img_res = cv2.resize(img, (width_new, height_new), | |||||
| cv2.INTER_LINEAR) | |||||
| else: | |||||
| img_res = img | |||||
| return img_res | |||||
| def preprocess(self, image, input_size, swap=(2, 0, 1)): | |||||
| """ | |||||
| Args: | |||||
| image, cv2 image with BGR format | |||||
| input_size, model input size | |||||
| """ | |||||
| if len(image.shape) == 3: | |||||
| padded_img = np.ones((input_size[0], input_size[1], 3)) * 114.0 | |||||
| else: | |||||
| padded_img = np.ones(input_size) * 114.0 | |||||
| img = np.array(image) | |||||
| r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) | |||||
| resized_img = cv2.resize( | |||||
| img, | |||||
| (int(img.shape[1] * r), int(img.shape[0] * r)), | |||||
| interpolation=cv2.INTER_LINEAR, | |||||
| ).astype(np.float32) | |||||
| padded_img[:int(img.shape[0] * r), :int(img.shape[1] | |||||
| * r)] = resized_img | |||||
| padded_img = padded_img.transpose(swap) | |||||
| padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) | |||||
| return padded_img, r | |||||
| def cal_iou(self, val1, val2): | |||||
| x11, y11, x12, y12 = val1 | |||||
| x21, y21, x22, y22 = val2 | |||||
| leftX = max(x11, x21) | |||||
| topY = max(y11, y21) | |||||
| rightX = min(x12, x22) | |||||
| bottomY = min(y12, y22) | |||||
| if rightX < leftX or bottomY < topY: | |||||
| return 0 | |||||
| area = float((rightX - leftX) * (bottomY - topY)) | |||||
| barea = (x12 - x11) * (y12 - y11) + (x22 - x21) * (y22 - y21) - area | |||||
| if barea <= 0: | |||||
| return 0 | |||||
| return area / barea | |||||
| def nms(self, boxes, scores, nms_thr): | |||||
| """ | |||||
| Single class NMS implemented in Numpy. | |||||
| """ | |||||
| x1 = boxes[:, 0] | |||||
| y1 = boxes[:, 1] | |||||
| x2 = boxes[:, 2] | |||||
| y2 = boxes[:, 3] | |||||
| areas = (x2 - x1 + 1) * (y2 - y1 + 1) | |||||
| order = scores.argsort()[::-1] | |||||
| keep = [] | |||||
| while order.size > 0: | |||||
| i = order[0] | |||||
| keep.append(i) | |||||
| xx1 = np.maximum(x1[i], x1[order[1:]]) | |||||
| yy1 = np.maximum(y1[i], y1[order[1:]]) | |||||
| xx2 = np.minimum(x2[i], x2[order[1:]]) | |||||
| yy2 = np.minimum(y2[i], y2[order[1:]]) | |||||
| w = np.maximum(0.0, xx2 - xx1 + 1) | |||||
| h = np.maximum(0.0, yy2 - yy1 + 1) | |||||
| inter = w * h | |||||
| ovr = inter / (areas[i] + areas[order[1:]] - inter) | |||||
| inds = np.where(ovr <= nms_thr)[0] | |||||
| order = order[inds + 1] | |||||
| return keep | |||||
| def multiclass_nms(self, boxes, scores, nms_thr, score_thr): | |||||
| """ | |||||
| Multiclass NMS implemented in Numpy | |||||
| """ | |||||
| final_dets = [] | |||||
| num_classes = scores.shape[1] | |||||
| for cls_ind in range(num_classes): | |||||
| cls_scores = scores[:, cls_ind] | |||||
| valid_score_mask = cls_scores > score_thr | |||||
| if valid_score_mask.sum() == 0: | |||||
| continue | |||||
| else: | |||||
| valid_scores = cls_scores[valid_score_mask] | |||||
| valid_boxes = boxes[valid_score_mask] | |||||
| keep = self.nms(valid_boxes, valid_scores, nms_thr) | |||||
| if len(keep) > 0: | |||||
| cls_inds = np.ones((len(keep), 1)) * cls_ind | |||||
| dets = np.concatenate([ | |||||
| valid_boxes[keep], valid_scores[keep, None], cls_inds | |||||
| ], 1) | |||||
| final_dets.append(dets) | |||||
| if len(final_dets) == 0: | |||||
| return None | |||||
| return np.concatenate(final_dets, 0) | |||||
| def postprocess(self, outputs, img_size, p6=False): | |||||
| grids = [] | |||||
| expanded_strides = [] | |||||
| if not p6: | |||||
| strides = [8, 16, 32] | |||||
| else: | |||||
| strides = [8, 16, 32, 64] | |||||
| hsizes = [img_size[0] // stride for stride in strides] | |||||
| wsizes = [img_size[1] // stride for stride in strides] | |||||
| for hsize, wsize, stride in zip(hsizes, wsizes, strides): | |||||
| xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize)) | |||||
| grid = np.stack((xv, yv), 2).reshape(1, -1, 2) | |||||
| grids.append(grid) | |||||
| shape = grid.shape[:2] | |||||
| expanded_strides.append(np.full((*shape, 1), stride)) | |||||
| grids = np.concatenate(grids, 1) | |||||
| expanded_strides = np.concatenate(expanded_strides, 1) | |||||
| outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides | |||||
| outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides | |||||
| return outputs | |||||
| def get_new_box_order(self, bboxes, labels, img_h, img_w): | |||||
| """ | |||||
| refine bbox score | |||||
| """ | |||||
| bboxes = np.hstack((bboxes, np.zeros((bboxes.shape[0], 1)))) | |||||
| scores = bboxes[:, 4] | |||||
| order = scores.argsort()[::-1] | |||||
| bboxes_temp = bboxes[order] | |||||
| labels_temp = labels[order] | |||||
| bboxes = np.empty((0, 6)) | |||||
| # import pdb;pdb.set_trace() | |||||
| bboxes = np.vstack((bboxes, bboxes_temp[0].tolist())) | |||||
| labels = np.empty((0, )) | |||||
| labels = np.hstack((labels, [labels_temp[0]])) | |||||
| for i in range(1, bboxes_temp.shape[0]): | |||||
| iou_max = 0 | |||||
| for j in range(bboxes.shape[0]): | |||||
| iou_temp = self.cal_iou(bboxes_temp[i][:4], bboxes[j][:4]) | |||||
| if (iou_temp > iou_max): | |||||
| iou_max = iou_temp | |||||
| if (iou_max < 0.45): | |||||
| bboxes = np.vstack((bboxes, bboxes_temp[i].tolist())) | |||||
| labels = np.hstack((labels, [labels_temp[i]])) | |||||
| num_03 = scores > 0.3 | |||||
| num_03 = num_03.sum() | |||||
| num_out = max(num_03, 1) | |||||
| bboxes = bboxes[:num_out, :] | |||||
| labels = labels[:num_out] | |||||
| return bboxes, labels | |||||
| def forward(self, img_input, cid='0', sub_class=False): | |||||
| """ | |||||
| forward for product detection | |||||
| """ | |||||
| input_shape = self.img_shape | |||||
| img, ratio = self.preprocess(img_input, input_shape) | |||||
| img_h, img_w = img_input.shape[:2] | |||||
| ort_inputs = { | |||||
| self.ort_session.get_inputs()[0].name: img[None, :, :, :] | |||||
| } | |||||
| output = self.ort_session.run(None, ort_inputs) | |||||
| predictions = self.postprocess(output[0], input_shape, self.with_p6)[0] | |||||
| boxes = predictions[:, :4] | |||||
| scores = predictions[:, 4:5] * predictions[:, 5:] | |||||
| boxes_xyxy = np.ones_like(boxes) | |||||
| boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2] / 2. | |||||
| boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3] / 2. | |||||
| boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2] / 2. | |||||
| boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3] / 2. | |||||
| boxes_xyxy /= ratio | |||||
| dets = self.multiclass_nms( | |||||
| boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1) | |||||
| if dets is None: | |||||
| top1_bbox_str = str(0) + ',' + str(img_w) + ',' + str( | |||||
| 0) + ',' + str(img_h) | |||||
| crop_img = img_input.copy() | |||||
| coord = top1_bbox_str | |||||
| else: | |||||
| bboxes = dets[:, :5] | |||||
| labels = dets[:, 5] | |||||
| if not self.multi_detect: | |||||
| cid = int(cid) | |||||
| if (not sub_class): | |||||
| if cid > -1: | |||||
| if cid == 0: # cloth | |||||
| cid_ind1 = np.where(labels < 3) | |||||
| cid_ind2 = np.where(labels == 9) | |||||
| cid_ind = np.hstack((cid_ind1[0], cid_ind2[0])) | |||||
| scores = bboxes[cid_ind, -1] # 0, 1, 2, 9 | |||||
| if scores.size > 0: | |||||
| bboxes = bboxes[cid_ind] | |||||
| labels = labels[cid_ind] | |||||
| bboxes, labels = self.get_new_box_order( | |||||
| bboxes, labels, img_h, img_w) | |||||
| elif cid == 3: # bag | |||||
| cid_ind = np.where(labels == 3) | |||||
| scores = bboxes[cid_ind, -1] # 3 | |||||
| if scores.size > 0: | |||||
| bboxes = bboxes[cid_ind] | |||||
| labels = labels[cid_ind] | |||||
| bboxes, labels = self.get_new_box_order( | |||||
| bboxes, labels, img_h, img_w) | |||||
| elif cid == 4: # shoe | |||||
| cid_ind = np.where(labels == 4) | |||||
| scores = bboxes[cid_ind, -1] # 4 | |||||
| if scores.size > 0: | |||||
| bboxes = bboxes[cid_ind] | |||||
| labels = labels[cid_ind] | |||||
| bboxes, labels = self.get_new_box_order( | |||||
| bboxes, labels, img_h, img_w) | |||||
| else: # other | |||||
| cid_ind5 = np.where(labels == 5) | |||||
| cid_ind6 = np.where(labels == 6) | |||||
| cid_ind7 = np.where(labels == 7) | |||||
| cid_ind8 = np.where(labels == 8) | |||||
| cid_ind10 = np.where(labels == 10) | |||||
| cid_ind11 = np.where(labels == 11) | |||||
| cid_ind12 = np.where(labels == 12) | |||||
| cid_ind = np.hstack( | |||||
| (cid_ind5[0], cid_ind6[0], cid_ind7[0], | |||||
| cid_ind8[0], cid_ind10[0], cid_ind11[0], | |||||
| cid_ind12[0])) | |||||
| scores = bboxes[cid_ind, -1] # 5,6,7,8,10,11,12 | |||||
| if scores.size > 0: | |||||
| bboxes = bboxes[cid_ind] | |||||
| labels = labels[cid_ind] | |||||
| bboxes, labels = self.get_new_box_order( | |||||
| bboxes, labels, img_h, img_w) | |||||
| else: | |||||
| bboxes, labels = self.get_new_box_order( | |||||
| bboxes, labels, img_h, img_w) | |||||
| else: | |||||
| if cid > -1: | |||||
| if cid == 0: # upper | |||||
| cid_ind = np.where(labels == 0) | |||||
| scores = bboxes[cid_ind, -1] # 0, 1, 2, 9 | |||||
| if scores.size > 0: | |||||
| bboxes = bboxes[cid_ind] | |||||
| labels = labels[cid_ind] | |||||
| bboxes, labels = self.get_new_box_order( | |||||
| bboxes, labels, img_h, img_w) | |||||
| elif cid == 1: # skirt | |||||
| cid_ind = np.where(labels == 1) | |||||
| scores = bboxes[cid_ind, -1] # 3 | |||||
| if scores.size > 0: | |||||
| bboxes = bboxes[cid_ind] | |||||
| labels = labels[cid_ind] | |||||
| bboxes, labels = self.get_new_box_order( | |||||
| bboxes, labels, img_h, img_w) | |||||
| elif cid == 2: # lower | |||||
| cid_ind = np.where(labels == 2) | |||||
| scores = bboxes[cid_ind, -1] # 3 | |||||
| if scores.size > 0: | |||||
| bboxes = bboxes[cid_ind] | |||||
| labels = labels[cid_ind] | |||||
| bboxes, labels = self.get_new_box_order( | |||||
| bboxes, labels, img_h, img_w) | |||||
| elif cid == 3: # bag | |||||
| cid_ind = np.where(labels == 3) | |||||
| scores = bboxes[cid_ind, -1] # 3 | |||||
| if scores.size > 0: | |||||
| bboxes = bboxes[cid_ind] | |||||
| labels = labels[cid_ind] | |||||
| bboxes, labels = self.get_new_box_order( | |||||
| bboxes, labels, img_h, img_w) | |||||
| elif cid == 4: # shoe | |||||
| cid_ind = np.where(labels == 4) | |||||
| scores = bboxes[cid_ind, -1] # 4 | |||||
| if scores.size > 0: | |||||
| bboxes = bboxes[cid_ind] | |||||
| labels = labels[cid_ind] | |||||
| bboxes, labels = self.get_new_box_order( | |||||
| bboxes, labels, img_h, img_w) | |||||
| elif cid == 5: # access | |||||
| cid_ind = np.where(labels == 5) | |||||
| scores = bboxes[cid_ind, -1] # 3 | |||||
| if scores.size > 0: | |||||
| bboxes = bboxes[cid_ind] | |||||
| labels = labels[cid_ind] | |||||
| bboxes, labels = self.get_new_box_order( | |||||
| bboxes, labels, img_h, img_w) | |||||
| elif cid == 7: # beauty | |||||
| cid_ind = np.where(labels == 6) | |||||
| scores = bboxes[cid_ind, -1] # 3 | |||||
| if scores.size > 0: | |||||
| bboxes = bboxes[cid_ind] | |||||
| labels = labels[cid_ind] | |||||
| bboxes, labels = self.get_new_box_order( | |||||
| bboxes, labels, img_h, img_w) | |||||
| elif cid == 9: # furniture | |||||
| cid_ind = np.where(labels == 8) | |||||
| scores = bboxes[cid_ind, -1] # 3 | |||||
| if scores.size > 0: | |||||
| bboxes = bboxes[cid_ind] | |||||
| labels = labels[cid_ind] | |||||
| bboxes, labels = self.get_new_box_order( | |||||
| bboxes, labels, img_h, img_w) | |||||
| elif cid == 21: # underwear | |||||
| cid_ind = np.where(labels == 9) | |||||
| scores = bboxes[cid_ind, -1] # 3 | |||||
| if scores.size > 0: | |||||
| bboxes = bboxes[cid_ind] | |||||
| labels = labels[cid_ind] | |||||
| bboxes, labels = self.get_new_box_order( | |||||
| bboxes, labels, img_h, img_w) | |||||
| elif cid == 22: # digital | |||||
| cid_ind = np.where(labels == 11) | |||||
| scores = bboxes[cid_ind, -1] # 3 | |||||
| if scores.size > 0: | |||||
| bboxes = bboxes[cid_ind] | |||||
| labels = labels[cid_ind] | |||||
| bboxes, labels = self.get_new_box_order( | |||||
| bboxes, labels, img_h, img_w) | |||||
| else: # other | |||||
| cid_ind5 = np.where(labels == 7) # bottle | |||||
| cid_ind6 = np.where(labels == 10) # toy | |||||
| cid_ind7 = np.where(labels == 12) # toy | |||||
| cid_ind = np.hstack( | |||||
| (cid_ind5[0], cid_ind6[0], cid_ind7[0])) | |||||
| scores = bboxes[cid_ind, -1] # 5,6,7 | |||||
| if scores.size > 0: | |||||
| bboxes = bboxes[cid_ind] | |||||
| labels = labels[cid_ind] | |||||
| bboxes, labels = self.get_new_box_order( | |||||
| bboxes, labels, img_h, img_w) | |||||
| else: | |||||
| bboxes, labels = self.get_new_box_order( | |||||
| bboxes, labels, img_h, img_w) | |||||
| else: | |||||
| bboxes, labels = self.get_new_box_order( | |||||
| bboxes, labels, img_h, img_w) | |||||
| top1_bbox = bboxes[0].astype(np.int32) | |||||
| top1_bbox[0] = min(max(0, top1_bbox[0]), img_input.shape[1] - 1) | |||||
| top1_bbox[1] = min(max(0, top1_bbox[1]), img_input.shape[0] - 1) | |||||
| top1_bbox[2] = max(min(img_input.shape[1] - 1, top1_bbox[2]), 0) | |||||
| top1_bbox[3] = max(min(img_input.shape[0] - 1, top1_bbox[3]), 0) | |||||
| if not self.multi_detect: | |||||
| top1_bbox_str = str(top1_bbox[0]) + ',' + str( | |||||
| top1_bbox[2]) + ',' + str(top1_bbox[1]) + ',' + str( | |||||
| top1_bbox[3]) # x1, x2, y1, y2 | |||||
| crop_img = img_input[top1_bbox[1]:top1_bbox[3], | |||||
| top1_bbox[0]:top1_bbox[2], :] | |||||
| coord = top1_bbox_str | |||||
| coord = '' | |||||
| for i in range(0, len(bboxes)): | |||||
| top_bbox = bboxes[i].astype(np.int32) | |||||
| top_bbox[0] = min( | |||||
| max(0, top_bbox[0]), img_input.shape[1] - 1) | |||||
| top_bbox[1] = min( | |||||
| max(0, top_bbox[1]), img_input.shape[0] - 1) | |||||
| top_bbox[2] = max( | |||||
| min(img_input.shape[1] - 1, top_bbox[2]), 0) | |||||
| top_bbox[3] = max( | |||||
| min(img_input.shape[0] - 1, top_bbox[3]), 0) | |||||
| coord = coord + str(top_bbox[0]) + ',' + str( | |||||
| top_bbox[2]) + ',' + str(top_bbox[1]) + ',' + str( | |||||
| top_bbox[3]) + ',' + str(bboxes[i][4]) + ',' + str( | |||||
| bboxes[i][5]) + ';' | |||||
| else: | |||||
| coord = '' | |||||
| for i in range(0, len(bboxes)): | |||||
| top_bbox = bboxes[i].astype(np.int32) | |||||
| top_bbox[0] = min( | |||||
| max(0, top_bbox[0]), img_input.shape[1] - 1) | |||||
| top_bbox[1] = min( | |||||
| max(0, top_bbox[1]), img_input.shape[0] - 1) | |||||
| top_bbox[2] = max( | |||||
| min(img_input.shape[1] - 1, top_bbox[2]), 0) | |||||
| top_bbox[3] = max( | |||||
| min(img_input.shape[0] - 1, top_bbox[3]), 0) | |||||
| coord = coord + str(top_bbox[0]) + ',' + str( | |||||
| top_bbox[2]) + ',' + str(top_bbox[1]) + ',' + str( | |||||
| top_bbox[3]) + ',' + str(bboxes[i][4]) + ',' + str( | |||||
| bboxes[i][5]) + ';' # x1, x2, y1, y2, conf | |||||
| crop_img = img_input[top1_bbox[1]:top1_bbox[3], | |||||
| top1_bbox[0]:top1_bbox[2], :] | |||||
| crop_img = cv2.resize(crop_img, (224, 224)) | |||||
| return coord, crop_img # return top1 image and coord | |||||
| @@ -0,0 +1,157 @@ | |||||
| import os | |||||
| import time | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| def gn_init(m, zero_init=False): | |||||
| assert isinstance(m, nn.GroupNorm) | |||||
| m.weight.data.fill_(0. if zero_init else 1.) | |||||
| m.bias.data.zero_() | |||||
| class Bottleneck(nn.Module): | |||||
| expansion = 4 | |||||
| def __init__(self, inplanes, planes, stride=1, downsample=None): | |||||
| """Bottleneck for resnet-style networks | |||||
| Args: | |||||
| inplanes: input channel number | |||||
| planes: output channel number | |||||
| """ | |||||
| super(Bottleneck, self).__init__() | |||||
| self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) | |||||
| self.bn1 = nn.GroupNorm(32, planes) | |||||
| self.conv2 = nn.Conv2d( | |||||
| planes, | |||||
| planes, | |||||
| kernel_size=3, | |||||
| stride=stride, | |||||
| padding=1, | |||||
| bias=False) | |||||
| self.bn2 = nn.GroupNorm(32, planes) | |||||
| self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) | |||||
| self.bn3 = nn.GroupNorm(32, planes * 4) | |||||
| self.relu = nn.ReLU(inplace=True) | |||||
| self.downsample = downsample | |||||
| self.stride = stride | |||||
| gn_init(self.bn1) | |||||
| gn_init(self.bn2) | |||||
| gn_init(self.bn3, zero_init=True) | |||||
| def forward(self, x): | |||||
| residual = x | |||||
| out = self.conv1(x) | |||||
| out = self.bn1(out) | |||||
| out = self.relu(out) | |||||
| out = self.conv2(out) | |||||
| out = self.bn2(out) | |||||
| out = self.relu(out) | |||||
| out = self.conv3(out) | |||||
| out = self.bn3(out) | |||||
| if self.downsample is not None: | |||||
| residual = self.downsample(x) | |||||
| out += residual | |||||
| out = self.relu(out) | |||||
| return out | |||||
| class ResNet(nn.Module): | |||||
| """ | |||||
| resnet-style network with group normalization | |||||
| """ | |||||
| def __init__(self, block, layers, num_classes=1000): | |||||
| self.inplanes = 64 | |||||
| super(ResNet, self).__init__() | |||||
| self.conv1 = nn.Conv2d( | |||||
| 3, 64, kernel_size=7, stride=2, padding=3, bias=False) | |||||
| self.bn1 = nn.GroupNorm(32, 64) | |||||
| self.relu = nn.ReLU(inplace=True) | |||||
| self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |||||
| self.layer1 = self._make_layer(block, 64, layers[0]) | |||||
| self.layer2 = self._make_layer(block, 128, layers[1], stride=2) | |||||
| self.layer3 = self._make_layer(block, 256, layers[2], stride=2) | |||||
| self.layer4 = self._make_layer(block, 512, layers[3], stride=1) | |||||
| self.gap = nn.AvgPool2d((14, 14)) | |||||
| self.reduce_conv = nn.Conv2d(2048, 512, kernel_size=1) | |||||
| gn_init(self.bn1) | |||||
| def _make_layer(self, block, planes, blocks, stride=1): | |||||
| downsample = None | |||||
| if stride != 1 or self.inplanes != planes * block.expansion: | |||||
| downsample = nn.Sequential( | |||||
| nn.AvgPool2d(stride, stride), | |||||
| nn.Conv2d( | |||||
| self.inplanes, | |||||
| planes * block.expansion, | |||||
| kernel_size=1, | |||||
| stride=1, | |||||
| bias=False), | |||||
| nn.GroupNorm(32, planes * block.expansion), | |||||
| ) | |||||
| layers = [] | |||||
| layers.append(block(self.inplanes, planes, stride, downsample)) | |||||
| self.inplanes = planes * block.expansion | |||||
| for i in range(1, blocks): | |||||
| layers.append(block(self.inplanes, planes)) | |||||
| return nn.Sequential(*layers) | |||||
| def forward(self, x): | |||||
| x = self.conv1(x) | |||||
| x = self.bn1(x) | |||||
| x = self.relu(x) | |||||
| x = self.maxpool(x) | |||||
| x = self.layer1(x) | |||||
| x = self.layer2(x) | |||||
| x = self.layer3(x) | |||||
| x = self.layer4(x) | |||||
| x = self.gap(x) | |||||
| x = self.reduce_conv(x) # 512 | |||||
| x = x.view(x.size(0), -1) # 512 | |||||
| return F.normalize(x, p=2, dim=1) | |||||
| def preprocess(img): | |||||
| """ | |||||
| preprocess the image with cv2-bgr style to tensor | |||||
| """ | |||||
| mean = np.array([0.485, 0.456, 0.406]) | |||||
| std = np.array([0.229, 0.224, 0.225]) | |||||
| img_size = 224 | |||||
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |||||
| img_new = cv2.resize( | |||||
| img, (img_size, img_size), interpolation=cv2.INTER_LINEAR) | |||||
| content = np.array(img_new).astype(np.float32) | |||||
| content = (content / 255.0 - mean) / std | |||||
| # transpose | |||||
| img_new = content.transpose(2, 0, 1) | |||||
| img_new = img_new[np.newaxis, :, :, :] | |||||
| return img_new | |||||
| def resnet50_embed(): | |||||
| """ | |||||
| create resnet50 network with group normalization | |||||
| """ | |||||
| net = ResNet(Bottleneck, [3, 4, 6, 3]) | |||||
| return net | |||||
| @@ -0,0 +1,115 @@ | |||||
| import os.path as osp | |||||
| from typing import Any, Dict | |||||
| import numpy as np | |||||
| import torch | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.base import TorchModel | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.models.cv.product_retrieval_embedding.item_detection import \ | |||||
| YOLOXONNX | |||||
| from modelscope.models.cv.product_retrieval_embedding.item_embedding import ( | |||||
| preprocess, resnet50_embed) | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| from modelscope.utils.torch_utils import create_device | |||||
| logger = get_logger() | |||||
| __all__ = ['ProductRetrievalEmbedding'] | |||||
| @MODELS.register_module( | |||||
| Tasks.product_retrieval_embedding, | |||||
| module_name=Models.product_retrieval_embedding) | |||||
| class ProductRetrievalEmbedding(TorchModel): | |||||
| def __init__(self, model_dir, device='cpu', **kwargs): | |||||
| super().__init__(model_dir=model_dir, device=device, **kwargs) | |||||
| def filter_param(src_params, own_state): | |||||
| copied_keys = [] | |||||
| for name, param in src_params.items(): | |||||
| if 'module.' == name[0:7]: | |||||
| name = name[7:] | |||||
| if '.module.' not in list(own_state.keys())[0]: | |||||
| name = name.replace('.module.', '.') | |||||
| if (name in own_state) and (own_state[name].shape | |||||
| == param.shape): | |||||
| own_state[name].copy_(param) | |||||
| copied_keys.append(name) | |||||
| def load_pretrained(model, src_params): | |||||
| if 'state_dict' in src_params: | |||||
| src_params = src_params['state_dict'] | |||||
| own_state = model.state_dict() | |||||
| filter_param(src_params, own_state) | |||||
| model.load_state_dict(own_state) | |||||
| cpu_flag = device == 'cpu' | |||||
| self.device = create_device( | |||||
| cpu_flag) # device.type == "cpu" or device.type == "cuda" | |||||
| self.use_gpu = self.device.type == 'cuda' | |||||
| # config the model path | |||||
| self.local_model_dir = model_dir | |||||
| # init feat model | |||||
| self.preprocess_for_embed = preprocess # input is cv2 bgr format | |||||
| model_feat = resnet50_embed() | |||||
| src_params = torch.load( | |||||
| osp.join(self.local_model_dir, ModelFile.TORCH_MODEL_BIN_FILE), | |||||
| 'cpu') | |||||
| load_pretrained(model_feat, src_params) | |||||
| if self.use_gpu: | |||||
| model_feat.to(self.device) | |||||
| logger.info('Use GPU: {}'.format(self.device)) | |||||
| else: | |||||
| logger.info('Use CPU for inference') | |||||
| self.model_feat = model_feat | |||||
| # init det model | |||||
| self.model_det = YOLOXONNX( | |||||
| onnx_path=osp.join(self.local_model_dir, 'onnx_detection.onnx'), | |||||
| multi_detect=False) | |||||
| logger.info('load model done') | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
| """ | |||||
| detection and feature extraction for input product image | |||||
| """ | |||||
| # input should be cv2 bgr format | |||||
| assert 'img' in input.keys() | |||||
| def set_phase(model, is_train): | |||||
| if is_train: | |||||
| model.train() | |||||
| else: | |||||
| model.eval() | |||||
| is_train = False | |||||
| set_phase(self.model_feat, is_train) | |||||
| img = input['img'] # for detection | |||||
| cid = '3' # preprocess detection category bag | |||||
| # transform img(tensor) to numpy array with bgr | |||||
| if isinstance(img, torch.Tensor): | |||||
| img = img.data.cpu().numpy() | |||||
| res, crop_img = self.model_det.forward(img, | |||||
| cid) # detect with bag category | |||||
| crop_img = self.preprocess_for_embed(crop_img) # feat preprocess | |||||
| input_tensor = torch.from_numpy(crop_img.astype(np.float32)) | |||||
| device = next(self.model_feat.parameters()).device | |||||
| use_gpu = device.type == 'cuda' | |||||
| with torch.no_grad(): | |||||
| if use_gpu: | |||||
| input_tensor = input_tensor.to(device) | |||||
| out_embedding = self.model_feat(input_tensor) | |||||
| out_embedding = out_embedding.cpu().numpy()[ | |||||
| 0, :] # feature array with 512 elements | |||||
| output = {OutputKeys.IMG_EMBEDDING: None} | |||||
| output[OutputKeys.IMG_EMBEDDING] = out_embedding | |||||
| return output | |||||
| @@ -140,6 +140,12 @@ TASK_OUTPUTS = { | |||||
| # } | # } | ||||
| Tasks.ocr_detection: [OutputKeys.POLYGONS], | Tasks.ocr_detection: [OutputKeys.POLYGONS], | ||||
| # image embedding result for a single image | |||||
| # { | |||||
| # "image_bedding": np.array with shape [D] | |||||
| # } | |||||
| Tasks.product_retrieval_embedding: [OutputKeys.IMG_EMBEDDING], | |||||
| # video embedding result for single video | # video embedding result for single video | ||||
| # { | # { | ||||
| # "video_embedding": np.array with shape [D], | # "video_embedding": np.array with shape [D], | ||||
| @@ -109,6 +109,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| 'damo/cv_gan_face-image-generation'), | 'damo/cv_gan_face-image-generation'), | ||||
| Tasks.image_super_resolution: (Pipelines.image_super_resolution, | Tasks.image_super_resolution: (Pipelines.image_super_resolution, | ||||
| 'damo/cv_rrdb_image-super-resolution'), | 'damo/cv_rrdb_image-super-resolution'), | ||||
| Tasks.product_retrieval_embedding: | |||||
| (Pipelines.product_retrieval_embedding, | |||||
| 'damo/cv_resnet50_product-bag-embedding-models'), | |||||
| Tasks.image_classification_imagenet: | Tasks.image_classification_imagenet: | ||||
| (Pipelines.general_image_classification, | (Pipelines.general_image_classification, | ||||
| 'damo/cv_vit-base_image-classification_ImageNet-labels'), | 'damo/cv_vit-base_image-classification_ImageNet-labels'), | ||||
| @@ -11,6 +11,7 @@ if TYPE_CHECKING: | |||||
| from .face_detection_pipeline import FaceDetectionPipeline | from .face_detection_pipeline import FaceDetectionPipeline | ||||
| from .face_recognition_pipeline import FaceRecognitionPipeline | from .face_recognition_pipeline import FaceRecognitionPipeline | ||||
| from .face_image_generation_pipeline import FaceImageGenerationPipeline | from .face_image_generation_pipeline import FaceImageGenerationPipeline | ||||
| from .image_classification_pipeline import ImageClassificationPipeline | |||||
| from .image_cartoon_pipeline import ImageCartoonPipeline | from .image_cartoon_pipeline import ImageCartoonPipeline | ||||
| from .image_classification_pipeline import GeneralImageClassificationPipeline | from .image_classification_pipeline import GeneralImageClassificationPipeline | ||||
| from .image_denoise_pipeline import ImageDenoisePipeline | from .image_denoise_pipeline import ImageDenoisePipeline | ||||
| @@ -20,12 +21,12 @@ if TYPE_CHECKING: | |||||
| from .image_matting_pipeline import ImageMattingPipeline | from .image_matting_pipeline import ImageMattingPipeline | ||||
| from .image_super_resolution_pipeline import ImageSuperResolutionPipeline | from .image_super_resolution_pipeline import ImageSuperResolutionPipeline | ||||
| from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline | from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline | ||||
| from .product_retrieval_embedding_pipeline import ProductRetrievalEmbeddingPipeline | |||||
| from .style_transfer_pipeline import StyleTransferPipeline | from .style_transfer_pipeline import StyleTransferPipeline | ||||
| from .live_category_pipeline import LiveCategoryPipeline | from .live_category_pipeline import LiveCategoryPipeline | ||||
| from .ocr_detection_pipeline import OCRDetectionPipeline | from .ocr_detection_pipeline import OCRDetectionPipeline | ||||
| from .video_category_pipeline import VideoCategoryPipeline | from .video_category_pipeline import VideoCategoryPipeline | ||||
| from .virtual_tryon_pipeline import VirtualTryonPipeline | from .virtual_tryon_pipeline import VirtualTryonPipeline | ||||
| from .image_classification_pipeline import ImageClassificationPipeline | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| 'action_recognition_pipeline': ['ActionRecognitionPipeline'], | 'action_recognition_pipeline': ['ActionRecognitionPipeline'], | ||||
| @@ -47,6 +48,8 @@ else: | |||||
| 'image_super_resolution_pipeline': ['ImageSuperResolutionPipeline'], | 'image_super_resolution_pipeline': ['ImageSuperResolutionPipeline'], | ||||
| 'image_to_image_translation_pipeline': | 'image_to_image_translation_pipeline': | ||||
| ['Image2ImageTranslationPipeline'], | ['Image2ImageTranslationPipeline'], | ||||
| 'product_retrieval_embedding_pipeline': | |||||
| ['ProductRetrievalEmbeddingPipeline'], | |||||
| 'live_category_pipeline': ['LiveCategoryPipeline'], | 'live_category_pipeline': ['LiveCategoryPipeline'], | ||||
| 'ocr_detection_pipeline': ['OCRDetectionPipeline'], | 'ocr_detection_pipeline': ['OCRDetectionPipeline'], | ||||
| 'style_transfer_pipeline': ['StyleTransferPipeline'], | 'style_transfer_pipeline': ['StyleTransferPipeline'], | ||||
| @@ -0,0 +1,45 @@ | |||||
| import os.path as osp | |||||
| from typing import Any, Dict | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import torch | |||||
| from PIL import Image | |||||
| from torchvision import transforms | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.pipelines.base import Input, Model, Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.preprocessors import LoadImage | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| @PIPELINES.register_module( | |||||
| Tasks.product_retrieval_embedding, | |||||
| module_name=Pipelines.product_retrieval_embedding) | |||||
| class ProductRetrievalEmbeddingPipeline(Pipeline): | |||||
| def __init__(self, model: str, **kwargs): | |||||
| """use `model` to create a pipeline for prediction | |||||
| Args: | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| super().__init__(model=model, **kwargs) | |||||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||||
| """ | |||||
| preprocess the input image to cv2-bgr style | |||||
| """ | |||||
| img = LoadImage.convert_to_ndarray(input) # array with rgb | |||||
| img = np.ascontiguousarray(img[:, :, ::-1]) # array with bgr | |||||
| result = {'img': img} # only for detection | |||||
| return result | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
| return self.model(input) | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| return inputs | |||||
| @@ -37,6 +37,7 @@ class CVTasks(object): | |||||
| face_image_generation = 'face-image-generation' | face_image_generation = 'face-image-generation' | ||||
| image_super_resolution = 'image-super-resolution' | image_super_resolution = 'image-super-resolution' | ||||
| style_transfer = 'style-transfer' | style_transfer = 'style-transfer' | ||||
| product_retrieval_embedding = 'product-retrieval-embedding' | |||||
| live_category = 'live-category' | live_category = 'live-category' | ||||
| video_category = 'video-category' | video_category = 'video-category' | ||||
| image_classification_imagenet = 'image-classification-imagenet' | image_classification_imagenet = 'image-classification-imagenet' | ||||
| @@ -1,5 +1,8 @@ | |||||
| decord>=0.6.0 | decord>=0.6.0 | ||||
| easydict | easydict | ||||
| # tensorflow 1.x compatability requires numpy version to be cap at 1.18 | |||||
| numpy<=1.18 | |||||
| onnxruntime>=1.10 | |||||
| tf_slim | tf_slim | ||||
| timm | timm | ||||
| torchvision | torchvision | ||||
| @@ -4,7 +4,8 @@ easydict | |||||
| einops | einops | ||||
| filelock>=3.3.0 | filelock>=3.3.0 | ||||
| gast>=0.2.2 | gast>=0.2.2 | ||||
| numpy | |||||
| # tensorflow 1.x compatability requires numpy version to be cap at 1.18 | |||||
| numpy<=1.18 | |||||
| opencv-python | opencv-python | ||||
| oss2 | oss2 | ||||
| Pillow>=6.2.0 | Pillow>=6.2.0 | ||||
| @@ -0,0 +1,39 @@ | |||||
| import unittest | |||||
| import numpy as np | |||||
| from modelscope.models import Model | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class ProductRetrievalEmbeddingTest(unittest.TestCase): | |||||
| model_id = 'damo/cv_resnet50_product-bag-embedding-models' | |||||
| img_input = 'data/test/images/product_embed_bag.jpg' | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_with_model_name(self): | |||||
| product_embed = pipeline(Tasks.product_retrieval_embedding, | |||||
| self.model_id) | |||||
| result = product_embed(self.img_input)[OutputKeys.IMG_EMBEDDING] | |||||
| print('abs sum value is: {}'.format(np.sum(np.abs(result)))) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_with_model_from_modelhub(self): | |||||
| model = Model.from_pretrained(self.model_id) | |||||
| product_embed = pipeline( | |||||
| task=Tasks.product_retrieval_embedding, model=model) | |||||
| result = product_embed(self.img_input)[OutputKeys.IMG_EMBEDDING] | |||||
| print('abs sum value is: {}'.format(np.sum(np.abs(result)))) | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_with_default_model(self): | |||||
| product_embed = pipeline(task=Tasks.product_retrieval_embedding) | |||||
| result = product_embed(self.img_input)[OutputKeys.IMG_EMBEDDING] | |||||
| print('abs sum value is: {}'.format(np.sum(np.abs(result)))) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||