|
- #! /usr/bin/python
- # -*- coding: utf-8 -*-
-
- import os
-
- from tensorlayer import logging
- from tensorlayer.files.utils import (del_file, folder_exists, load_file_list, maybe_download_and_extract)
-
- __all__ = ['load_mpii_pose_dataset']
-
-
- def load_mpii_pose_dataset(path='data', is_16_pos_only=False):
- """Load MPII Human Pose Dataset.
-
- Parameters
- -----------
- path : str
- The path that the data is downloaded to.
- is_16_pos_only : boolean
- If True, only return the peoples contain 16 pose keypoints. (Usually be used for single person pose estimation)
-
- Returns
- ----------
- img_train_list : list of str
- The image directories of training data.
- ann_train_list : list of dict
- The annotations of training data.
- img_test_list : list of str
- The image directories of testing data.
- ann_test_list : list of dict
- The annotations of testing data.
-
- Examples
- --------
- >>> import pprint
- >>> import tensorlayer as tl
- >>> img_train_list, ann_train_list, img_test_list, ann_test_list = tl.files.load_mpii_pose_dataset()
- >>> image = tl.vis.read_image(img_train_list[0])
- >>> tl.vis.draw_mpii_pose_to_image(image, ann_train_list[0], 'image.png')
- >>> pprint.pprint(ann_train_list[0])
-
- References
- -----------
- - `MPII Human Pose Dataset. CVPR 14 <http://human-pose.mpi-inf.mpg.de>`__
- - `MPII Human Pose Models. CVPR 16 <http://pose.mpi-inf.mpg.de>`__
- - `MPII Human Shape, Poselet Conditioned Pictorial Structures and etc <http://pose.mpi-inf.mpg.de/#related>`__
- - `MPII Keyponts and ID <http://human-pose.mpi-inf.mpg.de/#download>`__
- """
- path = os.path.join(path, 'mpii_human_pose')
- logging.info("Load or Download MPII Human Pose > {}".format(path))
-
- # annotation
- url = "http://datasets.d2.mpi-inf.mpg.de/andriluka14cvpr/"
- tar_filename = "mpii_human_pose_v1_u12_2.zip"
- extracted_filename = "mpii_human_pose_v1_u12_2"
- if folder_exists(os.path.join(path, extracted_filename)) is False:
- logging.info("[MPII] (annotation) {} is nonexistent in {}".format(extracted_filename, path))
- maybe_download_and_extract(tar_filename, path, url, extract=True)
- del_file(os.path.join(path, tar_filename))
-
- # images
- url = "http://datasets.d2.mpi-inf.mpg.de/andriluka14cvpr/"
- tar_filename = "mpii_human_pose_v1.tar.gz"
- extracted_filename2 = "images"
- if folder_exists(os.path.join(path, extracted_filename2)) is False:
- logging.info("[MPII] (images) {} is nonexistent in {}".format(extracted_filename, path))
- maybe_download_and_extract(tar_filename, path, url, extract=True)
- del_file(os.path.join(path, tar_filename))
-
- # parse annotation, format see http://human-pose.mpi-inf.mpg.de/#download
- import scipy.io as sio
- logging.info("reading annotations from mat file ...")
- # mat = sio.loadmat(os.path.join(path, extracted_filename, "mpii_human_pose_v1_u12_1.mat"))
-
- # def fix_wrong_joints(joint): # https://github.com/mitmul/deeppose/blob/master/datasets/mpii_dataset.py
- # if '12' in joint and '13' in joint and '2' in joint and '3' in joint:
- # if ((joint['12'][0] < joint['13'][0]) and
- # (joint['3'][0] < joint['2'][0])):
- # joint['2'], joint['3'] = joint['3'], joint['2']
- # if ((joint['12'][0] > joint['13'][0]) and
- # (joint['3'][0] > joint['2'][0])):
- # joint['2'], joint['3'] = joint['3'], joint['2']
- # return joint
-
- ann_train_list = []
- ann_test_list = []
- img_train_list = []
- img_test_list = []
-
- def save_joints():
- # joint_data_fn = os.path.join(path, 'data.json')
- # fp = open(joint_data_fn, 'w')
- mat = sio.loadmat(os.path.join(path, extracted_filename, "mpii_human_pose_v1_u12_1.mat"))
-
- for _, (anno, train_flag) in enumerate( # all images
- zip(mat['RELEASE']['annolist'][0, 0][0], mat['RELEASE']['img_train'][0, 0][0])):
-
- img_fn = anno['image']['name'][0, 0][0]
- train_flag = int(train_flag)
-
- # print(i, img_fn, train_flag) # DEBUG print all images
-
- if train_flag:
- img_train_list.append(img_fn)
- ann_train_list.append([])
- else:
- img_test_list.append(img_fn)
- ann_test_list.append([])
-
- head_rect = []
- if 'x1' in str(anno['annorect'].dtype):
- head_rect = zip(
- [x1[0, 0] for x1 in anno['annorect']['x1'][0]], [y1[0, 0] for y1 in anno['annorect']['y1'][0]],
- [x2[0, 0] for x2 in anno['annorect']['x2'][0]], [y2[0, 0] for y2 in anno['annorect']['y2'][0]]
- )
- else:
- head_rect = [] # TODO
-
- if 'annopoints' in str(anno['annorect'].dtype):
- annopoints = anno['annorect']['annopoints'][0]
- head_x1s = anno['annorect']['x1'][0]
- head_y1s = anno['annorect']['y1'][0]
- head_x2s = anno['annorect']['x2'][0]
- head_y2s = anno['annorect']['y2'][0]
-
- for annopoint, head_x1, head_y1, head_x2, head_y2 in zip(annopoints, head_x1s, head_y1s, head_x2s,
- head_y2s):
- # if annopoint != []:
- # if len(annopoint) != 0:
- if annopoint.size:
- head_rect = [
- float(head_x1[0, 0]),
- float(head_y1[0, 0]),
- float(head_x2[0, 0]),
- float(head_y2[0, 0])
- ]
-
- # joint coordinates
- annopoint = annopoint['point'][0, 0]
- j_id = [str(j_i[0, 0]) for j_i in annopoint['id'][0]]
- x = [x[0, 0] for x in annopoint['x'][0]]
- y = [y[0, 0] for y in annopoint['y'][0]]
- joint_pos = {}
- for _j_id, (_x, _y) in zip(j_id, zip(x, y)):
- joint_pos[int(_j_id)] = [float(_x), float(_y)]
- # joint_pos = fix_wrong_joints(joint_pos)
-
- # visibility list
- if 'is_visible' in str(annopoint.dtype):
- vis = [v[0] if v.size > 0 else [0] for v in annopoint['is_visible'][0]]
- vis = dict([(k, int(v[0])) if len(v) > 0 else v for k, v in zip(j_id, vis)])
- else:
- vis = None
-
- # if len(joint_pos) == 16:
- if ((is_16_pos_only ==True) and (len(joint_pos) == 16)) or (is_16_pos_only == False):
- # only use image with 16 key points / or use all
- data = {
- 'filename': img_fn,
- 'train': train_flag,
- 'head_rect': head_rect,
- 'is_visible': vis,
- 'joint_pos': joint_pos
- }
- # print(json.dumps(data), file=fp) # py3
- if train_flag:
- ann_train_list[-1].append(data)
- else:
- ann_test_list[-1].append(data)
-
- # def write_line(datum, fp):
- # joints = sorted([[int(k), v] for k, v in datum['joint_pos'].items()])
- # joints = np.array([j for i, j in joints]).flatten()
- #
- # out = [datum['filename']]
- # out.extend(joints)
- # out = [str(o) for o in out]
- # out = ','.join(out)
- #
- # print(out, file=fp)
-
- # def split_train_test():
- # # fp_test = open('data/mpii/test_joints.csv', 'w')
- # fp_test = open(os.path.join(path, 'test_joints.csv'), 'w')
- # # fp_train = open('data/mpii/train_joints.csv', 'w')
- # fp_train = open(os.path.join(path, 'train_joints.csv'), 'w')
- # # all_data = open('data/mpii/data.json').readlines()
- # all_data = open(os.path.join(path, 'data.json')).readlines()
- # N = len(all_data)
- # N_test = int(N * 0.1)
- # N_train = N - N_test
- #
- # print('N:{}'.format(N))
- # print('N_train:{}'.format(N_train))
- # print('N_test:{}'.format(N_test))
- #
- # np.random.seed(1701)
- # perm = np.random.permutation(N)
- # test_indices = perm[:N_test]
- # train_indices = perm[N_test:]
- #
- # print('train_indices:{}'.format(len(train_indices)))
- # print('test_indices:{}'.format(len(test_indices)))
- #
- # for i in train_indices:
- # datum = json.loads(all_data[i].strip())
- # write_line(datum, fp_train)
- #
- # for i in test_indices:
- # datum = json.loads(all_data[i].strip())
- # write_line(datum, fp_test)
-
- save_joints()
- # split_train_test() #
-
- ## read images dir
- logging.info("reading images list ...")
- img_dir = os.path.join(path, extracted_filename2)
- _img_list = load_file_list(path=os.path.join(path, extracted_filename2), regx='\\.jpg', printable=False)
- # ann_list = json.load(open(os.path.join(path, 'data.json')))
- for i, im in enumerate(img_train_list):
- if im not in _img_list:
- print('missing training image {} in {} (remove from img(ann)_train_list)'.format(im, img_dir))
- # img_train_list.remove(im)
- del img_train_list[i]
- del ann_train_list[i]
- for i, im in enumerate(img_test_list):
- if im not in _img_list:
- print('missing testing image {} in {} (remove from img(ann)_test_list)'.format(im, img_dir))
- # img_test_list.remove(im)
- del img_train_list[i]
- del ann_train_list[i]
-
- ## check annotation and images
- n_train_images = len(img_train_list)
- n_test_images = len(img_test_list)
- n_images = n_train_images + n_test_images
- logging.info("n_images: {} n_train_images: {} n_test_images: {}".format(n_images, n_train_images, n_test_images))
- n_train_ann = len(ann_train_list)
- n_test_ann = len(ann_test_list)
- n_ann = n_train_ann + n_test_ann
- logging.info("n_ann: {} n_train_ann: {} n_test_ann: {}".format(n_ann, n_train_ann, n_test_ann))
- n_train_people = len(sum(ann_train_list, []))
- n_test_people = len(sum(ann_test_list, []))
- n_people = n_train_people + n_test_people
- logging.info("n_people: {} n_train_people: {} n_test_people: {}".format(n_people, n_train_people, n_test_people))
- # add path to all image file name
- for i, value in enumerate(img_train_list):
- img_train_list[i] = os.path.join(img_dir, value)
- for i, value in enumerate(img_test_list):
- img_test_list[i] = os.path.join(img_dir, value)
- return img_train_list, ann_train_list, img_test_list, ann_test_list
|