|
- # Copyright 2020 Huawei Technologies Co., Ltd
-
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
-
- # http://www.apache.org/licenses/LICENSE-2.0
-
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- import os
- import math
- import random
- import numpy as np
- import cv2
- from pycocotools.coco import COCO as ReadJson
-
- import mindspore.dataset as de
-
- from src.config import JointType, params
-
- cv2.setNumThreads(0)
-
- class txtdataset():
- def __init__(self, train, imgpath, maskpath, insize, mode='train', n_samples=None):
- self.train = train
- self.mode = mode
- self.imgpath = imgpath
- self.maskpath = maskpath
- self.insize = insize
- self.maxtime = 0
- self.catIds = train.getCatIds(catNms=['person'])
- self.imgIds = sorted(train.getImgIds(catIds=self.catIds))
- if self.mode == 'train':
- self.clean_imgIds()
- if self.mode in ['val', 'eval'] and n_samples is not None:
- self.imgIds = random.sample(self.imgIds, n_samples)
- print('{} images: {}'.format(mode, len(self)))
-
-
- def __len__(self):
- return len(self.imgIds)
-
- def clean_imgIds(self):
- print("cleaning imgids")
-
- for img_id in self.imgIds.copy():
- annotations = None
- anno_ids = self.train.getAnnIds(imgIds=[img_id], iscrowd=None)
-
- # annotation for that image
- if anno_ids:
- annotations_for_img = self.train.loadAnns(anno_ids)
-
- person_cnt = 0
- valid_annotations_for_img = []
- for annotation in annotations_for_img:
- # if too few keypoints or too small
- if annotation['num_keypoints'] >= params['min_keypoints'] and \
- annotation['area'] > params['min_area']:
- person_cnt += 1
- valid_annotations_for_img.append(annotation)
-
- # if person annotation
- if person_cnt > 0:
- annotations = valid_annotations_for_img
- if annotations is None:
- self.imgIds.remove(img_id)
-
- def overlay_paf(self, img, paf):
- hue = ((np.arctan2(paf[1], paf[0]) / np.pi) / -2 + 0.5)
- saturation = np.sqrt(paf[0] ** 2 + paf[1] ** 2)
- saturation[saturation > 1.0] = 1.0
- value = saturation.copy()
- hsv_paf = np.vstack((hue[np.newaxis], saturation[np.newaxis], value[np.newaxis])).transpose(1, 2, 0)
- rgb_paf = cv2.cvtColor((hsv_paf * 255).astype(np.uint8), cv2.COLOR_HSV2BGR)
- img = cv2.addWeighted(img, 0.6, rgb_paf, 0.4, 0)
- return img
-
- def overlay_pafs(self, img, pafs):
- mix_paf = np.zeros((2,) + img.shape[:-1])
- paf_flags = np.zeros(mix_paf.shape) # for constant paf
-
- for paf in pafs.reshape((int(pafs.shape[0]/2), 2,) + pafs.shape[1:]):
- paf_flags = paf != 0
- paf_flags += np.broadcast_to(paf_flags[0] | paf_flags[1], paf.shape)
- mix_paf += paf
-
- mix_paf[paf_flags > 0] /= paf_flags[paf_flags > 0]
- img = self.overlay_paf(img, mix_paf)
- return img
-
- def overlay_heatmap(self, img, heatmap):
- rgb_heatmap = cv2.applyColorMap((heatmap * 255).astype(np.uint8), cv2.COLORMAP_JET)
- img = cv2.addWeighted(img, 0.6, rgb_heatmap, 0.4, 0)
- return img
-
- def overlay_ignore_mask(self, img, ignore_mask):
- img = img * np.repeat((ignore_mask == 0).astype(np.uint8)[:, :, None], 3, axis=2)
- return img
-
- # -------------------- augment code --------------------------------
- def get_pose_bboxes(self, poses):
- pose_bboxes = []
- for pose in poses:
- x1 = pose[pose[:, 2] > 0][:, 0].min()
- y1 = pose[pose[:, 2] > 0][:, 1].min()
- x2 = pose[pose[:, 2] > 0][:, 0].max()
- y2 = pose[pose[:, 2] > 0][:, 1].max()
- pose_bboxes.append([x1, y1, x2, y2])
- pose_bboxes = np.array(pose_bboxes)
- return pose_bboxes
-
- def resize_data(self, img, ignore_mask, poses, shape):
- """resize img, mask and annotations"""
- img_h, img_w, _ = img.shape
-
- resized_img = cv2.resize(img, shape)
- ignore_mask = cv2.resize(ignore_mask.astype(np.uint8), shape).astype('bool')
- poses[:, :, :2] = (poses[:, :, :2] * np.array(shape) / np.array((img_w, img_h)))
- return resized_img, ignore_mask, poses
-
- def random_resize_img(self, img, ignore_mask, poses):
- h, w, _ = img.shape
- joint_bboxes = self.get_pose_bboxes(poses)
- bbox_sizes = ((joint_bboxes[:, 2:] - joint_bboxes[:, :2] + 1) ** 2).sum(axis=1) ** 0.5
-
- min_scale = params['min_box_size'] / bbox_sizes.min()
- max_scale = params['max_box_size'] / bbox_sizes.max()
-
- min_scale = min(max(min_scale, params['min_scale']), 1)
- max_scale = min(max(max_scale, 1), params['max_scale'])
-
- scale = float((max_scale - min_scale) * random.random() + min_scale)
- shape = (round(w * scale), round(h * scale))
-
- resized_img, resized_mask, resized_poses = self.resize_data(img, ignore_mask, poses, shape)
- return resized_img, resized_mask, resized_poses
-
- def random_rotate_img(self, img, mask, poses):
- h, w, _ = img.shape
- degree = np.random.randn() / 3 * params['max_rotate_degree']
- rad = degree * math.pi / 180
- center = (w / 2, h / 2)
- R = cv2.getRotationMatrix2D(center, degree, 1)
- bbox = (w * abs(math.cos(rad)) + h * abs(math.sin(rad)), w * abs(math.sin(rad)) + h * abs(math.cos(rad)))
- R[0, 2] += bbox[0] / 2 - center[0]
- R[1, 2] += bbox[1] / 2 - center[1]
- rotate_img = cv2.warpAffine(img, R, (int(bbox[0]+0.5), int(bbox[1]+0.5)), flags=cv2.INTER_CUBIC,
- borderMode=cv2.BORDER_CONSTANT, borderValue=[127.5, 127.5, 127.5])
- rotate_mask = cv2.warpAffine(mask.astype('uint8')*255, R, (int(bbox[0]+0.5), int(bbox[1]+0.5))) > 0
-
- tmp_poses = np.ones_like(poses)
- tmp_poses[:, :, :2] = poses[:, :, :2].copy()
- tmp_rotate_poses = np.dot(tmp_poses, R.T) # apply rotation matrix to the poses
- rotate_poses = poses.copy() # to keep visibility flag
- rotate_poses[:, :, :2] = tmp_rotate_poses
- return rotate_img, rotate_mask, rotate_poses
-
- def random_crop_img(self, img, ignore_mask, poses):
- h, w, _ = img.shape
- insize = self.insize
- joint_bboxes = self.get_pose_bboxes(poses)
- bbox = random.choice(joint_bboxes) # select a bbox randomly
- bbox_center = bbox[:2] + (bbox[2:] - bbox[:2]) / 2
-
- r_xy = np.random.rand(2)
- perturb = ((r_xy - 0.5) * 2 * params['center_perterb_max'])
- center = (bbox_center + perturb + 0.5).astype('i')
-
- crop_img = np.zeros((insize, insize, 3), 'uint8') + 127.5
- crop_mask = np.zeros((insize, insize), 'bool')
-
- offset = (center - (insize - 1) / 2 + 0.5).astype('i')
- offset_ = (center + (insize - 1) / 2 - (w - 1, h - 1) + 0.5).astype('i')
-
- x1, y1 = (center - (insize-1)/2 + 0.5).astype('i')
- x2, y2 = (center + (insize-1)/2 + 0.5).astype('i')
-
- x1 = max(x1, 0)
- y1 = max(y1, 0)
- x2 = min(x2, w-1)
- y2 = min(y2, h-1)
-
- x_from = -offset[0] if offset[0] < 0 else 0
- y_from = -offset[1] if offset[1] < 0 else 0
- x_to = insize - offset_[0] - 1 if offset_[0] >= 0 else insize - 1
- y_to = insize - offset_[1] - 1 if offset_[1] >= 0 else insize - 1
-
- crop_img[y_from:y_to+1, x_from:x_to+1] = img[y1:y2+1, x1:x2+1].copy()
- crop_mask[y_from:y_to+1, x_from:x_to+1] = ignore_mask[y1:y2+1, x1:x2+1].copy()
-
- poses[:, :, :2] -= offset
- return crop_img.astype('uint8'), crop_mask, poses
-
- def distort_color(self, img):
- img_max = np.broadcast_to(np.array(255, dtype=np.uint8), img.shape[:-1])
- img_min = np.zeros(img.shape[:-1], dtype=np.uint8)
-
- hsv_img = cv2.cvtColor(img.copy(), cv2.COLOR_BGR2HSV).astype(np.int32)
- hsv_img[:, :, 0] = np.maximum(np.minimum(hsv_img[:, :, 0] - 10 + np.random.randint(20 + 1), img_max), img_min) # hue
- hsv_img[:, :, 1] = np.maximum(np.minimum(hsv_img[:, :, 1] - 40 + np.random.randint(80 + 1), img_max), img_min) # saturation
- hsv_img[:, :, 2] = np.maximum(np.minimum(hsv_img[:, :, 2] - 30 + np.random.randint(60 + 1), img_max), img_min) # value
- hsv_img = hsv_img.astype(np.uint8)
-
- distorted_img = cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR)
- return distorted_img
-
- def flip_img(self, img, mask, poses):
- flipped_img = cv2.flip(img, 1)
- flipped_mask = cv2.flip(mask.astype(np.uint8), 1).astype('bool')
- poses[:, :, 0] = img.shape[1] - 1 - poses[:, :, 0]
-
- def swap_joints(poses, joint_type_, joint_type_2):
- tmp = poses[:, joint_type_].copy()
- poses[:, joint_type_] = poses[:, joint_type_2]
- poses[:, joint_type_2] = tmp
-
- swap_joints(poses, JointType.LeftEye, JointType.RightEye)
- swap_joints(poses, JointType.LeftEar, JointType.RightEar)
- swap_joints(poses, JointType.LeftShoulder, JointType.RightShoulder)
- swap_joints(poses, JointType.LeftElbow, JointType.RightElbow)
- swap_joints(poses, JointType.LeftHand, JointType.RightHand)
- swap_joints(poses, JointType.LeftWaist, JointType.RightWaist)
- swap_joints(poses, JointType.LeftKnee, JointType.RightKnee)
- swap_joints(poses, JointType.LeftFoot, JointType.RightFoot)
- return flipped_img, flipped_mask, poses
-
- def augment_data(self, img, ignore_mask, poses):
- aug_img = img.copy()
- aug_img, ignore_mask, poses = self.random_resize_img(aug_img, ignore_mask, poses)
- aug_img, ignore_mask, poses = self.random_rotate_img(aug_img, ignore_mask, poses)
- aug_img, ignore_mask, poses = self.random_crop_img(aug_img, ignore_mask, poses)
- if np.random.randint(2):
- aug_img = self.distort_color(aug_img)
- if np.random.randint(2):
- aug_img, ignore_mask, poses = self.flip_img(aug_img, ignore_mask, poses)
-
- return aug_img, ignore_mask, poses
- # ------------------------------- end -----------------------------------
-
-
- # ------------------------------ Heatmap ------------------------------------
- # return shape: (height, width)
- def generate_gaussian_heatmap(self, shape, joint, sigma):
- x, y = joint
- grid_x = np.tile(np.arange(shape[1]), (shape[0], 1))
- grid_y = np.tile(np.arange(shape[0]), (shape[1], 1)).transpose()
- grid_distance = (grid_x - x) ** 2 + (grid_y - y) ** 2
- gaussian_heatmap = np.exp(-0.5 * grid_distance / sigma**2)
- return gaussian_heatmap
-
- def generate_heatmaps(self, img, poses, heatmap_sigma):
- heatmaps = np.zeros((0,) + img.shape[:-1])
- sum_heatmap = np.zeros(img.shape[:-1])
- for joint_index in range(len(JointType)):
- heatmap = np.zeros(img.shape[:-1])
- for pose in poses:
- if pose[joint_index, 2] > 0:
- jointmap = self.generate_gaussian_heatmap(img.shape[:-1], pose[joint_index][:2], heatmap_sigma)
- heatmap[jointmap > heatmap] = jointmap[jointmap > heatmap]
- sum_heatmap[jointmap > sum_heatmap] = jointmap[jointmap > sum_heatmap]
- heatmaps = np.vstack((heatmaps, heatmap.reshape((1,) + heatmap.shape)))
- bg_heatmap = 1 - sum_heatmap # background channel
- heatmaps = np.vstack((heatmaps, bg_heatmap[None]))
- return heatmaps.astype('f')
-
- def generate_gaussian_heatmap_fast(self, shape, joint, sigma):
- x, y = joint
- grid_x = np.tile(np.arange(shape[1]), (shape[0], 1))
- grid_y = np.tile(np.arange(shape[0]), (shape[1], 1)).transpose()
- grid_x = grid_x + 0.4375
- grid_y = grid_y + 0.4375
- grid_distance = (grid_x - x) ** 2 + (grid_y - y) ** 2
- gaussian_heatmap = np.exp(-0.5 * grid_distance / sigma**2)
- return gaussian_heatmap
-
- def generate_heatmaps_fast(self, img, poses, heatmap_sigma):
- resize_shape = (img.shape[0] // 8, img.shape[1] // 8)
- heatmaps = np.zeros((0,) + resize_shape)
- sum_heatmap = np.zeros(resize_shape)
- for joint_index in range(len(JointType)):
- heatmap = np.zeros(resize_shape)
- for pose in poses:
- if pose[joint_index, 2] > 0:
- jointmap = self.generate_gaussian_heatmap_fast(resize_shape, pose[joint_index][:2]/8,
- heatmap_sigma/8)
- index_1 = jointmap > heatmap
- heatmap[index_1] = jointmap[index_1]
- index_2 = jointmap > sum_heatmap
- sum_heatmap[index_2] = jointmap[index_2]
- heatmaps = np.vstack((heatmaps, heatmap.reshape((1,) + heatmap.shape)))
-
- bg_heatmap = 1 - sum_heatmap # background channel
- heatmaps = np.vstack((heatmaps, bg_heatmap[None]))
- return heatmaps.astype('f')
- # ------------------------------ end ------------------------------------
-
- # ------------------------------ PAF ------------------------------------
- # return shape: (2, height, width)
- def generate_constant_paf(self, shape, joint_from, joint_to, paf_width):
- if np.array_equal(joint_from, joint_to): # same joint
- return np.zeros((2,) + shape[:-1])
-
- joint_distance = np.linalg.norm(joint_to - joint_from)
- unit_vector = (joint_to - joint_from) / joint_distance
- rad = np.pi / 2
- # [[0, 1], [-1, 0]]
- rot_matrix = np.array([[np.cos(rad), np.sin(rad)], [-np.sin(rad), np.cos(rad)]])
- # [[u_y], [-u_x]]
- vertical_unit_vector = np.dot(rot_matrix, unit_vector)
- grid_x = np.tile(np.arange(shape[1]), (shape[0], 1))
- grid_y = np.tile(np.arange(shape[0]), (shape[1], 1)).transpose()
- horizontal_inner_product = unit_vector[0] * (grid_x - joint_from[0]) + unit_vector[1] * (grid_y - joint_from[1])
- horizontal_paf_flag = (horizontal_inner_product >= 0) & (horizontal_inner_product <= joint_distance)
- vertical_inner_product = vertical_unit_vector[0] * (grid_x - joint_from[0]) + vertical_unit_vector[1] *\
- (grid_y - joint_from[1])
- vertical_paf_flag = np.abs(vertical_inner_product) <= paf_width # paf_width : 8
- paf_flag = horizontal_paf_flag & vertical_paf_flag
- constant_paf = np.stack((paf_flag, paf_flag)) *\
- np.broadcast_to(unit_vector, shape[:-1] + (2,)).transpose(2, 0, 1)
-
- return constant_paf
-
- def generate_pafs(self, img, poses, paf_sigma):
- pafs = np.zeros((0,) + img.shape[:-1])
-
- for limb in params['limbs_point']:
- paf = np.zeros((2,) + img.shape[:-1])
- paf_flags = np.zeros(paf.shape) # for constant paf
-
- for pose in poses:
- joint_from, joint_to = pose[limb]
- if joint_from[2] > 0 and joint_to[2] > 0:
- limb_paf = self.generate_constant_paf(img.shape, joint_from[:2], joint_to[:2], paf_sigma) # [2, 368, 368]
- limb_paf_flags = limb_paf != 0
- paf_flags += np.broadcast_to(limb_paf_flags[0] | limb_paf_flags[1], limb_paf.shape)
-
- paf += limb_paf
-
- paf[paf_flags > 0] /= paf_flags[paf_flags > 0]
- pafs = np.vstack((pafs, paf))
- return pafs.astype('f')
-
- def generate_constant_paf_fast(self, shape, joint_from, joint_to, paf_width):
- if np.array_equal(joint_from, joint_to): # same joint
- return np.zeros((2,) + shape[:-1])
-
- joint_distance = np.linalg.norm(joint_to - joint_from)
- unit_vector = (joint_to - joint_from) / joint_distance
- rad = np.pi / 2
- # [[0, 1], [-1, 0]]
- rot_matrix = np.array([[np.cos(rad), np.sin(rad)], [-np.sin(rad), np.cos(rad)]])
- # [[u_y], [-u_x]]
- vertical_unit_vector = np.dot(rot_matrix, unit_vector)
- grid_x = np.tile(np.arange(shape[1]), (shape[0], 1))
- grid_y = np.tile(np.arange(shape[0]), (shape[1], 1)).transpose()
- grid_x = grid_x + 0.4375
- grid_y = grid_y + 0.4375
- horizontal_inner_product = unit_vector[0] * (grid_x - joint_from[0]) + unit_vector[1] * (grid_y - joint_from[1])
- horizontal_paf_flag = (horizontal_inner_product >= 0) & (horizontal_inner_product <= joint_distance)
- vertical_inner_product = vertical_unit_vector[0] * (grid_x - joint_from[0]) + vertical_unit_vector[1] *\
- (grid_y - joint_from[1])
- vertical_paf_flag = np.abs(vertical_inner_product) <= paf_width # paf_width : 8/8 = 1
- paf_flag = horizontal_paf_flag & vertical_paf_flag
- constant_paf = np.stack((paf_flag, paf_flag)) *\
- np.broadcast_to(unit_vector, shape[:-1] + (2,)).transpose(2, 0, 1)
-
- return constant_paf
-
- def generate_pafs_fast(self, img, poses, paf_sigma):
- resize_shape = (img.shape[0]//8, img.shape[1]//8, 3)
- pafs = np.zeros((0,) + resize_shape[:-1])
-
- for limb in params['limbs_point']:
- paf = np.zeros((2,) + resize_shape[:-1])
- paf_flags = np.zeros(paf.shape) # for constant paf
-
- for pose in poses:
- joint_from, joint_to = pose[limb]
- if joint_from[2] > 0 and joint_to[2] > 0:
- limb_paf = self.generate_constant_paf_fast(resize_shape, joint_from[:2]/8, joint_to[:2]/8, paf_sigma/8) # [2, 368, 368]
- limb_paf_flags = limb_paf != 0
- paf_flags += np.broadcast_to(limb_paf_flags[0] | limb_paf_flags[1], limb_paf.shape)
-
- paf += limb_paf
-
- index_1 = paf_flags > 0
- paf[index_1] /= paf_flags[index_1]
- pafs = np.vstack((pafs, paf))
- return pafs.astype('f')
- # ------------------------------ end ------------------------------------
-
- def get_img_annotation(self, ind=None, img_id=None):
- annotations = None
-
- if ind is not None:
- img_id = self.imgIds[ind]
- anno_ids = self.train.getAnnIds(imgIds=[img_id], iscrowd=None)
-
- # annotation for that image
- if anno_ids:
- annotations_for_img = self.train.loadAnns(anno_ids)
-
- person_cnt = 0
- valid_annotations_for_img = []
- for annotation in annotations_for_img:
- # if too few keypoints or too small
- if annotation['num_keypoints'] >= params['min_keypoints'] and annotation['area'] > params['min_area']:
- person_cnt += 1
- valid_annotations_for_img.append(annotation)
-
- # if person annotation
- if person_cnt > 0:
- annotations = valid_annotations_for_img
-
- img_path = os.path.join(self.imgpath, self.train.loadImgs([img_id])[0]['file_name'])
- mask_path = os.path.join(self.maskpath, '{:012d}.png'.format(img_id))
- img = cv2.imread(img_path)
- ignore_mask = cv2.imread(mask_path, 0)
- if ignore_mask is None:
- ignore_mask = np.zeros(img.shape[:2], np.float32)
- else:
- ignore_mask[ignore_mask == 255] = 1
-
- if self.mode == 'eval':
- return img, img_id, annotations_for_img, ignore_mask
-
- return img, img_id, annotations, ignore_mask.astype('f')
-
- def parse_annotation(self, annotations):
- poses = np.zeros((0, len(JointType), 3), dtype=np.int32)
-
- for ann in annotations:
- ann_pose = np.array(ann['keypoints']).reshape(-1, 3)
- pose = np.zeros((1, len(JointType), 3), dtype=np.int32)
-
- # convert poses position
- for i, joint_index in enumerate(params['joint_indices']):
- pose[0][joint_index] = ann_pose[i]
-
- # compute neck position
- if pose[0][JointType.LeftShoulder][2] > 0 and pose[0][JointType.RightShoulder][2] > 0:
- pose[0][JointType.Neck][0] = int((pose[0][JointType.LeftShoulder][0] +
- pose[0][JointType.RightShoulder][0]) / 2)
- pose[0][JointType.Neck][1] = int((pose[0][JointType.LeftShoulder][1] +
- pose[0][JointType.RightShoulder][1]) / 2)
- pose[0][JointType.Neck][2] = 2
-
- poses = np.vstack((poses, pose))
- return poses
-
- def resize_output(self, input_np, map_h=46, map_w=46):
- if len(input_np.shape) == 3:
- output = np.zeros((input_np.shape[0], map_h, map_w))
- for i in range(input_np.shape[0]):
- output[i] = cv2.resize(input_np[i], (map_w, map_h))
- return output.astype('f')
-
- input_np = input_np.astype('f')
- output = cv2.resize(input_np, (map_h, map_w))
- return output
-
- def generate_labels(self, img, poses, ignore_mask):
- img, ignore_mask, poses = self.augment_data(img, ignore_mask, poses)
- resized_img, ignore_mask, resized_poses = self.resize_data(img, ignore_mask, poses,
- shape=(self.insize, self.insize))
-
- resized_heatmaps = self.generate_heatmaps_fast(resized_img, resized_poses, params['heatmap_sigma'])
-
- resized_pafs = self.generate_pafs_fast(resized_img, resized_poses, params['paf_sigma'])
-
- ignore_mask = cv2.morphologyEx(ignore_mask.astype('uint8'), cv2.MORPH_DILATE, np.ones((16, 16))).astype('bool')
- resized_ignore_mask = self.resize_output(ignore_mask)
-
-
- return resized_img, resized_pafs, resized_heatmaps, resized_ignore_mask
-
- def preprocess(self, img):
- x_data = img.astype('f')
- x_data /= 255
- x_data -= 0.5
- x_data = x_data.transpose(2, 0, 1)
- return x_data
-
- def __getitem__(self, i):
- img, img_id, annotations, ignore_mask = self.get_img_annotation(ind=i)
-
- if self.mode in ['eval', 'val']:
- # don't need to make heatmaps/pafs
- return img, np.array([img_id])
-
- # if no annotations are available
- while annotations is None:
- print("none annotations", img_id)
- img_id = self.imgIds[np.random.randint(len(self))]
- img, img_id, annotations, ignore_mask = self.get_img_annotation(img_id=img_id)
-
- poses = self.parse_annotation(annotations)
-
- # TEST
- # return img, poses, ignore_mask
-
- resized_img, pafs, heatmaps, ignore_mask = self.generate_labels(img, poses, ignore_mask)
- resized_img = self.preprocess(resized_img)
- ignore_mask = 1. - ignore_mask
-
- # # TEST
- # print("Shape: ", resized_img.dtype, " ", pafs.dtype, " ", heatmaps.dtype, " ", ignore_mask.dtype)
-
- return resized_img, pafs, heatmaps, ignore_mask
-
-
- class DistributedSampler():
- def __init__(self, dataset, rank, group_size, shuffle=True, seed=0):
- self.dataset = dataset
- self.rank = rank
- self.group_size = group_size
- self.dataset_len = len(self.dataset)
- self.num_samplers = int(math.ceil(self.dataset_len * 1.0 / self.group_size))
- self.total_size = self.num_samplers * self.group_size
- self.shuffle = shuffle
- self.seed = seed
-
- def __iter__(self):
- if self.shuffle:
- self.seed = (self.seed + 1) & 0xffffffff
- np.random.seed(self.seed)
- indices = np.random.permutation(self.dataset_len).tolist()
- else:
- indices = list(range(len(self.dataset_len)))
- indices += indices[:(self.total_size - len(indices))]
- indices = indices[self.rank::self.group_size]
- return iter(indices)
-
- def __len__(self):
- return self.num_samplers
-
- def valdata(jsonpath, imgpath, rank, group_size, mode='val', maskpath=''):
- #cv2.setNumThreads(0)
- val = ReadJson(jsonpath)
- dataset = txtdataset(val, imgpath, maskpath, params['insize'], mode=mode)
- sampler = DistributedSampler(dataset, rank, group_size)
- ds = de.GeneratorDataset(dataset, ['img', 'img_id'], num_parallel_workers=8, sampler=sampler)
- ds = ds.repeat(1)
- return ds
-
-
- def create_dataset(jsonpath, imgpath, maskpath, batch_size, rank, group_size, mode='train', repeat_num=1, shuffle=True,
- multiprocessing=True, num_worker=20):
-
- train = ReadJson(jsonpath)
- dataset = txtdataset(train, imgpath, maskpath, params['insize'], mode=mode)
- if group_size == 1:
- de_dataset = de.GeneratorDataset(dataset, ["image", "pafs", "heatmaps", "ignore_mask"],
- shuffle=shuffle,
- num_parallel_workers=num_worker,
- python_multiprocessing=multiprocessing)
- else:
- de_dataset = de.GeneratorDataset(dataset, ["image", "pafs", "heatmaps", "ignore_mask"],
- shuffle=shuffle,
- num_parallel_workers=num_worker,
- python_multiprocessing=multiprocessing,
- num_shards=group_size,
- shard_id=rank)
-
- de_dataset = de_dataset.batch(batch_size=batch_size, drop_remainder=True)
- de_dataset = de_dataset.repeat(repeat_num)
-
- return de_dataset
|