| @@ -142,10 +142,10 @@ Parameters for both training and evaluation can be set in config.py | |||
| 'batch_size': 10 # training batch size | |||
| 'lr_gamma': 0.1 # lr scale when reach lr_steps | |||
| 'lr_steps': '100000,200000,250000' # the steps when lr * lr_gamma | |||
| 'loss scale': 16386 # the loss scale of mixed precision | |||
| 'loss scale': 16384 # the loss scale of mixed precision | |||
| 'max_epoch_train': 60 # total training epochs | |||
| 'insize': 368 # image size used as input to the model | |||
| 'keep_checkpoint_max': 5 # only keep the last keep_checkpoint_max checkpoint | |||
| 'keep_checkpoint_max': 1 # only keep the last keep_checkpoint_max checkpoint | |||
| 'log_interval': 100 # the interval of print a log | |||
| 'ckpt_interval': 5000 # the interval of saving a output model | |||
| ``` | |||
| @@ -195,7 +195,7 @@ For more configuration details, please refer the script `config.py`. | |||
| ```python | |||
| # grep "AP" eval.log | |||
| {'AP': 0.40030956300341397, 'Ap .5': 0.6658941566481336, 'AP .75': 0.396047897339743, 'AP (M)': 0.3075356543635785, 'AP (L)': 0.533772768618845, 'AR': 0.4519836272040302, 'AR .5': 0.693639798488665, 'AR .75': 0.4570214105793451, 'AR (M)': 0.32155148866429945, 'AR (L)': 0.6330360460795242} | |||
| {'AP': 0.39830956300341397, 'Ap .5': 0.6658941566481336, 'AP .75': 0.396047897339743, 'AP (M)': 0.3075356543635785, 'AP (L)': 0.533772768618845, 'AR': 0.4519836272040302, 'AR .5': 0.693639798488665, 'AR .75': 0.4570214105793451, 'AR (M)': 0.32155148866429945, 'AR (L)': 0.6330360460795242} | |||
| ``` | |||
| @@ -209,14 +209,14 @@ For more configuration details, please refer the script `config.py`. | |||
| | -------------------------- | ----------------------------------------------------------- | |||
| | Model Version | openpose | |||
| | Resource | Ascend 910 ;CPU 2.60GHz,192cores;Memory,755G | |||
| | uploaded Date | 10/20/2020 (month/day/year) | |||
| | uploaded Date | 12/14/2020 (month/day/year) | |||
| | MindSpore Version | 1.0.1-alpha | |||
| | Training Parameters | epoch = 60, steps = 30k, batch_size = 10, lr = 0.0001 | |||
| | Optimizer | Adam | |||
| | Training Parameters | epoch=60(1pcs)/80(8pcs), steps=30k(1pcs)/5k(8pcs), batch_size=10, init_lr=0.0001 | |||
| | Optimizer | Adam(1pcs)/Momentum(8pcs) | |||
| | Loss Function | MSE | |||
| | outputs | pose | |||
| | Speed | 1pc: 29imgs/s | |||
| | Total time | 1pc: 30h | |||
| | Speed | 1pcs: 35fps, 8pcs: 230fps | |||
| | Total time | 1pcs: 22.5h, 8pcs: 5.1h | |||
| | Checkpoint for Fine tuning | 602.33M (.ckpt file) | |||
| @@ -17,10 +17,9 @@ import os | |||
| import argparse | |||
| import warnings | |||
| import sys | |||
| import cv2 | |||
| from tqdm import tqdm | |||
| import numpy as np | |||
| from tqdm import tqdm | |||
| import cv2 | |||
| from scipy.ndimage.filters import gaussian_filter | |||
| from pycocotools.coco import COCO as LoadAnn | |||
| from pycocotools.cocoeval import COCOeval as MapEval | |||
| @@ -30,9 +29,10 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore.communication.management import init, get_rank, get_group_size | |||
| from mindspore.common import dtype as mstype | |||
| from src.dataset import valdata | |||
| from src.openposenet import OpenPoseNet | |||
| from src.config import params, JointType | |||
| from src.openposenet import OpenPoseNet | |||
| from src.dataset import valdata | |||
| warnings.filterwarnings("ignore") | |||
| devid = int(os.getenv('DEVICE_ID')) | |||
| @@ -40,6 +40,18 @@ context.set_context(mode=context.GRAPH_MODE, | |||
| device_target="Ascend", save_graphs=False, device_id=devid) | |||
| show_gt = 0 | |||
| parser = argparse.ArgumentParser('mindspore openpose_net test') | |||
| parser.add_argument('--model_path', type=str, default='./0-33_170000.ckpt', help='path of testing model') | |||
| parser.add_argument('--imgpath_val', type=str, default='./dataset/coco/val2017', help='path of testing imgs') | |||
| parser.add_argument('--ann', type=str, default='./dataset/coco/annotations/person_keypoints_val2017.json', | |||
| help='path of annotations') | |||
| parser.add_argument('--output_path', type=str, default='./output_img', help='path of testing imgs') | |||
| # distributed related | |||
| parser.add_argument('--is_distributed', type=int, default=0, help='if multi device') | |||
| parser.add_argument('--rank', type=int, default=0, help='local rank of distributed') | |||
| parser.add_argument('--group_size', type=int, default=1, help='world size of distributed') | |||
| args, _ = parser.parse_known_args() | |||
| def evaluate_mAP(res_file, ann_file, ann_type='keypoints', silence=True): | |||
| class NullWriter(): | |||
| def write(self, arg): | |||
| @@ -68,23 +80,6 @@ def evaluate_mAP(res_file, ann_file, ann_type='keypoints', silence=True): | |||
| return info_str | |||
| def parse_args(): | |||
| """Parse arguments.""" | |||
| parser = argparse.ArgumentParser('mindspore openpose_net test') | |||
| parser.add_argument('--model_path', type=str, default='./scripts/train_parallel0/checkpoints/ckpt_0/0-60_663.ckpt', | |||
| help='path of testing model') | |||
| parser.add_argument('--imgpath_val', type=str, default='/data0/zhy/dataset/coco/val2017', | |||
| help='path of testing imgs') | |||
| parser.add_argument('--ann', type=str, default='/data0/zhy/dataset/coco/annotations/person_keypoints_val2017.json', | |||
| help='path of annotations') | |||
| parser.add_argument('--output_path', type=str, default='./output_img', help='path of testing imgs') | |||
| # distributed related | |||
| parser.add_argument('--is_distributed', type=int, default=0, help='if multi device') | |||
| parser.add_argument('--rank', type=int, default=0, help='local rank of distributed') | |||
| parser.add_argument('--group_size', type=int, default=1, help='world size of distributed') | |||
| args, _ = parser.parse_known_args() | |||
| return args | |||
| def load_model(test_net, model_path): | |||
| assert os.path.exists(model_path) | |||
| @@ -178,7 +173,7 @@ def compute_peaks_from_heatmaps(heatmaps): | |||
| return all_peaks | |||
| def compute_candidate_connections(paf, cand_a, cand_b, img_len, cfg): | |||
| def compute_candidate_connections(paf, cand_a, cand_b, img_len, params_): | |||
| candidate_connections = [] | |||
| for joint_a in cand_a: | |||
| for joint_b in cand_b: | |||
| @@ -186,33 +181,33 @@ def compute_candidate_connections(paf, cand_a, cand_b, img_len, cfg): | |||
| norm = np.linalg.norm(vector) | |||
| if norm == 0: | |||
| continue | |||
| ys = np.linspace(joint_a[1], joint_b[1], num=cfg['n_integ_points']) | |||
| xs = np.linspace(joint_a[0], joint_b[0], num=cfg['n_integ_points']) | |||
| ys = np.linspace(joint_a[1], joint_b[1], num=params_['n_integ_points']) | |||
| xs = np.linspace(joint_a[0], joint_b[0], num=params_['n_integ_points']) | |||
| integ_points = np.stack([ys, xs]).T.round().astype('i') | |||
| paf_in_edge = np.hstack([paf[0][np.hsplit(integ_points, 2)], paf[1][np.hsplit(integ_points, 2)]]) | |||
| unit_vector = vector / norm | |||
| inner_products = np.dot(paf_in_edge, unit_vector) | |||
| integ_value = inner_products.sum() / len(inner_products) | |||
| integ_value_with_dist_prior = integ_value + min(cfg['limb_length_ratio'] * img_len / norm - | |||
| cfg['length_penalty_value'], 0) | |||
| n_valid_points = sum(inner_products > cfg['inner_product_thresh']) | |||
| if n_valid_points > cfg['n_integ_points_thresh'] and integ_value_with_dist_prior > 0: | |||
| integ_value_with_dist_prior = integ_value + min(params_['limb_length_ratio'] * img_len / norm - | |||
| params_['length_penalty_value'], 0) | |||
| n_valid_points = sum(inner_products > params_['inner_product_thresh']) | |||
| if n_valid_points > params_['n_integ_points_thresh'] and integ_value_with_dist_prior > 0: | |||
| candidate_connections.append([int(joint_a[3]), int(joint_b[3]), integ_value_with_dist_prior]) | |||
| candidate_connections = sorted(candidate_connections, key=lambda x: x[2], reverse=True) | |||
| return candidate_connections | |||
| def compute_connections(pafs, all_peaks, img_len, cfg): | |||
| def compute_connections(pafs, all_peaks, img_len, params_): | |||
| all_connections = [] | |||
| for i in range(len(cfg['limbs_point'])): | |||
| for i in range(len(params_['limbs_point'])): | |||
| paf_index = [i * 2, i * 2 + 1] | |||
| paf = pafs[paf_index] # shape: (2, 320, 320) | |||
| limb_point = cfg['limbs_point'][i] # example: [<JointType.Neck: 1>, <JointType.RightWaist: 8>] | |||
| limb_point = params_['limbs_point'][i] # example: [<JointType.Neck: 1>, <JointType.RightWaist: 8>] | |||
| cand_a = all_peaks[all_peaks[:, 0] == limb_point[0]][:, 1:] | |||
| cand_b = all_peaks[all_peaks[:, 0] == limb_point[1]][:, 1:] | |||
| if cand_a.shape[0] > 0 and cand_b.shape[0] > 0: | |||
| candidate_connections = compute_candidate_connections(paf, cand_a, cand_b, img_len, cfg) | |||
| candidate_connections = compute_candidate_connections(paf, cand_a, cand_b, img_len, params_) | |||
| connections = np.zeros((0, 3)) | |||
| @@ -226,11 +221,11 @@ def compute_connections(pafs, all_peaks, img_len, cfg): | |||
| all_connections.append(np.zeros((0, 3))) | |||
| return all_connections | |||
| def grouping_key_points(all_connections, candidate_peaks, cfg): | |||
| def grouping_key_points(all_connections, candidate_peaks, params_): | |||
| subsets = -1 * np.ones((0, 20)) | |||
| for l, connections in enumerate(all_connections): | |||
| joint_a, joint_b = cfg['limbs_point'][l] | |||
| joint_a, joint_b = params_['limbs_point'][l] | |||
| for ind_a, ind_b, score in connections[:, :3]: | |||
| ind_a, ind_b = int(ind_a), int(ind_b) | |||
| joint_found_cnt = 0 | |||
| @@ -249,6 +244,7 @@ def grouping_key_points(all_connections, candidate_peaks, cfg): | |||
| found_subset[-1] += 1 # increment joint count | |||
| found_subset[-2] += candidate_peaks[ind_b, 3] + score | |||
| elif joint_found_cnt == 2: | |||
| found_subset_1 = subsets[joint_found_subset_index[0]] | |||
| @@ -289,10 +285,8 @@ def grouping_key_points(all_connections, candidate_peaks, cfg): | |||
| pass | |||
| # delete low score subsets | |||
| keep = np.logical_and(subsets[:, -1] >= cfg['n_subset_limbs_thresh'], | |||
| subsets[:, -2] / subsets[:, -1] >= cfg['subset_score_thresh']) | |||
| # cfg['n_subset_limbs_thresh'] = 3 | |||
| # cfg['subset_score_thresh'] = 0.2 | |||
| keep = np.logical_and(subsets[:, -1] >= params_['n_subset_limbs_thresh'], | |||
| subsets[:, -2] / subsets[:, -1] >= params_['subset_score_thresh']) | |||
| subsets = subsets[keep] | |||
| return subsets | |||
| @@ -319,7 +313,7 @@ def detect(img, network): | |||
| # map_w, map_h = compute_optimal_size(orig_img, params['heatmap_size']) # 320 | |||
| map_w, map_h = compute_optimal_size(orig_img, params['inference_img_size']) | |||
| print("image size is: ", input_w, input_h) | |||
| # print("image size is: ", input_w, input_h) | |||
| resized_image = cv2.resize(orig_img, (input_w, input_h)) | |||
| x_data = preprocess(resized_image) | |||
| @@ -394,7 +388,7 @@ def draw_person_pose(orig_img, poses): | |||
| return canvas | |||
| def depreprocess(img): | |||
| # x_data = img.astype('f') | |||
| #x_data = img.astype('f') | |||
| x_data = img[0] | |||
| x_data += 0.5 | |||
| x_data *= 255 | |||
| @@ -402,15 +396,14 @@ def depreprocess(img): | |||
| x_data = x_data.transpose(1, 2, 0) | |||
| return x_data | |||
| def _eval(): | |||
| args = parse_args() | |||
| def val(): | |||
| if args.is_distributed: | |||
| init() | |||
| args.rank = get_rank() | |||
| args.group_size = get_group_size() | |||
| if not os.path.exists(args.output_path): | |||
| os.mkdir(args.output_path) | |||
| network = OpenPoseNet() | |||
| network = OpenPoseNet(vgg_with_bn=params['vgg_with_bn']) | |||
| network.set_train(False) | |||
| load_model(network, args.model_path) | |||
| @@ -455,4 +448,4 @@ def _eval(): | |||
| print('result: ', res) | |||
| if __name__ == "__main__": | |||
| _eval() | |||
| val() | |||
| @@ -15,9 +15,8 @@ | |||
| # ============================================================================ | |||
| export DEVICE_ID=0 | |||
| export RANK_ID=0 | |||
| python eval.py \ | |||
| --model_path ./scripts/train_parallel0/checkpoints/ckpt_0/0-60_663.ckpt \ | |||
| --imgpath_val /data0/zhy/dataset/coco/val2017 \ | |||
| --ann /data0/zhy/dataset/coco/annotations/person_keypoints_val2017.json \ | |||
| --model_path ./scripts/train_parallel0/checkpoints/ckpt_0/0-80_663.ckpt \ | |||
| --imgpath_val ./dataset/val2017 \ | |||
| --ann ./dataset/annotations/person_keypoints_val2017.json \ | |||
| > eval.log 2>&1 & | |||
| @@ -14,5 +14,6 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| export DEVICE_ID=0 | |||
| cd .. | |||
| python train.py --train_dir train2017 --train_ann person_keypoints_train2017.json > scripts/train.log 2>&1 & | |||
| @@ -53,21 +53,41 @@ class JointType(IntEnum): | |||
| params = { | |||
| # paths | |||
| 'data_dir': '/data0/zhy/dataset/coco', | |||
| 'vgg_path': '/data0/zhy/dataset/coco/vgg19-0-97_5004.ckpt', | |||
| 'data_dir': './dataset', | |||
| 'save_model_path': './checkpoints/', | |||
| 'load_pretrain': False, | |||
| 'pretrained_model_path': "", | |||
| # training params | |||
| 'batch_size': 10, | |||
| 'lr': 1e-4, | |||
| 'lr_gamma': 0.1, | |||
| 'lr_steps': '100000,200000,250000', | |||
| 'lr_steps_NP': '250000', | |||
| # train type | |||
| 'train_type': 'fix_loss_scale', # chose in ['clip_grad', 'fix_loss_scale'] | |||
| 'train_type_NP': 'clip_grad', | |||
| # vgg bn | |||
| 'vgg_with_bn': False, | |||
| 'vgg_path': './vgg_model/vgg19-0-97_5004.ckpt', | |||
| # if clip_grad | |||
| 'GRADIENT_CLIP_TYPE': 1, | |||
| 'GRADIENT_CLIP_VALUE': 10.0, | |||
| 'loss_scale': 16386, | |||
| # optimizer and lr | |||
| 'optimizer': "Adam", # chose in ['Momentum', 'Adam'] | |||
| 'optimizer_NP': "Momentum", | |||
| 'group_params': True, | |||
| 'group_params_NP': False, | |||
| 'lr': 1e-4, | |||
| 'lr_type': 'default', # chose in ["default", "cosine"] | |||
| 'lr_gamma': 0.1, # if default | |||
| 'lr_steps': '100000,200000,250000', # if default | |||
| 'lr_steps_NP': '250000,300000', # if default | |||
| 'warmup_epoch': 5, # if cosine | |||
| 'max_epoch_train': 60, | |||
| 'max_epoch_train_NP': 80, | |||
| 'loss_scale': 16384, | |||
| # default param | |||
| 'batch_size': 10, | |||
| 'min_keypoints': 5, | |||
| 'min_area': 32 * 32, | |||
| 'insize': 368, | |||
| @@ -75,9 +95,9 @@ params = { | |||
| 'paf_sigma': 8, | |||
| 'heatmap_sigma': 7, | |||
| 'eva_num': 100, | |||
| 'keep_checkpoint_max': 5, | |||
| 'keep_checkpoint_max': 1, | |||
| 'log_interval': 100, | |||
| 'ckpt_interval': 663, # 5000, | |||
| 'ckpt_interval': 5304, | |||
| 'min_box_size': 64, | |||
| 'max_box_size': 512, | |||
| @@ -15,10 +15,10 @@ | |||
| import os | |||
| import math | |||
| import random | |||
| import cv2 | |||
| import numpy as np | |||
| import cv2 | |||
| from pycocotools.coco import COCO as ReadJson | |||
| import mindspore.dataset as de | |||
| from src.config import JointType, params | |||
| @@ -41,6 +41,7 @@ class txtdataset(): | |||
| self.imgIds = random.sample(self.imgIds, n_samples) | |||
| print('{} images: {}'.format(mode, len(self))) | |||
| def __len__(self): | |||
| return len(self.imgIds) | |||
| @@ -217,9 +218,9 @@ class txtdataset(): | |||
| flipped_mask = cv2.flip(mask.astype(np.uint8), 1).astype('bool') | |||
| poses[:, :, 0] = img.shape[1] - 1 - poses[:, :, 0] | |||
| def swap_joints(poses, joint_type_1, joint_type_2): | |||
| tmp = poses[:, joint_type_1].copy() | |||
| poses[:, joint_type_1] = poses[:, joint_type_2] | |||
| def swap_joints(poses, joint_type_, joint_type_2): | |||
| tmp = poses[:, joint_type_].copy() | |||
| poses[:, joint_type_] = poses[:, joint_type_2] | |||
| poses[:, joint_type_2] = tmp | |||
| swap_joints(poses, JointType.LeftEye, JointType.RightEye) | |||
| @@ -243,8 +244,10 @@ class txtdataset(): | |||
| aug_img, ignore_mask, poses = self.flip_img(aug_img, ignore_mask, poses) | |||
| return aug_img, ignore_mask, poses | |||
| # ------------------------------------------------------------------ | |||
| # ------------------------------- end ----------------------------------- | |||
| # ------------------------------ Heatmap ------------------------------------ | |||
| # return shape: (height, width) | |||
| def generate_gaussian_heatmap(self, shape, joint, sigma): | |||
| x, y = joint | |||
| @@ -269,6 +272,38 @@ class txtdataset(): | |||
| heatmaps = np.vstack((heatmaps, bg_heatmap[None])) | |||
| return heatmaps.astype('f') | |||
| def generate_gaussian_heatmap_fast(self, shape, joint, sigma): | |||
| x, y = joint | |||
| grid_x = np.tile(np.arange(shape[1]), (shape[0], 1)) | |||
| grid_y = np.tile(np.arange(shape[0]), (shape[1], 1)).transpose() | |||
| grid_x = grid_x + 0.4375 | |||
| grid_y = grid_y + 0.4375 | |||
| grid_distance = (grid_x - x) ** 2 + (grid_y - y) ** 2 | |||
| gaussian_heatmap = np.exp(-0.5 * grid_distance / sigma**2) | |||
| return gaussian_heatmap | |||
| def generate_heatmaps_fast(self, img, poses, heatmap_sigma): | |||
| resize_shape = (img.shape[0] // 8, img.shape[1] // 8) | |||
| heatmaps = np.zeros((0,) + resize_shape) | |||
| sum_heatmap = np.zeros(resize_shape) | |||
| for joint_index in range(len(JointType)): | |||
| heatmap = np.zeros(resize_shape) | |||
| for pose in poses: | |||
| if pose[joint_index, 2] > 0: | |||
| jointmap = self.generate_gaussian_heatmap_fast(resize_shape, pose[joint_index][:2]/8, | |||
| heatmap_sigma/8) | |||
| index_1 = jointmap > heatmap | |||
| heatmap[index_1] = jointmap[index_1] | |||
| index_2 = jointmap > sum_heatmap | |||
| sum_heatmap[index_2] = jointmap[index_2] | |||
| heatmaps = np.vstack((heatmaps, heatmap.reshape((1,) + heatmap.shape))) | |||
| bg_heatmap = 1 - sum_heatmap # background channel | |||
| heatmaps = np.vstack((heatmaps, bg_heatmap[None])) | |||
| return heatmaps.astype('f') | |||
| # ------------------------------ end ------------------------------------ | |||
| # ------------------------------ PAF ------------------------------------ | |||
| # return shape: (2, height, width) | |||
| def generate_constant_paf(self, shape, joint_from, joint_to, paf_width): | |||
| if np.array_equal(joint_from, joint_to): # same joint | |||
| @@ -285,7 +320,7 @@ class txtdataset(): | |||
| grid_y = np.tile(np.arange(shape[0]), (shape[1], 1)).transpose() | |||
| horizontal_inner_product = unit_vector[0] * (grid_x - joint_from[0]) + unit_vector[1] * (grid_y - joint_from[1]) | |||
| horizontal_paf_flag = (horizontal_inner_product >= 0) & (horizontal_inner_product <= joint_distance) | |||
| vertical_inner_product = vertical_unit_vector[0] * (grid_x - joint_from[0]) + vertical_unit_vector[1] * \ | |||
| vertical_inner_product = vertical_unit_vector[0] * (grid_x - joint_from[0]) + vertical_unit_vector[1] *\ | |||
| (grid_y - joint_from[1]) | |||
| vertical_paf_flag = np.abs(vertical_inner_product) <= paf_width # paf_width : 8 | |||
| paf_flag = horizontal_paf_flag & vertical_paf_flag | |||
| @@ -314,6 +349,55 @@ class txtdataset(): | |||
| pafs = np.vstack((pafs, paf)) | |||
| return pafs.astype('f') | |||
| def generate_constant_paf_fast(self, shape, joint_from, joint_to, paf_width): | |||
| if np.array_equal(joint_from, joint_to): # same joint | |||
| return np.zeros((2,) + shape[:-1]) | |||
| joint_distance = np.linalg.norm(joint_to - joint_from) | |||
| unit_vector = (joint_to - joint_from) / joint_distance | |||
| rad = np.pi / 2 | |||
| # [[0, 1], [-1, 0]] | |||
| rot_matrix = np.array([[np.cos(rad), np.sin(rad)], [-np.sin(rad), np.cos(rad)]]) | |||
| # [[u_y], [-u_x]] | |||
| vertical_unit_vector = np.dot(rot_matrix, unit_vector) | |||
| grid_x = np.tile(np.arange(shape[1]), (shape[0], 1)) | |||
| grid_y = np.tile(np.arange(shape[0]), (shape[1], 1)).transpose() | |||
| grid_x = grid_x + 0.4375 | |||
| grid_y = grid_y + 0.4375 | |||
| horizontal_inner_product = unit_vector[0] * (grid_x - joint_from[0]) + unit_vector[1] * (grid_y - joint_from[1]) | |||
| horizontal_paf_flag = (horizontal_inner_product >= 0) & (horizontal_inner_product <= joint_distance) | |||
| vertical_inner_product = vertical_unit_vector[0] * (grid_x - joint_from[0]) + vertical_unit_vector[1] *\ | |||
| (grid_y - joint_from[1]) | |||
| vertical_paf_flag = np.abs(vertical_inner_product) <= paf_width # paf_width : 8/8 = 1 | |||
| paf_flag = horizontal_paf_flag & vertical_paf_flag | |||
| constant_paf = np.stack((paf_flag, paf_flag)) *\ | |||
| np.broadcast_to(unit_vector, shape[:-1] + (2,)).transpose(2, 0, 1) | |||
| return constant_paf | |||
| def generate_pafs_fast(self, img, poses, paf_sigma): | |||
| resize_shape = (img.shape[0]//8, img.shape[1]//8, 3) | |||
| pafs = np.zeros((0,) + resize_shape[:-1]) | |||
| for limb in params['limbs_point']: | |||
| paf = np.zeros((2,) + resize_shape[:-1]) | |||
| paf_flags = np.zeros(paf.shape) # for constant paf | |||
| for pose in poses: | |||
| joint_from, joint_to = pose[limb] | |||
| if joint_from[2] > 0 and joint_to[2] > 0: | |||
| limb_paf = self.generate_constant_paf_fast(resize_shape, joint_from[:2]/8, joint_to[:2]/8, paf_sigma/8) # [2, 368, 368] | |||
| limb_paf_flags = limb_paf != 0 | |||
| paf_flags += np.broadcast_to(limb_paf_flags[0] | limb_paf_flags[1], limb_paf.shape) | |||
| paf += limb_paf | |||
| index_1 = paf_flags > 0 | |||
| paf[index_1] /= paf_flags[index_1] | |||
| pafs = np.vstack((pafs, paf)) | |||
| return pafs.astype('f') | |||
| # ------------------------------ end ------------------------------------ | |||
| def get_img_annotation(self, ind=None, img_id=None): | |||
| annotations = None | |||
| @@ -389,14 +473,18 @@ class txtdataset(): | |||
| resized_img, ignore_mask, resized_poses = self.resize_data(img, ignore_mask, poses, | |||
| shape=(self.insize, self.insize)) | |||
| heatmaps = self.generate_heatmaps(resized_img, resized_poses, params['heatmap_sigma']) | |||
| pafs = self.generate_pafs(resized_img, resized_poses, params['paf_sigma']) # params['paf_sigma']: 8 | |||
| # heatmaps = self.generate_heatmaps(resized_img, resized_poses, params['heatmap_sigma']) | |||
| # resized_heatmaps = self.resize_output(heatmaps) | |||
| resized_heatmaps = self.generate_heatmaps_fast(resized_img, resized_poses, params['heatmap_sigma']) | |||
| # pafs = self.generate_pafs(resized_img, resized_poses, params['paf_sigma']) | |||
| # resized_pafs = self.resize_output(pafs) | |||
| resized_pafs = self.generate_pafs_fast(resized_img, resized_poses, params['paf_sigma']) | |||
| ignore_mask = cv2.morphologyEx(ignore_mask.astype('uint8'), cv2.MORPH_DILATE, np.ones((16, 16))).astype('bool') | |||
| resized_pafs = self.resize_output(pafs) | |||
| resized_heatmaps = self.resize_output(heatmaps) | |||
| resized_ignore_mask = self.resize_output(ignore_mask) | |||
| return resized_img, resized_pafs, resized_heatmaps, resized_ignore_mask | |||
| def preprocess(self, img): | |||
| @@ -459,7 +547,6 @@ class DistributedSampler(): | |||
| def __len__(self): | |||
| return self.num_samplers | |||
| def valdata(jsonpath, imgpath, rank, group_size, mode='val', maskpath=''): | |||
| #cv2.setNumThreads(0) | |||
| val = ReadJson(jsonpath) | |||
| @@ -470,23 +557,6 @@ def valdata(jsonpath, imgpath, rank, group_size, mode='val', maskpath=''): | |||
| return ds | |||
| def openpose(jsonpath, imgpath, maskpath, per_batch_size, max_epoch, rank, group_size, mode='train'): | |||
| train = ReadJson(jsonpath) | |||
| num_parallel = 48 | |||
| if group_size > 1: | |||
| num_parallel = 20 | |||
| dataset = txtdataset(train, imgpath, maskpath, params['insize'], mode=mode) | |||
| sampler = DistributedSampler(dataset, rank, group_size) | |||
| de_dataset = de.GeneratorDataset(dataset, ["image", "pafs", "heatmaps", "ignore_mask"], | |||
| num_parallel_workers=num_parallel, sampler=sampler, shuffle=True) | |||
| de_dataset = de_dataset.project(columns=["image", "pafs", "heatmaps", "ignore_mask"]) | |||
| de_dataset = de_dataset.batch(batch_size=per_batch_size, drop_remainder=True, num_parallel_workers=num_parallel) | |||
| steap_pre_epoch = de_dataset.get_dataset_size() | |||
| de_dataset = de_dataset.repeat(max_epoch) | |||
| return de_dataset, steap_pre_epoch | |||
| def create_dataset(jsonpath, imgpath, maskpath, batch_size, rank, group_size, mode='train', repeat_num=1, shuffle=True, | |||
| multiprocessing=True, num_worker=20): | |||
| @@ -15,7 +15,6 @@ | |||
| import os | |||
| import argparse | |||
| import cv2 | |||
| import numpy as np | |||
| from tqdm import tqdm | |||
| from pycocotools.coco import COCO as ReadJson | |||
| @@ -23,44 +22,44 @@ from pycocotools.coco import COCO as ReadJson | |||
| from config import params | |||
| class DataLoader(): | |||
| def __init__(self, coco, dir_name, data_mode='train'): | |||
| self.train = coco | |||
| def __init__(self, train_, dir_name, mode_='train'): | |||
| self.train = train_ | |||
| self.dir_name = dir_name | |||
| assert data_mode in ['train', 'val'], 'Data loading mode is invalid.' | |||
| self.mode = data_mode | |||
| self.catIds = coco.getCatIds() # catNms=['person'] | |||
| self.imgIds = sorted(coco.getImgIds(catIds=self.catIds)) | |||
| assert mode_ in ['train', 'val'], 'Data loading mode is invalid.' | |||
| self.mode = mode_ | |||
| self.catIds = train_.getCatIds() # catNms=['person'] | |||
| self.imgIds = sorted(train_.getImgIds(catIds=self.catIds)) | |||
| def __len__(self): | |||
| return len(self.imgIds) | |||
| def gen_masks(self, image, anns): | |||
| _mask_all = np.zeros(image.shape[:2], 'bool') | |||
| _mask_miss = np.zeros(image.shape[:2], 'bool') | |||
| for ann in anns: | |||
| def gen_masks(self, image_, annotations_): | |||
| mask_all_1 = np.zeros(image_.shape[:2], 'bool') | |||
| mask_miss_1 = np.zeros(image_.shape[:2], 'bool') | |||
| for ann in annotations_: | |||
| mask = self.train.annToMask(ann).astype('bool') | |||
| if ann['iscrowd'] == 1: | |||
| intxn = _mask_all & mask | |||
| _mask_miss = np.bitwise_or(_mask_miss.astype(int), np.subtract(mask, intxn, dtype=np.int32)) | |||
| _mask_all = np.bitwise_or(_mask_all.astype(int), mask.astype(int)) | |||
| intxn = mask_all_1 & mask | |||
| mask_miss_1 = np.bitwise_or(mask_miss_1.astype(int), np.subtract(mask, intxn, dtype=np.int32)) | |||
| mask_all_1 = np.bitwise_or(mask_all_1.astype(int), mask.astype(int)) | |||
| elif ann['num_keypoints'] < params['min_keypoints'] or ann['area'] <= params['min_area']: | |||
| _mask_all = np.bitwise_or(_mask_all.astype(int), mask.astype(int)) | |||
| _mask_miss = np.bitwise_or(_mask_miss.astype(int), mask.astype(int)) | |||
| mask_all_1 = np.bitwise_or(mask_all_1.astype(int), mask.astype(int)) | |||
| mask_miss_1 = np.bitwise_or(mask_miss_1.astype(int), mask.astype(int)) | |||
| else: | |||
| _mask_all = np.bitwise_or(_mask_all.astype(int), mask.astype(int)) | |||
| return _mask_all, _mask_miss | |||
| mask_all_1 = np.bitwise_or(mask_all_1.astype(int), mask.astype(int)) | |||
| return mask_all_1, mask_miss_1 | |||
| def dwaw_gen_masks(self, image, mask, color=(0, 0, 1)): | |||
| def dwaw_gen_masks(self, image_, mask, color=(0, 0, 1)): | |||
| bimsk = np.repeat(mask[:, :, np.newaxis], 3, axis=2) | |||
| mskd = image * bimsk.astype(np.int32) | |||
| mskd = image_ * bimsk.astype(np.int32) | |||
| clmsk = np.ones(bimsk.shape) * bimsk | |||
| for ind in range(3): | |||
| clmsk[:, :, ind] = clmsk[:, :, ind] * color[ind] * 255 | |||
| image = image + 0.7 * clmsk - 0.7 * mskd | |||
| return image.astype(np.uint8) | |||
| for index_1 in range(3): | |||
| clmsk[:, :, index_1] = clmsk[:, :, index_1] * color[index_1] * 255 | |||
| image_ = image_ + 0.7 * clmsk - 0.7 * mskd | |||
| return image_.astype(np.uint8) | |||
| def draw_masks_and_keypoints(self, image, anns): | |||
| for ann in anns: | |||
| def draw_masks_and_keypoints(self, image_, annotations_): | |||
| for ann in annotations_: | |||
| # masks | |||
| mask = self.train.annToMask(ann).astype(np.uint8) | |||
| if ann['iscrowd'] == 1: | |||
| @@ -70,30 +69,30 @@ class DataLoader(): | |||
| else: | |||
| color = (1, 0, 0) | |||
| bimsk = np.repeat(mask[:, :, np.newaxis], 3, axis=2) | |||
| mskd = image * bimsk.astype(np.int32) | |||
| mskd = image_ * bimsk.astype(np.int32) | |||
| clmsk = np.ones(bimsk.shape) * bimsk | |||
| for ind in range(3): | |||
| clmsk[:, :, ind] = clmsk[:, :, ind] * color[ind] * 255 | |||
| image = image + 0.7 * clmsk - 0.7 * mskd | |||
| for index_1 in range(3): | |||
| clmsk[:, :, index_1] = clmsk[:, :, index_1] * color[index_1] * 255 | |||
| image_ = image_ + 0.7 * clmsk - 0.7 * mskd | |||
| # keypoints | |||
| for x, y, v in np.array(ann['keypoints']).reshape(-1, 3): | |||
| if v == 1: | |||
| cv2.circle(image, (x, y), 3, (255, 255, 0), -1) | |||
| cv2.circle(image_, (x, y), 3, (255, 255, 0), -1) | |||
| elif v == 2: | |||
| cv2.circle(image, (x, y), 3, (255, 0, 255), -1) | |||
| return image.astype(np.uint8) | |||
| cv2.circle(image_, (x, y), 3, (255, 0, 255), -1) | |||
| return image_.astype(np.uint8) | |||
| def get_img_annotation(self, ind=None, image_id=None): | |||
| def get_img_annotation(self, ind=None, img_id_=None): | |||
| if ind is not None: | |||
| image_id = self.imgIds[ind] | |||
| img_id_ = self.imgIds[ind] | |||
| anno_ids = self.train.getAnnIds(imgIds=[image_id]) | |||
| _annotations = self.train.loadAnns(anno_ids) | |||
| anno_ids = self.train.getAnnIds(imgIds=[img_id_]) | |||
| annotations_ = self.train.loadAnns(anno_ids) | |||
| img_file = os.path.join(params['data_dir'], self.dir_name, self.train.loadImgs([image_id])[0]['file_name']) | |||
| _image = cv2.imread(img_file) | |||
| return _image, _annotations, image_id | |||
| img_file = os.path.join(params['data_dir'], self.dir_name, self.train.loadImgs([img_id_])[0]['file_name']) | |||
| image_ = cv2.imread(img_file) | |||
| return image_, annotations_, img_id_ | |||
| if __name__ == '__main__': | |||
| @@ -107,7 +106,7 @@ if __name__ == '__main__': | |||
| path_list = [args.train_ann, args.val_ann, args.train_dir, args.val_dir] | |||
| for index, mode in enumerate(['train', 'val']): | |||
| train = ReadJson(path_list[index]) | |||
| data_loader = DataLoader(train, path_list[index+2], mode=mode) | |||
| data_loader = DataLoader(train, path_list[index+2], mode_=mode) | |||
| save_dir = os.path.join(params['data_dir'], 'ignore_mask_{}'.format(mode)) | |||
| if not os.path.exists(save_dir): | |||
| @@ -12,37 +12,53 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import time | |||
| import mindspore.nn as nn | |||
| from mindspore.ops import operations as P | |||
| from mindspore.nn.loss.loss import _Loss | |||
| from mindspore.train.callback import Callback | |||
| from mindspore.ops import functional as F | |||
| from mindspore.ops import composite as C | |||
| from mindspore.communication.management import get_group_size | |||
| from mindspore.context import ParallelMode | |||
| from mindspore import context | |||
| from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean | |||
| from mindspore.nn.wrap.grad_reducer import DistributedGradReducer | |||
| from src.config import params | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||
| time_stamp_init = False | |||
| time_stamp_first = 0 | |||
| grad_scale = C.MultitypeFuncGraph("grad_scale") | |||
| _grad_overflow = C.MultitypeFuncGraph("_grad_overflow") | |||
| reciprocal = P.Reciprocal() | |||
| GRADIENT_CLIP_TYPE = 1 | |||
| GRADIENT_CLIP_VALUE = 1.0 | |||
| @grad_scale.register("Tensor", "Tensor") | |||
| def tensor_grad_scale(scale, grad): | |||
| return grad * F.cast(reciprocal(scale), F.dtype(grad)) | |||
| @grad_scale.register("Tensor", "RowTensor") | |||
| def tensor_grad_scale_row_tensor(scale, grad): | |||
| return RowTensor(grad.indices, | |||
| grad.values * F.cast(reciprocal(scale), F.dtype(grad.values)), | |||
| grad.dense_shape) | |||
| GRADIENT_CLIP_TYPE = params['GRADIENT_CLIP_TYPE'] | |||
| GRADIENT_CLIP_VALUE = params['GRADIENT_CLIP_VALUE'] | |||
| clip_grad = C.MultitypeFuncGraph("clip_grad") | |||
| @clip_grad.register("Number", "Number", "Tensor") | |||
| def _clip_grad(clip_type, clip_value, grad): | |||
| """ | |||
| Clip gradients. | |||
| Inputs: | |||
| clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'. | |||
| clip_value (float): Specifies how much to clip. | |||
| grad (tuple[Tensor]): Gradients. | |||
| Outputs: | |||
| tuple[Tensor]: clipped gradients. | |||
| """ | |||
| if clip_type not in (0, 1): | |||
| return grad | |||
| dt = F.dtype(grad) | |||
| if clip_type == 0: | |||
| new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt), | |||
| F.cast(F.tuple_to_array((clip_value,)), dt)) | |||
| else: | |||
| new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt)) | |||
| return new_grad | |||
| class openpose_loss(_Loss): | |||
| def __init__(self): | |||
| super(openpose_loss, self).__init__() | |||
| @@ -99,109 +115,49 @@ class openpose_loss(_Loss): | |||
| return total_loss, heatmaps_loss, pafs_loss | |||
| class Depend_network(nn.Cell): | |||
| def __init__(self, network): | |||
| super(Depend_network, self).__init__() | |||
| class BuildTrainNetwork(nn.Cell): | |||
| def __init__(self, network, criterion): | |||
| super(BuildTrainNetwork, self).__init__() | |||
| self.network = network | |||
| self.criterion = criterion | |||
| def construct(self, *args): | |||
| loss, _, _ = self.network(*args) # loss, heatmaps_loss, pafs_loss | |||
| def construct(self, input_data, gt_paf, gt_heatmap, mask): | |||
| logit_pafs, logit_heatmap = self.network(input_data) | |||
| loss, _, _ = self.criterion(logit_pafs, logit_heatmap, gt_paf, gt_heatmap, mask) | |||
| return loss | |||
| #loss = self.criterion(logit_pafs, logit_heatmap, gt_paf, gt_heatmap, mask) | |||
| # return loss, heatmaps_loss, pafs_loss | |||
| class TrainingWrapper(nn.Cell): | |||
| def __init__(self, network, optimizer, sens=1): | |||
| super(TrainingWrapper, self).__init__(auto_prefix=False) | |||
| class TrainOneStepWithClipGradientCell(nn.Cell): | |||
| '''TrainOneStepWithClipGradientCell''' | |||
| def __init__(self, network, optimizer, sens=1.0): | |||
| super(TrainOneStepWithClipGradientCell, self).__init__(auto_prefix=False) | |||
| self.network = network | |||
| self.depend_network = Depend_network(network) | |||
| # self.weights = ms.ParameterTuple(network.trainable_params()) | |||
| self.network.set_grad() | |||
| self.network.add_flags(defer_inline=True) | |||
| self.weights = optimizer.parameters | |||
| self.optimizer = optimizer | |||
| self.grad = C.GradOperation(get_by_list=True, sens_param=True) | |||
| self.hyper_map = C.HyperMap() | |||
| self.sens = sens | |||
| self.reducer_flag = False | |||
| self.grad_reducer = None | |||
| self.print = P.Print() | |||
| self.parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||
| if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: | |||
| parallel_mode = _get_parallel_mode() | |||
| if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): | |||
| self.reducer_flag = True | |||
| if self.reducer_flag: | |||
| mean = context.get_auto_parallel_context("gradients_mean") | |||
| #if mean.get_device_num_is_set(): | |||
| # if mean: | |||
| #degree = context.get_auto_parallel_context("device_num") | |||
| # else: | |||
| degree = get_group_size() | |||
| self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree) | |||
| def construct(self, *args): | |||
| mean = _get_gradients_mean() | |||
| degree = _get_device_num() | |||
| self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) | |||
| def construct(self, *inputs): | |||
| weights = self.weights | |||
| loss, heatmaps_loss, pafs_loss = self.network(*args) | |||
| loss = self.network(*inputs) | |||
| sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) | |||
| #grads = self.grad(self.network, weights)(*args, sens) | |||
| grads = self.grad(self.depend_network, weights)(*args, sens) | |||
| grads = self.grad(self.network, weights)(*inputs, sens) | |||
| grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) | |||
| if self.reducer_flag: | |||
| # apply grad reducer on grads | |||
| grads = self.grad_reducer(grads) | |||
| #return F.depend(loss, self.optimizer(grads)) | |||
| # for grad in grads: | |||
| # self.print(grad) | |||
| loss = F.depend(loss, self.optimizer(grads)) | |||
| return loss, heatmaps_loss, pafs_loss | |||
| class BuildTrainNetwork(nn.Cell): | |||
| def __init__(self, network, criterion): | |||
| super(BuildTrainNetwork, self).__init__() | |||
| self.network = network | |||
| self.criterion = criterion | |||
| def construct(self, input_data, gt_paf, gt_heatmap, mask): | |||
| logit_pafs, logit_heatmap = self.network(input_data) | |||
| loss, _, _ = self.criterion(logit_pafs, logit_heatmap, gt_paf, gt_heatmap, mask) | |||
| return loss | |||
| class LossCallBack(Callback): | |||
| """ | |||
| Monitor the loss in training. | |||
| If the loss is NAN or INF terminating training. | |||
| Note: | |||
| If per_print_times is 0 do not print loss. | |||
| Args: | |||
| per_print_times (int): Print loss every times. Default: 1. | |||
| """ | |||
| def __init__(self, per_print_times=1): | |||
| super(LossCallBack, self).__init__() | |||
| if not isinstance(per_print_times, int) or per_print_times < 0: | |||
| raise ValueError("print_step must be int and >= 0.") | |||
| self._per_print_times = per_print_times | |||
| self.count = 0 | |||
| self.loss_sum = 0 | |||
| global time_stamp_init, time_stamp_first | |||
| if not time_stamp_init: | |||
| time_stamp_first = time.time() | |||
| time_stamp_init = True | |||
| def step_end(self, run_context): | |||
| cb_params = run_context.original_args() | |||
| loss = cb_params.net_outputs.asnumpy() | |||
| self.count += 1 | |||
| self.loss_sum += float(loss) | |||
| cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 | |||
| if self.count >= 1: | |||
| global time_stamp_first | |||
| time_stamp_current = time.time() | |||
| loss = self.loss_sum/self.count | |||
| loss_file = open("./loss.log", "a+") | |||
| loss_file.write("%lu epoch: %s step: %s ,loss: %.5f" % | |||
| (time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch, | |||
| loss)) | |||
| loss_file.write("\n") | |||
| loss_file.close() | |||
| self.count = 0 | |||
| self.loss_sum = 0 | |||
| return F.depend(loss, self.optimizer(grads)) | |||
| @@ -17,19 +17,18 @@ from mindspore.nn import Conv2d, ReLU | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore.ops import operations as P | |||
| from mindspore import context | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||
| #selfCat = P.Concat(axis=1) | |||
| time_stamp_init = False | |||
| time_stamp_first = 0 | |||
| loadvgg = 1 | |||
| class OpenPoseNet(nn.Cell): | |||
| insize = 368 | |||
| def __init__(self, vggpath=''): | |||
| def __init__(self, vggpath='', vgg_with_bn=False): | |||
| super(OpenPoseNet, self).__init__() | |||
| self.base = Base_model() | |||
| self.base = Base_model(vgg_with_bn=vgg_with_bn) | |||
| self.stage_1 = Stage_1() | |||
| self.stage_2 = Stage_x() | |||
| self.stage_3 = Stage_x() | |||
| @@ -39,23 +38,15 @@ class OpenPoseNet(nn.Cell): | |||
| self.shape = P.Shape() | |||
| self.cat = P.Concat(axis=1) | |||
| self.print = P.Print() | |||
| # for m in self.modules(): | |||
| # if isinstance(m, Conv2d): | |||
| # init.constant_(m.bias, 0) | |||
| if loadvgg and vggpath: | |||
| param_dict = load_checkpoint(vggpath) | |||
| param_dict_new = {} | |||
| trans_name = 'base.vgg_base.' | |||
| for key, values in param_dict.items(): | |||
| #print('key:',key,self.shape(values)) | |||
| if key.startswith('moments.'): | |||
| continue | |||
| elif key.startswith('network.'): | |||
| param_dict_new[trans_name+key[17:]] = values | |||
| # else: | |||
| # param_dict_new[key] = values | |||
| #print(param_dict_new) | |||
| load_param_into_net(self.base.vgg_base, param_dict_new) | |||
| def construct(self, x): | |||
| @@ -205,20 +196,17 @@ class VGG_Base_MS(nn.Cell): | |||
| return x | |||
| class Base_model(nn.Cell): | |||
| def __init__(self): | |||
| def __init__(self, vgg_with_bn=False): | |||
| super(Base_model, self).__init__() | |||
| #cfgs_zh = {'16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512]} | |||
| cfgs_zh = {'19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512]} | |||
| #cfgs_zh = {'16': [64, 64,128, 128, 256, 256, 256, 512, 512, 512]} | |||
| self.vgg_base = Vgg(cfgs_zh['19'], batch_norm=False) | |||
| #self.vgg_base = VGG_Base() | |||
| #self.vgg_base = VGG_Base_MS() | |||
| self.vgg_base = Vgg(cfgs_zh['19'], batch_norm=vgg_with_bn) | |||
| self.conv4_3_CPM = Conv2d(in_channels=512, out_channels=256, kernel_size=3, stride=1, pad_mode='same', | |||
| has_bias=True) | |||
| self.conv4_4_CPM = Conv2d(in_channels=256, out_channels=128, kernel_size=3, stride=1, pad_mode='same', | |||
| has_bias=True) | |||
| self.relu = ReLU() | |||
| def construct(self, x): | |||
| x = self.vgg_base(x) | |||
| x = self.relu(self.conv4_3_CPM(x)) | |||
| @@ -1,27 +1,11 @@ | |||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import argparse | |||
| import os | |||
| import math | |||
| import time | |||
| import numpy as np | |||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||
| from mindspore.train.callback import LossMonitor | |||
| from mindspore.train.callback import LossMonitor, Callback | |||
| from mindspore.common.tensor import Tensor | |||
| from src.config import params | |||
| from mindspore.common import dtype as mstype | |||
| class MyLossMonitor(LossMonitor): | |||
| def __init__(self, per_print_times=1): | |||
| @@ -32,6 +16,7 @@ class MyLossMonitor(LossMonitor): | |||
| def step_end(self, run_context): | |||
| cb_params = run_context.original_args() | |||
| loss = cb_params.net_outputs | |||
| if isinstance(loss, (tuple, list)): | |||
| @@ -47,63 +32,76 @@ class MyLossMonitor(LossMonitor): | |||
| raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format( | |||
| cb_params.cur_epoch_num, cur_step_in_epoch)) | |||
| if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: | |||
| # print("epoch: %s step: %s, loss is %s, step time: %.3f s." % (cb_params.cur_epoch_num, cur_step_in_epoch, | |||
| # loss, | |||
| # (time.time() - self._start_time)), flush=True) | |||
| self._loss_list.append(loss) | |||
| if cb_params.cur_step_num % 100 == 0: | |||
| print("epoch: %s, steps: [%s] mean loss is: %s"%(cb_params.cur_epoch_num, cur_step_in_epoch, | |||
| np.array(self._loss_list).mean()), flush=True) | |||
| print("epoch: %s, steps: [%s], mean loss is: %s"%(cb_params.cur_epoch_num, cur_step_in_epoch, | |||
| np.array(self._loss_list).mean()), flush=True) | |||
| self._loss_list = [] | |||
| self._start_time = time.time() | |||
| class MyScaleSensCallback(Callback): | |||
| '''MyLossScaleCallback''' | |||
| def __init__(self, loss_scale_list, epoch_list): | |||
| super(MyScaleSensCallback, self).__init__() | |||
| self.loss_scale_list = loss_scale_list | |||
| self.epoch_list = epoch_list | |||
| self.scaling_sens = loss_scale_list[0] | |||
| def parse_args(): | |||
| """Parse train arguments.""" | |||
| parser = argparse.ArgumentParser('mindspore openpose training') | |||
| # dataset related | |||
| parser.add_argument('--train_dir', type=str, default='train2017', help='train data dir') | |||
| parser.add_argument('--train_ann', type=str, default='person_keypoints_train2017.json', | |||
| help='train annotations json') | |||
| parser.add_argument('--group_size', type=int, default=1, help='world size of distributed') | |||
| args, _ = parser.parse_known_args() | |||
| args.jsonpath_train = os.path.join(params['data_dir'], 'annotations/' + args.train_ann) | |||
| args.imgpath_train = os.path.join(params['data_dir'], args.train_dir) | |||
| args.maskpath_train = os.path.join(params['data_dir'], 'ignore_mask_train') | |||
| return args | |||
| def get_lr(lr, lr_gamma, steps_per_epoch, max_epoch_train, lr_steps, group_size): | |||
| lr_stage = np.array([lr] * steps_per_epoch * max_epoch_train).astype('f') | |||
| for step in lr_steps: | |||
| step //= group_size | |||
| lr_stage[step:] *= lr_gamma | |||
| def epoch_end(self, run_context): | |||
| cb_params = run_context.original_args() | |||
| epoch = cb_params.cur_epoch_num | |||
| for i, _ in enumerate(self.epoch_list): | |||
| if epoch >= self.epoch_list[i]: | |||
| self.scaling_sens = self.loss_scale_list[i+1] | |||
| else: | |||
| break | |||
| scaling_sens_tensor = Tensor(self.scaling_sens, dtype=mstype.float32) | |||
| cb_params.train_network.set_sense_scale(scaling_sens_tensor) | |||
| print("Epoch: set train network scale sens to {}".format(self.scaling_sens)) | |||
| def _linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr): | |||
| lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps) | |||
| learning_rate = float(init_lr) + lr_inc * current_step | |||
| return learning_rate | |||
| def _a_cosine_learning_rate(current_step, base_lr, warmup_steps, decay_steps): | |||
| base = float(current_step - warmup_steps) / float(decay_steps) | |||
| learning_rate = (1 + math.cos(base * math.pi)) / 2 * base_lr | |||
| return learning_rate | |||
| def _dynamic_lr(base_lr, total_steps, warmup_steps, warmup_ratio=1 / 3): | |||
| lr = [] | |||
| for i in range(total_steps): | |||
| if i < warmup_steps: | |||
| lr.append(_linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * warmup_ratio)) | |||
| else: | |||
| lr.append(_a_cosine_learning_rate(i, base_lr, warmup_steps, total_steps)) | |||
| return lr | |||
| def get_lr(lr, lr_gamma, steps_per_epoch, max_epoch_train, lr_steps, group_size, lr_type='default', warmup_epoch=5): | |||
| if lr_type == 'default': | |||
| lr_stage = np.array([lr] * steps_per_epoch * max_epoch_train).astype('f') | |||
| for step in lr_steps: | |||
| step //= group_size | |||
| lr_stage[step:] *= lr_gamma | |||
| elif lr_type == 'cosine': | |||
| lr_stage = _dynamic_lr(lr, steps_per_epoch * max_epoch_train, warmup_epoch * steps_per_epoch, | |||
| warmup_ratio=1 / 3) | |||
| lr_stage = np.array(lr_stage).astype('f') | |||
| else: | |||
| raise ValueError("lr type {} is not support.".format(lr_type)) | |||
| lr_base = lr_stage.copy() | |||
| lr_base = lr_base / 4 | |||
| lr_vgg = lr_base.copy() | |||
| vgg_freeze_step = 2000 | |||
| vgg_freeze_step = 2000 // group_size | |||
| lr_vgg[:vgg_freeze_step] = 0 | |||
| return lr_stage, lr_base, lr_vgg | |||
| # zhang add | |||
| def adjust_learning_rate(init_lr, lr_gamma, steps_per_epoch, max_epoch_train, stepvalues): | |||
| lr_stage = np.array([init_lr] * steps_per_epoch * max_epoch_train).astype('f') | |||
| for epoch in stepvalues: | |||
| lr_stage[epoch * steps_per_epoch:] *= lr_gamma | |||
| lr_base = lr_stage.copy() | |||
| lr_base = lr_base / 4 | |||
| lr_vgg = lr_base.copy() | |||
| vgg_freeze_step = 2000 | |||
| lr_vgg[:vgg_freeze_step] = 0 | |||
| return lr_stage, lr_base, lr_vgg | |||
| @@ -13,26 +13,38 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import os | |||
| import argparse | |||
| import mindspore | |||
| from mindspore import context | |||
| from mindspore.context import ParallelMode | |||
| from mindspore.communication.management import init, get_rank, get_group_size | |||
| from mindspore.train import Model | |||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor | |||
| from mindspore.nn.optim import Adam, Momentum | |||
| from mindspore.train.loss_scale_manager import FixedLossScaleManager | |||
| from mindspore.nn.optim import Adam | |||
| from src.dataset import create_dataset | |||
| from src.openposenet import OpenPoseNet | |||
| from src.loss import openpose_loss, BuildTrainNetwork | |||
| from src.loss import openpose_loss, BuildTrainNetwork, TrainOneStepWithClipGradientCell | |||
| from src.config import params | |||
| from src.utils import parse_args, get_lr, load_model, MyLossMonitor | |||
| from src.utils import get_lr, load_model, MyLossMonitor | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) | |||
| parser = argparse.ArgumentParser('mindspore openpose training') | |||
| parser.add_argument('--train_dir', type=str, default='train2017', help='train data dir') | |||
| parser.add_argument('--train_ann', type=str, default='person_keypoints_train2017.json', | |||
| help='train annotations json') | |||
| parser.add_argument('--group_size', type=int, default=1, help='world size of distributed') | |||
| args, _ = parser.parse_known_args() | |||
| args.jsonpath_train = os.path.join(params['data_dir'], 'annotations/' + args.train_ann) | |||
| args.imgpath_train = os.path.join(params['data_dir'], args.train_dir) | |||
| args.maskpath_train = os.path.join(params['data_dir'], 'ignore_mask_train') | |||
| def train(): | |||
| """Train function.""" | |||
| args = parse_args() | |||
| args.outputs_dir = params['save_model_path'] | |||
| @@ -46,11 +58,15 @@ def train(): | |||
| args.outputs_dir = os.path.join(args.outputs_dir, "ckpt_0/") | |||
| args.rank = 0 | |||
| # with out loss_scale | |||
| if args.group_size > 1: | |||
| args.max_epoch = params["max_epoch_train_NP"] | |||
| args.loss_scale = params['loss_scale'] / 2 | |||
| args.lr_steps = list(map(int, params["lr_steps_NP"].split(','))) | |||
| params['train_type'] = params['train_type_NP'] | |||
| params['optimizer'] = params['optimizer_NP'] | |||
| params['group_params'] = params['group_params_NP'] | |||
| else: | |||
| args.max_epoch = params["max_epoch_train"] | |||
| args.loss_scale = params['loss_scale'] | |||
| args.lr_steps = list(map(int, params["lr_steps"].split(','))) | |||
| @@ -58,9 +74,7 @@ def train(): | |||
| print('start create network') | |||
| criterion = openpose_loss() | |||
| criterion.add_flags_recursive(fp32=True) | |||
| network = OpenPoseNet(vggpath=params['vgg_path']) | |||
| # network.add_flags_recursive(fp32=True) | |||
| network = OpenPoseNet(vggpath=params['vgg_path'], vgg_with_bn=params['vgg_with_bn']) | |||
| if params["load_pretrain"]: | |||
| print("load pretrain model:", params["pretrained_model_path"]) | |||
| load_model(network, params["pretrained_model_path"]) | |||
| @@ -72,7 +86,7 @@ def train(): | |||
| print('start create dataset') | |||
| else: | |||
| print('Error: wrong data path') | |||
| return 0 | |||
| num_worker = 20 if args.group_size > 1 else 48 | |||
| de_dataset_train = create_dataset(args.jsonpath_train, args.imgpath_train, args.maskpath_train, | |||
| @@ -90,35 +104,63 @@ def train(): | |||
| lr_stage, lr_base, lr_vgg = get_lr(params['lr'] * args.group_size, | |||
| params['lr_gamma'], | |||
| steps_per_epoch, | |||
| params["max_epoch_train"], | |||
| args.max_epoch, | |||
| args.lr_steps, | |||
| args.group_size) | |||
| vgg19_base_params = list(filter(lambda x: 'base.vgg_base' in x.name, train_net.trainable_params())) | |||
| base_params = list(filter(lambda x: 'base.conv' in x.name, train_net.trainable_params())) | |||
| stages_params = list(filter(lambda x: 'base' not in x.name, train_net.trainable_params())) | |||
| group_params = [{'params': vgg19_base_params, 'lr': lr_vgg}, | |||
| {'params': base_params, 'lr': lr_base}, | |||
| {'params': stages_params, 'lr': lr_stage}] | |||
| opt = Adam(group_params, loss_scale=args.loss_scale) | |||
| train_net.set_train(True) | |||
| loss_scale_manager = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False) | |||
| model = Model(train_net, optimizer=opt, loss_scale_manager=loss_scale_manager) | |||
| params['ckpt_interval'] = max(steps_per_epoch, params['ckpt_interval']) | |||
| args.group_size, | |||
| lr_type=params['lr_type'], | |||
| warmup_epoch=params['warmup_epoch']) | |||
| # optimizer | |||
| if params['group_params']: | |||
| vgg19_base_params = list(filter(lambda x: 'base.vgg_base' in x.name, train_net.trainable_params())) | |||
| base_params = list(filter(lambda x: 'base.conv' in x.name, train_net.trainable_params())) | |||
| stages_params = list(filter(lambda x: 'base' not in x.name, train_net.trainable_params())) | |||
| group_params = [{'params': vgg19_base_params, 'lr': lr_vgg}, | |||
| {'params': base_params, 'lr': lr_base}, | |||
| {'params': stages_params, 'lr': lr_stage}] | |||
| if params['optimizer'] == "Momentum": | |||
| opt = Momentum(group_params, learning_rate=lr_stage, momentum=0.9) | |||
| elif params['optimizer'] == "Adam": | |||
| opt = Adam(group_params) | |||
| else: | |||
| raise ValueError("optimizer not support.") | |||
| else: | |||
| if params['optimizer'] == "Momentum": | |||
| opt = Momentum(train_net.trainable_params(), learning_rate=lr_stage, momentum=0.9) | |||
| elif params['optimizer'] == "Adam": | |||
| opt = Adam(train_net.trainable_params(), learning_rate=lr_stage) | |||
| else: | |||
| raise ValueError("optimizer not support.") | |||
| # callback | |||
| config_ck = CheckpointConfig(save_checkpoint_steps=params['ckpt_interval'], | |||
| keep_checkpoint_max=params["keep_checkpoint_max"]) | |||
| ckpoint_cb = ModelCheckpoint(prefix='{}'.format(args.rank), directory=args.outputs_dir, config=config_ck) | |||
| time_cb = TimeMonitor(data_size=de_dataset_train.get_dataset_size()) | |||
| callback_list = [MyLossMonitor(), time_cb, ckpoint_cb] | |||
| if args.rank == 0: | |||
| callback_list = [MyLossMonitor(), time_cb, ckpoint_cb] | |||
| else: | |||
| callback_list = [MyLossMonitor(), time_cb] | |||
| # train | |||
| if params['train_type'] == 'clip_grad': | |||
| train_net = TrainOneStepWithClipGradientCell(train_net, opt, sens=args.loss_scale) | |||
| train_net.set_train() | |||
| model = Model(train_net) | |||
| elif params['train_type'] == 'fix_loss_scale': | |||
| loss_scale_manager = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False) | |||
| train_net.set_train() | |||
| model = Model(train_net, optimizer=opt, loss_scale_manager=loss_scale_manager) | |||
| else: | |||
| raise ValueError("Type {} is not support.".format(params['train_type'])) | |||
| print("============== Starting Training ==============") | |||
| model.train(params["max_epoch_train"], de_dataset_train, callbacks=callback_list, | |||
| model.train(args.max_epoch, de_dataset_train, callbacks=callback_list, | |||
| dataset_sink_mode=False) | |||
| return 0 | |||
| if __name__ == "__main__": | |||
| # mindspore.common.seed.set_seed(1) | |||
| mindspore.common.seed.set_seed(1) | |||
| train() | |||