| @@ -142,10 +142,10 @@ Parameters for both training and evaluation can be set in config.py | |||||
| 'batch_size': 10 # training batch size | 'batch_size': 10 # training batch size | ||||
| 'lr_gamma': 0.1 # lr scale when reach lr_steps | 'lr_gamma': 0.1 # lr scale when reach lr_steps | ||||
| 'lr_steps': '100000,200000,250000' # the steps when lr * lr_gamma | 'lr_steps': '100000,200000,250000' # the steps when lr * lr_gamma | ||||
| 'loss scale': 16386 # the loss scale of mixed precision | |||||
| 'loss scale': 16384 # the loss scale of mixed precision | |||||
| 'max_epoch_train': 60 # total training epochs | 'max_epoch_train': 60 # total training epochs | ||||
| 'insize': 368 # image size used as input to the model | 'insize': 368 # image size used as input to the model | ||||
| 'keep_checkpoint_max': 5 # only keep the last keep_checkpoint_max checkpoint | |||||
| 'keep_checkpoint_max': 1 # only keep the last keep_checkpoint_max checkpoint | |||||
| 'log_interval': 100 # the interval of print a log | 'log_interval': 100 # the interval of print a log | ||||
| 'ckpt_interval': 5000 # the interval of saving a output model | 'ckpt_interval': 5000 # the interval of saving a output model | ||||
| ``` | ``` | ||||
| @@ -195,7 +195,7 @@ For more configuration details, please refer the script `config.py`. | |||||
| ```python | ```python | ||||
| # grep "AP" eval.log | # grep "AP" eval.log | ||||
| {'AP': 0.40030956300341397, 'Ap .5': 0.6658941566481336, 'AP .75': 0.396047897339743, 'AP (M)': 0.3075356543635785, 'AP (L)': 0.533772768618845, 'AR': 0.4519836272040302, 'AR .5': 0.693639798488665, 'AR .75': 0.4570214105793451, 'AR (M)': 0.32155148866429945, 'AR (L)': 0.6330360460795242} | |||||
| {'AP': 0.39830956300341397, 'Ap .5': 0.6658941566481336, 'AP .75': 0.396047897339743, 'AP (M)': 0.3075356543635785, 'AP (L)': 0.533772768618845, 'AR': 0.4519836272040302, 'AR .5': 0.693639798488665, 'AR .75': 0.4570214105793451, 'AR (M)': 0.32155148866429945, 'AR (L)': 0.6330360460795242} | |||||
| ``` | ``` | ||||
| @@ -209,14 +209,14 @@ For more configuration details, please refer the script `config.py`. | |||||
| | -------------------------- | ----------------------------------------------------------- | | -------------------------- | ----------------------------------------------------------- | ||||
| | Model Version | openpose | | Model Version | openpose | ||||
| | Resource | Ascend 910 ;CPU 2.60GHz,192cores;Memory,755G | | Resource | Ascend 910 ;CPU 2.60GHz,192cores;Memory,755G | ||||
| | uploaded Date | 10/20/2020 (month/day/year) | |||||
| | uploaded Date | 12/14/2020 (month/day/year) | |||||
| | MindSpore Version | 1.0.1-alpha | | MindSpore Version | 1.0.1-alpha | ||||
| | Training Parameters | epoch = 60, steps = 30k, batch_size = 10, lr = 0.0001 | |||||
| | Optimizer | Adam | |||||
| | Training Parameters | epoch=60(1pcs)/80(8pcs), steps=30k(1pcs)/5k(8pcs), batch_size=10, init_lr=0.0001 | |||||
| | Optimizer | Adam(1pcs)/Momentum(8pcs) | |||||
| | Loss Function | MSE | | Loss Function | MSE | ||||
| | outputs | pose | | outputs | pose | ||||
| | Speed | 1pc: 29imgs/s | |||||
| | Total time | 1pc: 30h | |||||
| | Speed | 1pcs: 35fps, 8pcs: 230fps | |||||
| | Total time | 1pcs: 22.5h, 8pcs: 5.1h | |||||
| | Checkpoint for Fine tuning | 602.33M (.ckpt file) | | Checkpoint for Fine tuning | 602.33M (.ckpt file) | ||||
| @@ -17,10 +17,9 @@ import os | |||||
| import argparse | import argparse | ||||
| import warnings | import warnings | ||||
| import sys | import sys | ||||
| import cv2 | |||||
| from tqdm import tqdm | |||||
| import numpy as np | import numpy as np | ||||
| from tqdm import tqdm | |||||
| import cv2 | |||||
| from scipy.ndimage.filters import gaussian_filter | from scipy.ndimage.filters import gaussian_filter | ||||
| from pycocotools.coco import COCO as LoadAnn | from pycocotools.coco import COCO as LoadAnn | ||||
| from pycocotools.cocoeval import COCOeval as MapEval | from pycocotools.cocoeval import COCOeval as MapEval | ||||
| @@ -30,9 +29,10 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net | |||||
| from mindspore.communication.management import init, get_rank, get_group_size | from mindspore.communication.management import init, get_rank, get_group_size | ||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| from src.dataset import valdata | |||||
| from src.openposenet import OpenPoseNet | |||||
| from src.config import params, JointType | from src.config import params, JointType | ||||
| from src.openposenet import OpenPoseNet | |||||
| from src.dataset import valdata | |||||
| warnings.filterwarnings("ignore") | warnings.filterwarnings("ignore") | ||||
| devid = int(os.getenv('DEVICE_ID')) | devid = int(os.getenv('DEVICE_ID')) | ||||
| @@ -40,6 +40,18 @@ context.set_context(mode=context.GRAPH_MODE, | |||||
| device_target="Ascend", save_graphs=False, device_id=devid) | device_target="Ascend", save_graphs=False, device_id=devid) | ||||
| show_gt = 0 | show_gt = 0 | ||||
| parser = argparse.ArgumentParser('mindspore openpose_net test') | |||||
| parser.add_argument('--model_path', type=str, default='./0-33_170000.ckpt', help='path of testing model') | |||||
| parser.add_argument('--imgpath_val', type=str, default='./dataset/coco/val2017', help='path of testing imgs') | |||||
| parser.add_argument('--ann', type=str, default='./dataset/coco/annotations/person_keypoints_val2017.json', | |||||
| help='path of annotations') | |||||
| parser.add_argument('--output_path', type=str, default='./output_img', help='path of testing imgs') | |||||
| # distributed related | |||||
| parser.add_argument('--is_distributed', type=int, default=0, help='if multi device') | |||||
| parser.add_argument('--rank', type=int, default=0, help='local rank of distributed') | |||||
| parser.add_argument('--group_size', type=int, default=1, help='world size of distributed') | |||||
| args, _ = parser.parse_known_args() | |||||
| def evaluate_mAP(res_file, ann_file, ann_type='keypoints', silence=True): | def evaluate_mAP(res_file, ann_file, ann_type='keypoints', silence=True): | ||||
| class NullWriter(): | class NullWriter(): | ||||
| def write(self, arg): | def write(self, arg): | ||||
| @@ -68,23 +80,6 @@ def evaluate_mAP(res_file, ann_file, ann_type='keypoints', silence=True): | |||||
| return info_str | return info_str | ||||
| def parse_args(): | |||||
| """Parse arguments.""" | |||||
| parser = argparse.ArgumentParser('mindspore openpose_net test') | |||||
| parser.add_argument('--model_path', type=str, default='./scripts/train_parallel0/checkpoints/ckpt_0/0-60_663.ckpt', | |||||
| help='path of testing model') | |||||
| parser.add_argument('--imgpath_val', type=str, default='/data0/zhy/dataset/coco/val2017', | |||||
| help='path of testing imgs') | |||||
| parser.add_argument('--ann', type=str, default='/data0/zhy/dataset/coco/annotations/person_keypoints_val2017.json', | |||||
| help='path of annotations') | |||||
| parser.add_argument('--output_path', type=str, default='./output_img', help='path of testing imgs') | |||||
| # distributed related | |||||
| parser.add_argument('--is_distributed', type=int, default=0, help='if multi device') | |||||
| parser.add_argument('--rank', type=int, default=0, help='local rank of distributed') | |||||
| parser.add_argument('--group_size', type=int, default=1, help='world size of distributed') | |||||
| args, _ = parser.parse_known_args() | |||||
| return args | |||||
| def load_model(test_net, model_path): | def load_model(test_net, model_path): | ||||
| assert os.path.exists(model_path) | assert os.path.exists(model_path) | ||||
| @@ -178,7 +173,7 @@ def compute_peaks_from_heatmaps(heatmaps): | |||||
| return all_peaks | return all_peaks | ||||
| def compute_candidate_connections(paf, cand_a, cand_b, img_len, cfg): | |||||
| def compute_candidate_connections(paf, cand_a, cand_b, img_len, params_): | |||||
| candidate_connections = [] | candidate_connections = [] | ||||
| for joint_a in cand_a: | for joint_a in cand_a: | ||||
| for joint_b in cand_b: | for joint_b in cand_b: | ||||
| @@ -186,33 +181,33 @@ def compute_candidate_connections(paf, cand_a, cand_b, img_len, cfg): | |||||
| norm = np.linalg.norm(vector) | norm = np.linalg.norm(vector) | ||||
| if norm == 0: | if norm == 0: | ||||
| continue | continue | ||||
| ys = np.linspace(joint_a[1], joint_b[1], num=cfg['n_integ_points']) | |||||
| xs = np.linspace(joint_a[0], joint_b[0], num=cfg['n_integ_points']) | |||||
| ys = np.linspace(joint_a[1], joint_b[1], num=params_['n_integ_points']) | |||||
| xs = np.linspace(joint_a[0], joint_b[0], num=params_['n_integ_points']) | |||||
| integ_points = np.stack([ys, xs]).T.round().astype('i') | integ_points = np.stack([ys, xs]).T.round().astype('i') | ||||
| paf_in_edge = np.hstack([paf[0][np.hsplit(integ_points, 2)], paf[1][np.hsplit(integ_points, 2)]]) | paf_in_edge = np.hstack([paf[0][np.hsplit(integ_points, 2)], paf[1][np.hsplit(integ_points, 2)]]) | ||||
| unit_vector = vector / norm | unit_vector = vector / norm | ||||
| inner_products = np.dot(paf_in_edge, unit_vector) | inner_products = np.dot(paf_in_edge, unit_vector) | ||||
| integ_value = inner_products.sum() / len(inner_products) | integ_value = inner_products.sum() / len(inner_products) | ||||
| integ_value_with_dist_prior = integ_value + min(cfg['limb_length_ratio'] * img_len / norm - | |||||
| cfg['length_penalty_value'], 0) | |||||
| n_valid_points = sum(inner_products > cfg['inner_product_thresh']) | |||||
| if n_valid_points > cfg['n_integ_points_thresh'] and integ_value_with_dist_prior > 0: | |||||
| integ_value_with_dist_prior = integ_value + min(params_['limb_length_ratio'] * img_len / norm - | |||||
| params_['length_penalty_value'], 0) | |||||
| n_valid_points = sum(inner_products > params_['inner_product_thresh']) | |||||
| if n_valid_points > params_['n_integ_points_thresh'] and integ_value_with_dist_prior > 0: | |||||
| candidate_connections.append([int(joint_a[3]), int(joint_b[3]), integ_value_with_dist_prior]) | candidate_connections.append([int(joint_a[3]), int(joint_b[3]), integ_value_with_dist_prior]) | ||||
| candidate_connections = sorted(candidate_connections, key=lambda x: x[2], reverse=True) | candidate_connections = sorted(candidate_connections, key=lambda x: x[2], reverse=True) | ||||
| return candidate_connections | return candidate_connections | ||||
| def compute_connections(pafs, all_peaks, img_len, cfg): | |||||
| def compute_connections(pafs, all_peaks, img_len, params_): | |||||
| all_connections = [] | all_connections = [] | ||||
| for i in range(len(cfg['limbs_point'])): | |||||
| for i in range(len(params_['limbs_point'])): | |||||
| paf_index = [i * 2, i * 2 + 1] | paf_index = [i * 2, i * 2 + 1] | ||||
| paf = pafs[paf_index] # shape: (2, 320, 320) | paf = pafs[paf_index] # shape: (2, 320, 320) | ||||
| limb_point = cfg['limbs_point'][i] # example: [<JointType.Neck: 1>, <JointType.RightWaist: 8>] | |||||
| limb_point = params_['limbs_point'][i] # example: [<JointType.Neck: 1>, <JointType.RightWaist: 8>] | |||||
| cand_a = all_peaks[all_peaks[:, 0] == limb_point[0]][:, 1:] | cand_a = all_peaks[all_peaks[:, 0] == limb_point[0]][:, 1:] | ||||
| cand_b = all_peaks[all_peaks[:, 0] == limb_point[1]][:, 1:] | cand_b = all_peaks[all_peaks[:, 0] == limb_point[1]][:, 1:] | ||||
| if cand_a.shape[0] > 0 and cand_b.shape[0] > 0: | if cand_a.shape[0] > 0 and cand_b.shape[0] > 0: | ||||
| candidate_connections = compute_candidate_connections(paf, cand_a, cand_b, img_len, cfg) | |||||
| candidate_connections = compute_candidate_connections(paf, cand_a, cand_b, img_len, params_) | |||||
| connections = np.zeros((0, 3)) | connections = np.zeros((0, 3)) | ||||
| @@ -226,11 +221,11 @@ def compute_connections(pafs, all_peaks, img_len, cfg): | |||||
| all_connections.append(np.zeros((0, 3))) | all_connections.append(np.zeros((0, 3))) | ||||
| return all_connections | return all_connections | ||||
| def grouping_key_points(all_connections, candidate_peaks, cfg): | |||||
| def grouping_key_points(all_connections, candidate_peaks, params_): | |||||
| subsets = -1 * np.ones((0, 20)) | subsets = -1 * np.ones((0, 20)) | ||||
| for l, connections in enumerate(all_connections): | for l, connections in enumerate(all_connections): | ||||
| joint_a, joint_b = cfg['limbs_point'][l] | |||||
| joint_a, joint_b = params_['limbs_point'][l] | |||||
| for ind_a, ind_b, score in connections[:, :3]: | for ind_a, ind_b, score in connections[:, :3]: | ||||
| ind_a, ind_b = int(ind_a), int(ind_b) | ind_a, ind_b = int(ind_a), int(ind_b) | ||||
| joint_found_cnt = 0 | joint_found_cnt = 0 | ||||
| @@ -249,6 +244,7 @@ def grouping_key_points(all_connections, candidate_peaks, cfg): | |||||
| found_subset[-1] += 1 # increment joint count | found_subset[-1] += 1 # increment joint count | ||||
| found_subset[-2] += candidate_peaks[ind_b, 3] + score | found_subset[-2] += candidate_peaks[ind_b, 3] + score | ||||
| elif joint_found_cnt == 2: | elif joint_found_cnt == 2: | ||||
| found_subset_1 = subsets[joint_found_subset_index[0]] | found_subset_1 = subsets[joint_found_subset_index[0]] | ||||
| @@ -289,10 +285,8 @@ def grouping_key_points(all_connections, candidate_peaks, cfg): | |||||
| pass | pass | ||||
| # delete low score subsets | # delete low score subsets | ||||
| keep = np.logical_and(subsets[:, -1] >= cfg['n_subset_limbs_thresh'], | |||||
| subsets[:, -2] / subsets[:, -1] >= cfg['subset_score_thresh']) | |||||
| # cfg['n_subset_limbs_thresh'] = 3 | |||||
| # cfg['subset_score_thresh'] = 0.2 | |||||
| keep = np.logical_and(subsets[:, -1] >= params_['n_subset_limbs_thresh'], | |||||
| subsets[:, -2] / subsets[:, -1] >= params_['subset_score_thresh']) | |||||
| subsets = subsets[keep] | subsets = subsets[keep] | ||||
| return subsets | return subsets | ||||
| @@ -319,7 +313,7 @@ def detect(img, network): | |||||
| # map_w, map_h = compute_optimal_size(orig_img, params['heatmap_size']) # 320 | # map_w, map_h = compute_optimal_size(orig_img, params['heatmap_size']) # 320 | ||||
| map_w, map_h = compute_optimal_size(orig_img, params['inference_img_size']) | map_w, map_h = compute_optimal_size(orig_img, params['inference_img_size']) | ||||
| print("image size is: ", input_w, input_h) | |||||
| # print("image size is: ", input_w, input_h) | |||||
| resized_image = cv2.resize(orig_img, (input_w, input_h)) | resized_image = cv2.resize(orig_img, (input_w, input_h)) | ||||
| x_data = preprocess(resized_image) | x_data = preprocess(resized_image) | ||||
| @@ -394,7 +388,7 @@ def draw_person_pose(orig_img, poses): | |||||
| return canvas | return canvas | ||||
| def depreprocess(img): | def depreprocess(img): | ||||
| # x_data = img.astype('f') | |||||
| #x_data = img.astype('f') | |||||
| x_data = img[0] | x_data = img[0] | ||||
| x_data += 0.5 | x_data += 0.5 | ||||
| x_data *= 255 | x_data *= 255 | ||||
| @@ -402,15 +396,14 @@ def depreprocess(img): | |||||
| x_data = x_data.transpose(1, 2, 0) | x_data = x_data.transpose(1, 2, 0) | ||||
| return x_data | return x_data | ||||
| def _eval(): | |||||
| args = parse_args() | |||||
| def val(): | |||||
| if args.is_distributed: | if args.is_distributed: | ||||
| init() | init() | ||||
| args.rank = get_rank() | args.rank = get_rank() | ||||
| args.group_size = get_group_size() | args.group_size = get_group_size() | ||||
| if not os.path.exists(args.output_path): | if not os.path.exists(args.output_path): | ||||
| os.mkdir(args.output_path) | os.mkdir(args.output_path) | ||||
| network = OpenPoseNet() | |||||
| network = OpenPoseNet(vgg_with_bn=params['vgg_with_bn']) | |||||
| network.set_train(False) | network.set_train(False) | ||||
| load_model(network, args.model_path) | load_model(network, args.model_path) | ||||
| @@ -455,4 +448,4 @@ def _eval(): | |||||
| print('result: ', res) | print('result: ', res) | ||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| _eval() | |||||
| val() | |||||
| @@ -15,9 +15,8 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| export DEVICE_ID=0 | export DEVICE_ID=0 | ||||
| export RANK_ID=0 | |||||
| python eval.py \ | python eval.py \ | ||||
| --model_path ./scripts/train_parallel0/checkpoints/ckpt_0/0-60_663.ckpt \ | |||||
| --imgpath_val /data0/zhy/dataset/coco/val2017 \ | |||||
| --ann /data0/zhy/dataset/coco/annotations/person_keypoints_val2017.json \ | |||||
| --model_path ./scripts/train_parallel0/checkpoints/ckpt_0/0-80_663.ckpt \ | |||||
| --imgpath_val ./dataset/val2017 \ | |||||
| --ann ./dataset/annotations/person_keypoints_val2017.json \ | |||||
| > eval.log 2>&1 & | > eval.log 2>&1 & | ||||
| @@ -14,5 +14,6 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| export DEVICE_ID=0 | |||||
| cd .. | cd .. | ||||
| python train.py --train_dir train2017 --train_ann person_keypoints_train2017.json > scripts/train.log 2>&1 & | python train.py --train_dir train2017 --train_ann person_keypoints_train2017.json > scripts/train.log 2>&1 & | ||||
| @@ -53,21 +53,41 @@ class JointType(IntEnum): | |||||
| params = { | params = { | ||||
| # paths | # paths | ||||
| 'data_dir': '/data0/zhy/dataset/coco', | |||||
| 'vgg_path': '/data0/zhy/dataset/coco/vgg19-0-97_5004.ckpt', | |||||
| 'data_dir': './dataset', | |||||
| 'save_model_path': './checkpoints/', | 'save_model_path': './checkpoints/', | ||||
| 'load_pretrain': False, | 'load_pretrain': False, | ||||
| 'pretrained_model_path': "", | 'pretrained_model_path': "", | ||||
| # training params | |||||
| 'batch_size': 10, | |||||
| 'lr': 1e-4, | |||||
| 'lr_gamma': 0.1, | |||||
| 'lr_steps': '100000,200000,250000', | |||||
| 'lr_steps_NP': '250000', | |||||
| # train type | |||||
| 'train_type': 'fix_loss_scale', # chose in ['clip_grad', 'fix_loss_scale'] | |||||
| 'train_type_NP': 'clip_grad', | |||||
| # vgg bn | |||||
| 'vgg_with_bn': False, | |||||
| 'vgg_path': './vgg_model/vgg19-0-97_5004.ckpt', | |||||
| # if clip_grad | |||||
| 'GRADIENT_CLIP_TYPE': 1, | |||||
| 'GRADIENT_CLIP_VALUE': 10.0, | |||||
| 'loss_scale': 16386, | |||||
| # optimizer and lr | |||||
| 'optimizer': "Adam", # chose in ['Momentum', 'Adam'] | |||||
| 'optimizer_NP': "Momentum", | |||||
| 'group_params': True, | |||||
| 'group_params_NP': False, | |||||
| 'lr': 1e-4, | |||||
| 'lr_type': 'default', # chose in ["default", "cosine"] | |||||
| 'lr_gamma': 0.1, # if default | |||||
| 'lr_steps': '100000,200000,250000', # if default | |||||
| 'lr_steps_NP': '250000,300000', # if default | |||||
| 'warmup_epoch': 5, # if cosine | |||||
| 'max_epoch_train': 60, | 'max_epoch_train': 60, | ||||
| 'max_epoch_train_NP': 80, | |||||
| 'loss_scale': 16384, | |||||
| # default param | |||||
| 'batch_size': 10, | |||||
| 'min_keypoints': 5, | 'min_keypoints': 5, | ||||
| 'min_area': 32 * 32, | 'min_area': 32 * 32, | ||||
| 'insize': 368, | 'insize': 368, | ||||
| @@ -75,9 +95,9 @@ params = { | |||||
| 'paf_sigma': 8, | 'paf_sigma': 8, | ||||
| 'heatmap_sigma': 7, | 'heatmap_sigma': 7, | ||||
| 'eva_num': 100, | 'eva_num': 100, | ||||
| 'keep_checkpoint_max': 5, | |||||
| 'keep_checkpoint_max': 1, | |||||
| 'log_interval': 100, | 'log_interval': 100, | ||||
| 'ckpt_interval': 663, # 5000, | |||||
| 'ckpt_interval': 5304, | |||||
| 'min_box_size': 64, | 'min_box_size': 64, | ||||
| 'max_box_size': 512, | 'max_box_size': 512, | ||||
| @@ -15,10 +15,10 @@ | |||||
| import os | import os | ||||
| import math | import math | ||||
| import random | import random | ||||
| import cv2 | |||||
| import numpy as np | import numpy as np | ||||
| import cv2 | |||||
| from pycocotools.coco import COCO as ReadJson | from pycocotools.coco import COCO as ReadJson | ||||
| import mindspore.dataset as de | import mindspore.dataset as de | ||||
| from src.config import JointType, params | from src.config import JointType, params | ||||
| @@ -41,6 +41,7 @@ class txtdataset(): | |||||
| self.imgIds = random.sample(self.imgIds, n_samples) | self.imgIds = random.sample(self.imgIds, n_samples) | ||||
| print('{} images: {}'.format(mode, len(self))) | print('{} images: {}'.format(mode, len(self))) | ||||
| def __len__(self): | def __len__(self): | ||||
| return len(self.imgIds) | return len(self.imgIds) | ||||
| @@ -217,9 +218,9 @@ class txtdataset(): | |||||
| flipped_mask = cv2.flip(mask.astype(np.uint8), 1).astype('bool') | flipped_mask = cv2.flip(mask.astype(np.uint8), 1).astype('bool') | ||||
| poses[:, :, 0] = img.shape[1] - 1 - poses[:, :, 0] | poses[:, :, 0] = img.shape[1] - 1 - poses[:, :, 0] | ||||
| def swap_joints(poses, joint_type_1, joint_type_2): | |||||
| tmp = poses[:, joint_type_1].copy() | |||||
| poses[:, joint_type_1] = poses[:, joint_type_2] | |||||
| def swap_joints(poses, joint_type_, joint_type_2): | |||||
| tmp = poses[:, joint_type_].copy() | |||||
| poses[:, joint_type_] = poses[:, joint_type_2] | |||||
| poses[:, joint_type_2] = tmp | poses[:, joint_type_2] = tmp | ||||
| swap_joints(poses, JointType.LeftEye, JointType.RightEye) | swap_joints(poses, JointType.LeftEye, JointType.RightEye) | ||||
| @@ -243,8 +244,10 @@ class txtdataset(): | |||||
| aug_img, ignore_mask, poses = self.flip_img(aug_img, ignore_mask, poses) | aug_img, ignore_mask, poses = self.flip_img(aug_img, ignore_mask, poses) | ||||
| return aug_img, ignore_mask, poses | return aug_img, ignore_mask, poses | ||||
| # ------------------------------------------------------------------ | |||||
| # ------------------------------- end ----------------------------------- | |||||
| # ------------------------------ Heatmap ------------------------------------ | |||||
| # return shape: (height, width) | # return shape: (height, width) | ||||
| def generate_gaussian_heatmap(self, shape, joint, sigma): | def generate_gaussian_heatmap(self, shape, joint, sigma): | ||||
| x, y = joint | x, y = joint | ||||
| @@ -269,6 +272,38 @@ class txtdataset(): | |||||
| heatmaps = np.vstack((heatmaps, bg_heatmap[None])) | heatmaps = np.vstack((heatmaps, bg_heatmap[None])) | ||||
| return heatmaps.astype('f') | return heatmaps.astype('f') | ||||
| def generate_gaussian_heatmap_fast(self, shape, joint, sigma): | |||||
| x, y = joint | |||||
| grid_x = np.tile(np.arange(shape[1]), (shape[0], 1)) | |||||
| grid_y = np.tile(np.arange(shape[0]), (shape[1], 1)).transpose() | |||||
| grid_x = grid_x + 0.4375 | |||||
| grid_y = grid_y + 0.4375 | |||||
| grid_distance = (grid_x - x) ** 2 + (grid_y - y) ** 2 | |||||
| gaussian_heatmap = np.exp(-0.5 * grid_distance / sigma**2) | |||||
| return gaussian_heatmap | |||||
| def generate_heatmaps_fast(self, img, poses, heatmap_sigma): | |||||
| resize_shape = (img.shape[0] // 8, img.shape[1] // 8) | |||||
| heatmaps = np.zeros((0,) + resize_shape) | |||||
| sum_heatmap = np.zeros(resize_shape) | |||||
| for joint_index in range(len(JointType)): | |||||
| heatmap = np.zeros(resize_shape) | |||||
| for pose in poses: | |||||
| if pose[joint_index, 2] > 0: | |||||
| jointmap = self.generate_gaussian_heatmap_fast(resize_shape, pose[joint_index][:2]/8, | |||||
| heatmap_sigma/8) | |||||
| index_1 = jointmap > heatmap | |||||
| heatmap[index_1] = jointmap[index_1] | |||||
| index_2 = jointmap > sum_heatmap | |||||
| sum_heatmap[index_2] = jointmap[index_2] | |||||
| heatmaps = np.vstack((heatmaps, heatmap.reshape((1,) + heatmap.shape))) | |||||
| bg_heatmap = 1 - sum_heatmap # background channel | |||||
| heatmaps = np.vstack((heatmaps, bg_heatmap[None])) | |||||
| return heatmaps.astype('f') | |||||
| # ------------------------------ end ------------------------------------ | |||||
| # ------------------------------ PAF ------------------------------------ | |||||
| # return shape: (2, height, width) | # return shape: (2, height, width) | ||||
| def generate_constant_paf(self, shape, joint_from, joint_to, paf_width): | def generate_constant_paf(self, shape, joint_from, joint_to, paf_width): | ||||
| if np.array_equal(joint_from, joint_to): # same joint | if np.array_equal(joint_from, joint_to): # same joint | ||||
| @@ -285,7 +320,7 @@ class txtdataset(): | |||||
| grid_y = np.tile(np.arange(shape[0]), (shape[1], 1)).transpose() | grid_y = np.tile(np.arange(shape[0]), (shape[1], 1)).transpose() | ||||
| horizontal_inner_product = unit_vector[0] * (grid_x - joint_from[0]) + unit_vector[1] * (grid_y - joint_from[1]) | horizontal_inner_product = unit_vector[0] * (grid_x - joint_from[0]) + unit_vector[1] * (grid_y - joint_from[1]) | ||||
| horizontal_paf_flag = (horizontal_inner_product >= 0) & (horizontal_inner_product <= joint_distance) | horizontal_paf_flag = (horizontal_inner_product >= 0) & (horizontal_inner_product <= joint_distance) | ||||
| vertical_inner_product = vertical_unit_vector[0] * (grid_x - joint_from[0]) + vertical_unit_vector[1] * \ | |||||
| vertical_inner_product = vertical_unit_vector[0] * (grid_x - joint_from[0]) + vertical_unit_vector[1] *\ | |||||
| (grid_y - joint_from[1]) | (grid_y - joint_from[1]) | ||||
| vertical_paf_flag = np.abs(vertical_inner_product) <= paf_width # paf_width : 8 | vertical_paf_flag = np.abs(vertical_inner_product) <= paf_width # paf_width : 8 | ||||
| paf_flag = horizontal_paf_flag & vertical_paf_flag | paf_flag = horizontal_paf_flag & vertical_paf_flag | ||||
| @@ -314,6 +349,55 @@ class txtdataset(): | |||||
| pafs = np.vstack((pafs, paf)) | pafs = np.vstack((pafs, paf)) | ||||
| return pafs.astype('f') | return pafs.astype('f') | ||||
| def generate_constant_paf_fast(self, shape, joint_from, joint_to, paf_width): | |||||
| if np.array_equal(joint_from, joint_to): # same joint | |||||
| return np.zeros((2,) + shape[:-1]) | |||||
| joint_distance = np.linalg.norm(joint_to - joint_from) | |||||
| unit_vector = (joint_to - joint_from) / joint_distance | |||||
| rad = np.pi / 2 | |||||
| # [[0, 1], [-1, 0]] | |||||
| rot_matrix = np.array([[np.cos(rad), np.sin(rad)], [-np.sin(rad), np.cos(rad)]]) | |||||
| # [[u_y], [-u_x]] | |||||
| vertical_unit_vector = np.dot(rot_matrix, unit_vector) | |||||
| grid_x = np.tile(np.arange(shape[1]), (shape[0], 1)) | |||||
| grid_y = np.tile(np.arange(shape[0]), (shape[1], 1)).transpose() | |||||
| grid_x = grid_x + 0.4375 | |||||
| grid_y = grid_y + 0.4375 | |||||
| horizontal_inner_product = unit_vector[0] * (grid_x - joint_from[0]) + unit_vector[1] * (grid_y - joint_from[1]) | |||||
| horizontal_paf_flag = (horizontal_inner_product >= 0) & (horizontal_inner_product <= joint_distance) | |||||
| vertical_inner_product = vertical_unit_vector[0] * (grid_x - joint_from[0]) + vertical_unit_vector[1] *\ | |||||
| (grid_y - joint_from[1]) | |||||
| vertical_paf_flag = np.abs(vertical_inner_product) <= paf_width # paf_width : 8/8 = 1 | |||||
| paf_flag = horizontal_paf_flag & vertical_paf_flag | |||||
| constant_paf = np.stack((paf_flag, paf_flag)) *\ | |||||
| np.broadcast_to(unit_vector, shape[:-1] + (2,)).transpose(2, 0, 1) | |||||
| return constant_paf | |||||
| def generate_pafs_fast(self, img, poses, paf_sigma): | |||||
| resize_shape = (img.shape[0]//8, img.shape[1]//8, 3) | |||||
| pafs = np.zeros((0,) + resize_shape[:-1]) | |||||
| for limb in params['limbs_point']: | |||||
| paf = np.zeros((2,) + resize_shape[:-1]) | |||||
| paf_flags = np.zeros(paf.shape) # for constant paf | |||||
| for pose in poses: | |||||
| joint_from, joint_to = pose[limb] | |||||
| if joint_from[2] > 0 and joint_to[2] > 0: | |||||
| limb_paf = self.generate_constant_paf_fast(resize_shape, joint_from[:2]/8, joint_to[:2]/8, paf_sigma/8) # [2, 368, 368] | |||||
| limb_paf_flags = limb_paf != 0 | |||||
| paf_flags += np.broadcast_to(limb_paf_flags[0] | limb_paf_flags[1], limb_paf.shape) | |||||
| paf += limb_paf | |||||
| index_1 = paf_flags > 0 | |||||
| paf[index_1] /= paf_flags[index_1] | |||||
| pafs = np.vstack((pafs, paf)) | |||||
| return pafs.astype('f') | |||||
| # ------------------------------ end ------------------------------------ | |||||
| def get_img_annotation(self, ind=None, img_id=None): | def get_img_annotation(self, ind=None, img_id=None): | ||||
| annotations = None | annotations = None | ||||
| @@ -389,14 +473,18 @@ class txtdataset(): | |||||
| resized_img, ignore_mask, resized_poses = self.resize_data(img, ignore_mask, poses, | resized_img, ignore_mask, resized_poses = self.resize_data(img, ignore_mask, poses, | ||||
| shape=(self.insize, self.insize)) | shape=(self.insize, self.insize)) | ||||
| heatmaps = self.generate_heatmaps(resized_img, resized_poses, params['heatmap_sigma']) | |||||
| pafs = self.generate_pafs(resized_img, resized_poses, params['paf_sigma']) # params['paf_sigma']: 8 | |||||
| # heatmaps = self.generate_heatmaps(resized_img, resized_poses, params['heatmap_sigma']) | |||||
| # resized_heatmaps = self.resize_output(heatmaps) | |||||
| resized_heatmaps = self.generate_heatmaps_fast(resized_img, resized_poses, params['heatmap_sigma']) | |||||
| # pafs = self.generate_pafs(resized_img, resized_poses, params['paf_sigma']) | |||||
| # resized_pafs = self.resize_output(pafs) | |||||
| resized_pafs = self.generate_pafs_fast(resized_img, resized_poses, params['paf_sigma']) | |||||
| ignore_mask = cv2.morphologyEx(ignore_mask.astype('uint8'), cv2.MORPH_DILATE, np.ones((16, 16))).astype('bool') | ignore_mask = cv2.morphologyEx(ignore_mask.astype('uint8'), cv2.MORPH_DILATE, np.ones((16, 16))).astype('bool') | ||||
| resized_pafs = self.resize_output(pafs) | |||||
| resized_heatmaps = self.resize_output(heatmaps) | |||||
| resized_ignore_mask = self.resize_output(ignore_mask) | resized_ignore_mask = self.resize_output(ignore_mask) | ||||
| return resized_img, resized_pafs, resized_heatmaps, resized_ignore_mask | return resized_img, resized_pafs, resized_heatmaps, resized_ignore_mask | ||||
| def preprocess(self, img): | def preprocess(self, img): | ||||
| @@ -459,7 +547,6 @@ class DistributedSampler(): | |||||
| def __len__(self): | def __len__(self): | ||||
| return self.num_samplers | return self.num_samplers | ||||
| def valdata(jsonpath, imgpath, rank, group_size, mode='val', maskpath=''): | def valdata(jsonpath, imgpath, rank, group_size, mode='val', maskpath=''): | ||||
| #cv2.setNumThreads(0) | #cv2.setNumThreads(0) | ||||
| val = ReadJson(jsonpath) | val = ReadJson(jsonpath) | ||||
| @@ -470,23 +557,6 @@ def valdata(jsonpath, imgpath, rank, group_size, mode='val', maskpath=''): | |||||
| return ds | return ds | ||||
| def openpose(jsonpath, imgpath, maskpath, per_batch_size, max_epoch, rank, group_size, mode='train'): | |||||
| train = ReadJson(jsonpath) | |||||
| num_parallel = 48 | |||||
| if group_size > 1: | |||||
| num_parallel = 20 | |||||
| dataset = txtdataset(train, imgpath, maskpath, params['insize'], mode=mode) | |||||
| sampler = DistributedSampler(dataset, rank, group_size) | |||||
| de_dataset = de.GeneratorDataset(dataset, ["image", "pafs", "heatmaps", "ignore_mask"], | |||||
| num_parallel_workers=num_parallel, sampler=sampler, shuffle=True) | |||||
| de_dataset = de_dataset.project(columns=["image", "pafs", "heatmaps", "ignore_mask"]) | |||||
| de_dataset = de_dataset.batch(batch_size=per_batch_size, drop_remainder=True, num_parallel_workers=num_parallel) | |||||
| steap_pre_epoch = de_dataset.get_dataset_size() | |||||
| de_dataset = de_dataset.repeat(max_epoch) | |||||
| return de_dataset, steap_pre_epoch | |||||
| def create_dataset(jsonpath, imgpath, maskpath, batch_size, rank, group_size, mode='train', repeat_num=1, shuffle=True, | def create_dataset(jsonpath, imgpath, maskpath, batch_size, rank, group_size, mode='train', repeat_num=1, shuffle=True, | ||||
| multiprocessing=True, num_worker=20): | multiprocessing=True, num_worker=20): | ||||
| @@ -15,7 +15,6 @@ | |||||
| import os | import os | ||||
| import argparse | import argparse | ||||
| import cv2 | import cv2 | ||||
| import numpy as np | import numpy as np | ||||
| from tqdm import tqdm | from tqdm import tqdm | ||||
| from pycocotools.coco import COCO as ReadJson | from pycocotools.coco import COCO as ReadJson | ||||
| @@ -23,44 +22,44 @@ from pycocotools.coco import COCO as ReadJson | |||||
| from config import params | from config import params | ||||
| class DataLoader(): | class DataLoader(): | ||||
| def __init__(self, coco, dir_name, data_mode='train'): | |||||
| self.train = coco | |||||
| def __init__(self, train_, dir_name, mode_='train'): | |||||
| self.train = train_ | |||||
| self.dir_name = dir_name | self.dir_name = dir_name | ||||
| assert data_mode in ['train', 'val'], 'Data loading mode is invalid.' | |||||
| self.mode = data_mode | |||||
| self.catIds = coco.getCatIds() # catNms=['person'] | |||||
| self.imgIds = sorted(coco.getImgIds(catIds=self.catIds)) | |||||
| assert mode_ in ['train', 'val'], 'Data loading mode is invalid.' | |||||
| self.mode = mode_ | |||||
| self.catIds = train_.getCatIds() # catNms=['person'] | |||||
| self.imgIds = sorted(train_.getImgIds(catIds=self.catIds)) | |||||
| def __len__(self): | def __len__(self): | ||||
| return len(self.imgIds) | return len(self.imgIds) | ||||
| def gen_masks(self, image, anns): | |||||
| _mask_all = np.zeros(image.shape[:2], 'bool') | |||||
| _mask_miss = np.zeros(image.shape[:2], 'bool') | |||||
| for ann in anns: | |||||
| def gen_masks(self, image_, annotations_): | |||||
| mask_all_1 = np.zeros(image_.shape[:2], 'bool') | |||||
| mask_miss_1 = np.zeros(image_.shape[:2], 'bool') | |||||
| for ann in annotations_: | |||||
| mask = self.train.annToMask(ann).astype('bool') | mask = self.train.annToMask(ann).astype('bool') | ||||
| if ann['iscrowd'] == 1: | if ann['iscrowd'] == 1: | ||||
| intxn = _mask_all & mask | |||||
| _mask_miss = np.bitwise_or(_mask_miss.astype(int), np.subtract(mask, intxn, dtype=np.int32)) | |||||
| _mask_all = np.bitwise_or(_mask_all.astype(int), mask.astype(int)) | |||||
| intxn = mask_all_1 & mask | |||||
| mask_miss_1 = np.bitwise_or(mask_miss_1.astype(int), np.subtract(mask, intxn, dtype=np.int32)) | |||||
| mask_all_1 = np.bitwise_or(mask_all_1.astype(int), mask.astype(int)) | |||||
| elif ann['num_keypoints'] < params['min_keypoints'] or ann['area'] <= params['min_area']: | elif ann['num_keypoints'] < params['min_keypoints'] or ann['area'] <= params['min_area']: | ||||
| _mask_all = np.bitwise_or(_mask_all.astype(int), mask.astype(int)) | |||||
| _mask_miss = np.bitwise_or(_mask_miss.astype(int), mask.astype(int)) | |||||
| mask_all_1 = np.bitwise_or(mask_all_1.astype(int), mask.astype(int)) | |||||
| mask_miss_1 = np.bitwise_or(mask_miss_1.astype(int), mask.astype(int)) | |||||
| else: | else: | ||||
| _mask_all = np.bitwise_or(_mask_all.astype(int), mask.astype(int)) | |||||
| return _mask_all, _mask_miss | |||||
| mask_all_1 = np.bitwise_or(mask_all_1.astype(int), mask.astype(int)) | |||||
| return mask_all_1, mask_miss_1 | |||||
| def dwaw_gen_masks(self, image, mask, color=(0, 0, 1)): | |||||
| def dwaw_gen_masks(self, image_, mask, color=(0, 0, 1)): | |||||
| bimsk = np.repeat(mask[:, :, np.newaxis], 3, axis=2) | bimsk = np.repeat(mask[:, :, np.newaxis], 3, axis=2) | ||||
| mskd = image * bimsk.astype(np.int32) | |||||
| mskd = image_ * bimsk.astype(np.int32) | |||||
| clmsk = np.ones(bimsk.shape) * bimsk | clmsk = np.ones(bimsk.shape) * bimsk | ||||
| for ind in range(3): | |||||
| clmsk[:, :, ind] = clmsk[:, :, ind] * color[ind] * 255 | |||||
| image = image + 0.7 * clmsk - 0.7 * mskd | |||||
| return image.astype(np.uint8) | |||||
| for index_1 in range(3): | |||||
| clmsk[:, :, index_1] = clmsk[:, :, index_1] * color[index_1] * 255 | |||||
| image_ = image_ + 0.7 * clmsk - 0.7 * mskd | |||||
| return image_.astype(np.uint8) | |||||
| def draw_masks_and_keypoints(self, image, anns): | |||||
| for ann in anns: | |||||
| def draw_masks_and_keypoints(self, image_, annotations_): | |||||
| for ann in annotations_: | |||||
| # masks | # masks | ||||
| mask = self.train.annToMask(ann).astype(np.uint8) | mask = self.train.annToMask(ann).astype(np.uint8) | ||||
| if ann['iscrowd'] == 1: | if ann['iscrowd'] == 1: | ||||
| @@ -70,30 +69,30 @@ class DataLoader(): | |||||
| else: | else: | ||||
| color = (1, 0, 0) | color = (1, 0, 0) | ||||
| bimsk = np.repeat(mask[:, :, np.newaxis], 3, axis=2) | bimsk = np.repeat(mask[:, :, np.newaxis], 3, axis=2) | ||||
| mskd = image * bimsk.astype(np.int32) | |||||
| mskd = image_ * bimsk.astype(np.int32) | |||||
| clmsk = np.ones(bimsk.shape) * bimsk | clmsk = np.ones(bimsk.shape) * bimsk | ||||
| for ind in range(3): | |||||
| clmsk[:, :, ind] = clmsk[:, :, ind] * color[ind] * 255 | |||||
| image = image + 0.7 * clmsk - 0.7 * mskd | |||||
| for index_1 in range(3): | |||||
| clmsk[:, :, index_1] = clmsk[:, :, index_1] * color[index_1] * 255 | |||||
| image_ = image_ + 0.7 * clmsk - 0.7 * mskd | |||||
| # keypoints | # keypoints | ||||
| for x, y, v in np.array(ann['keypoints']).reshape(-1, 3): | for x, y, v in np.array(ann['keypoints']).reshape(-1, 3): | ||||
| if v == 1: | if v == 1: | ||||
| cv2.circle(image, (x, y), 3, (255, 255, 0), -1) | |||||
| cv2.circle(image_, (x, y), 3, (255, 255, 0), -1) | |||||
| elif v == 2: | elif v == 2: | ||||
| cv2.circle(image, (x, y), 3, (255, 0, 255), -1) | |||||
| return image.astype(np.uint8) | |||||
| cv2.circle(image_, (x, y), 3, (255, 0, 255), -1) | |||||
| return image_.astype(np.uint8) | |||||
| def get_img_annotation(self, ind=None, image_id=None): | |||||
| def get_img_annotation(self, ind=None, img_id_=None): | |||||
| if ind is not None: | if ind is not None: | ||||
| image_id = self.imgIds[ind] | |||||
| img_id_ = self.imgIds[ind] | |||||
| anno_ids = self.train.getAnnIds(imgIds=[image_id]) | |||||
| _annotations = self.train.loadAnns(anno_ids) | |||||
| anno_ids = self.train.getAnnIds(imgIds=[img_id_]) | |||||
| annotations_ = self.train.loadAnns(anno_ids) | |||||
| img_file = os.path.join(params['data_dir'], self.dir_name, self.train.loadImgs([image_id])[0]['file_name']) | |||||
| _image = cv2.imread(img_file) | |||||
| return _image, _annotations, image_id | |||||
| img_file = os.path.join(params['data_dir'], self.dir_name, self.train.loadImgs([img_id_])[0]['file_name']) | |||||
| image_ = cv2.imread(img_file) | |||||
| return image_, annotations_, img_id_ | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| @@ -107,7 +106,7 @@ if __name__ == '__main__': | |||||
| path_list = [args.train_ann, args.val_ann, args.train_dir, args.val_dir] | path_list = [args.train_ann, args.val_ann, args.train_dir, args.val_dir] | ||||
| for index, mode in enumerate(['train', 'val']): | for index, mode in enumerate(['train', 'val']): | ||||
| train = ReadJson(path_list[index]) | train = ReadJson(path_list[index]) | ||||
| data_loader = DataLoader(train, path_list[index+2], mode=mode) | |||||
| data_loader = DataLoader(train, path_list[index+2], mode_=mode) | |||||
| save_dir = os.path.join(params['data_dir'], 'ignore_mask_{}'.format(mode)) | save_dir = os.path.join(params['data_dir'], 'ignore_mask_{}'.format(mode)) | ||||
| if not os.path.exists(save_dir): | if not os.path.exists(save_dir): | ||||
| @@ -12,37 +12,53 @@ | |||||
| # See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| import time | |||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.nn.loss.loss import _Loss | from mindspore.nn.loss.loss import _Loss | ||||
| from mindspore.train.callback import Callback | |||||
| from mindspore.ops import functional as F | from mindspore.ops import functional as F | ||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from mindspore.communication.management import get_group_size | |||||
| from mindspore.context import ParallelMode | from mindspore.context import ParallelMode | ||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean | |||||
| from mindspore.nn.wrap.grad_reducer import DistributedGradReducer | |||||
| from src.config import params | |||||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | ||||
| time_stamp_init = False | time_stamp_init = False | ||||
| time_stamp_first = 0 | time_stamp_first = 0 | ||||
| grad_scale = C.MultitypeFuncGraph("grad_scale") | grad_scale = C.MultitypeFuncGraph("grad_scale") | ||||
| _grad_overflow = C.MultitypeFuncGraph("_grad_overflow") | |||||
| reciprocal = P.Reciprocal() | reciprocal = P.Reciprocal() | ||||
| GRADIENT_CLIP_TYPE = 1 | |||||
| GRADIENT_CLIP_VALUE = 1.0 | |||||
| @grad_scale.register("Tensor", "Tensor") | |||||
| def tensor_grad_scale(scale, grad): | |||||
| return grad * F.cast(reciprocal(scale), F.dtype(grad)) | |||||
| @grad_scale.register("Tensor", "RowTensor") | |||||
| def tensor_grad_scale_row_tensor(scale, grad): | |||||
| return RowTensor(grad.indices, | |||||
| grad.values * F.cast(reciprocal(scale), F.dtype(grad.values)), | |||||
| grad.dense_shape) | |||||
| GRADIENT_CLIP_TYPE = params['GRADIENT_CLIP_TYPE'] | |||||
| GRADIENT_CLIP_VALUE = params['GRADIENT_CLIP_VALUE'] | |||||
| clip_grad = C.MultitypeFuncGraph("clip_grad") | clip_grad = C.MultitypeFuncGraph("clip_grad") | ||||
| @clip_grad.register("Number", "Number", "Tensor") | @clip_grad.register("Number", "Number", "Tensor") | ||||
| def _clip_grad(clip_type, clip_value, grad): | |||||
| """ | |||||
| Clip gradients. | |||||
| Inputs: | |||||
| clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'. | |||||
| clip_value (float): Specifies how much to clip. | |||||
| grad (tuple[Tensor]): Gradients. | |||||
| Outputs: | |||||
| tuple[Tensor]: clipped gradients. | |||||
| """ | |||||
| if clip_type not in (0, 1): | |||||
| return grad | |||||
| dt = F.dtype(grad) | |||||
| if clip_type == 0: | |||||
| new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt), | |||||
| F.cast(F.tuple_to_array((clip_value,)), dt)) | |||||
| else: | |||||
| new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt)) | |||||
| return new_grad | |||||
| class openpose_loss(_Loss): | class openpose_loss(_Loss): | ||||
| def __init__(self): | def __init__(self): | ||||
| super(openpose_loss, self).__init__() | super(openpose_loss, self).__init__() | ||||
| @@ -99,109 +115,49 @@ class openpose_loss(_Loss): | |||||
| return total_loss, heatmaps_loss, pafs_loss | return total_loss, heatmaps_loss, pafs_loss | ||||
| class Depend_network(nn.Cell): | |||||
| def __init__(self, network): | |||||
| super(Depend_network, self).__init__() | |||||
| class BuildTrainNetwork(nn.Cell): | |||||
| def __init__(self, network, criterion): | |||||
| super(BuildTrainNetwork, self).__init__() | |||||
| self.network = network | self.network = network | ||||
| self.criterion = criterion | |||||
| def construct(self, *args): | |||||
| loss, _, _ = self.network(*args) # loss, heatmaps_loss, pafs_loss | |||||
| def construct(self, input_data, gt_paf, gt_heatmap, mask): | |||||
| logit_pafs, logit_heatmap = self.network(input_data) | |||||
| loss, _, _ = self.criterion(logit_pafs, logit_heatmap, gt_paf, gt_heatmap, mask) | |||||
| return loss | return loss | ||||
| #loss = self.criterion(logit_pafs, logit_heatmap, gt_paf, gt_heatmap, mask) | |||||
| # return loss, heatmaps_loss, pafs_loss | |||||
| class TrainingWrapper(nn.Cell): | |||||
| def __init__(self, network, optimizer, sens=1): | |||||
| super(TrainingWrapper, self).__init__(auto_prefix=False) | |||||
| class TrainOneStepWithClipGradientCell(nn.Cell): | |||||
| '''TrainOneStepWithClipGradientCell''' | |||||
| def __init__(self, network, optimizer, sens=1.0): | |||||
| super(TrainOneStepWithClipGradientCell, self).__init__(auto_prefix=False) | |||||
| self.network = network | self.network = network | ||||
| self.depend_network = Depend_network(network) | |||||
| # self.weights = ms.ParameterTuple(network.trainable_params()) | |||||
| self.network.set_grad() | |||||
| self.network.add_flags(defer_inline=True) | |||||
| self.weights = optimizer.parameters | self.weights = optimizer.parameters | ||||
| self.optimizer = optimizer | self.optimizer = optimizer | ||||
| self.grad = C.GradOperation(get_by_list=True, sens_param=True) | self.grad = C.GradOperation(get_by_list=True, sens_param=True) | ||||
| self.hyper_map = C.HyperMap() | |||||
| self.sens = sens | self.sens = sens | ||||
| self.reducer_flag = False | self.reducer_flag = False | ||||
| self.grad_reducer = None | self.grad_reducer = None | ||||
| self.print = P.Print() | |||||
| self.parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||||
| if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: | |||||
| parallel_mode = _get_parallel_mode() | |||||
| if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): | |||||
| self.reducer_flag = True | self.reducer_flag = True | ||||
| if self.reducer_flag: | if self.reducer_flag: | ||||
| mean = context.get_auto_parallel_context("gradients_mean") | |||||
| #if mean.get_device_num_is_set(): | |||||
| # if mean: | |||||
| #degree = context.get_auto_parallel_context("device_num") | |||||
| # else: | |||||
| degree = get_group_size() | |||||
| self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree) | |||||
| def construct(self, *args): | |||||
| mean = _get_gradients_mean() | |||||
| degree = _get_device_num() | |||||
| self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) | |||||
| def construct(self, *inputs): | |||||
| weights = self.weights | weights = self.weights | ||||
| loss, heatmaps_loss, pafs_loss = self.network(*args) | |||||
| loss = self.network(*inputs) | |||||
| sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) | sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) | ||||
| #grads = self.grad(self.network, weights)(*args, sens) | |||||
| grads = self.grad(self.depend_network, weights)(*args, sens) | |||||
| grads = self.grad(self.network, weights)(*inputs, sens) | |||||
| grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) | |||||
| if self.reducer_flag: | if self.reducer_flag: | ||||
| # apply grad reducer on grads | |||||
| grads = self.grad_reducer(grads) | grads = self.grad_reducer(grads) | ||||
| #return F.depend(loss, self.optimizer(grads)) | |||||
| # for grad in grads: | |||||
| # self.print(grad) | |||||
| loss = F.depend(loss, self.optimizer(grads)) | |||||
| return loss, heatmaps_loss, pafs_loss | |||||
| class BuildTrainNetwork(nn.Cell): | |||||
| def __init__(self, network, criterion): | |||||
| super(BuildTrainNetwork, self).__init__() | |||||
| self.network = network | |||||
| self.criterion = criterion | |||||
| def construct(self, input_data, gt_paf, gt_heatmap, mask): | |||||
| logit_pafs, logit_heatmap = self.network(input_data) | |||||
| loss, _, _ = self.criterion(logit_pafs, logit_heatmap, gt_paf, gt_heatmap, mask) | |||||
| return loss | |||||
| class LossCallBack(Callback): | |||||
| """ | |||||
| Monitor the loss in training. | |||||
| If the loss is NAN or INF terminating training. | |||||
| Note: | |||||
| If per_print_times is 0 do not print loss. | |||||
| Args: | |||||
| per_print_times (int): Print loss every times. Default: 1. | |||||
| """ | |||||
| def __init__(self, per_print_times=1): | |||||
| super(LossCallBack, self).__init__() | |||||
| if not isinstance(per_print_times, int) or per_print_times < 0: | |||||
| raise ValueError("print_step must be int and >= 0.") | |||||
| self._per_print_times = per_print_times | |||||
| self.count = 0 | |||||
| self.loss_sum = 0 | |||||
| global time_stamp_init, time_stamp_first | |||||
| if not time_stamp_init: | |||||
| time_stamp_first = time.time() | |||||
| time_stamp_init = True | |||||
| def step_end(self, run_context): | |||||
| cb_params = run_context.original_args() | |||||
| loss = cb_params.net_outputs.asnumpy() | |||||
| self.count += 1 | |||||
| self.loss_sum += float(loss) | |||||
| cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 | |||||
| if self.count >= 1: | |||||
| global time_stamp_first | |||||
| time_stamp_current = time.time() | |||||
| loss = self.loss_sum/self.count | |||||
| loss_file = open("./loss.log", "a+") | |||||
| loss_file.write("%lu epoch: %s step: %s ,loss: %.5f" % | |||||
| (time_stamp_current - time_stamp_first, cb_params.cur_epoch_num, cur_step_in_epoch, | |||||
| loss)) | |||||
| loss_file.write("\n") | |||||
| loss_file.close() | |||||
| self.count = 0 | |||||
| self.loss_sum = 0 | |||||
| return F.depend(loss, self.optimizer(grads)) | |||||
| @@ -17,19 +17,18 @@ from mindspore.nn import Conv2d, ReLU | |||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore import context | from mindspore import context | ||||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | ||||
| #selfCat = P.Concat(axis=1) | |||||
| time_stamp_init = False | time_stamp_init = False | ||||
| time_stamp_first = 0 | time_stamp_first = 0 | ||||
| loadvgg = 1 | loadvgg = 1 | ||||
| class OpenPoseNet(nn.Cell): | class OpenPoseNet(nn.Cell): | ||||
| insize = 368 | insize = 368 | ||||
| def __init__(self, vggpath=''): | |||||
| def __init__(self, vggpath='', vgg_with_bn=False): | |||||
| super(OpenPoseNet, self).__init__() | super(OpenPoseNet, self).__init__() | ||||
| self.base = Base_model() | |||||
| self.base = Base_model(vgg_with_bn=vgg_with_bn) | |||||
| self.stage_1 = Stage_1() | self.stage_1 = Stage_1() | ||||
| self.stage_2 = Stage_x() | self.stage_2 = Stage_x() | ||||
| self.stage_3 = Stage_x() | self.stage_3 = Stage_x() | ||||
| @@ -39,23 +38,15 @@ class OpenPoseNet(nn.Cell): | |||||
| self.shape = P.Shape() | self.shape = P.Shape() | ||||
| self.cat = P.Concat(axis=1) | self.cat = P.Concat(axis=1) | ||||
| self.print = P.Print() | self.print = P.Print() | ||||
| # for m in self.modules(): | |||||
| # if isinstance(m, Conv2d): | |||||
| # init.constant_(m.bias, 0) | |||||
| if loadvgg and vggpath: | if loadvgg and vggpath: | ||||
| param_dict = load_checkpoint(vggpath) | param_dict = load_checkpoint(vggpath) | ||||
| param_dict_new = {} | param_dict_new = {} | ||||
| trans_name = 'base.vgg_base.' | trans_name = 'base.vgg_base.' | ||||
| for key, values in param_dict.items(): | for key, values in param_dict.items(): | ||||
| #print('key:',key,self.shape(values)) | |||||
| if key.startswith('moments.'): | if key.startswith('moments.'): | ||||
| continue | continue | ||||
| elif key.startswith('network.'): | elif key.startswith('network.'): | ||||
| param_dict_new[trans_name+key[17:]] = values | param_dict_new[trans_name+key[17:]] = values | ||||
| # else: | |||||
| # param_dict_new[key] = values | |||||
| #print(param_dict_new) | |||||
| load_param_into_net(self.base.vgg_base, param_dict_new) | load_param_into_net(self.base.vgg_base, param_dict_new) | ||||
| def construct(self, x): | def construct(self, x): | ||||
| @@ -205,20 +196,17 @@ class VGG_Base_MS(nn.Cell): | |||||
| return x | return x | ||||
| class Base_model(nn.Cell): | class Base_model(nn.Cell): | ||||
| def __init__(self): | |||||
| def __init__(self, vgg_with_bn=False): | |||||
| super(Base_model, self).__init__() | super(Base_model, self).__init__() | ||||
| #cfgs_zh = {'16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512]} | |||||
| cfgs_zh = {'19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512]} | cfgs_zh = {'19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512]} | ||||
| #cfgs_zh = {'16': [64, 64,128, 128, 256, 256, 256, 512, 512, 512]} | |||||
| self.vgg_base = Vgg(cfgs_zh['19'], batch_norm=False) | |||||
| #self.vgg_base = VGG_Base() | |||||
| #self.vgg_base = VGG_Base_MS() | |||||
| self.vgg_base = Vgg(cfgs_zh['19'], batch_norm=vgg_with_bn) | |||||
| self.conv4_3_CPM = Conv2d(in_channels=512, out_channels=256, kernel_size=3, stride=1, pad_mode='same', | self.conv4_3_CPM = Conv2d(in_channels=512, out_channels=256, kernel_size=3, stride=1, pad_mode='same', | ||||
| has_bias=True) | has_bias=True) | ||||
| self.conv4_4_CPM = Conv2d(in_channels=256, out_channels=128, kernel_size=3, stride=1, pad_mode='same', | self.conv4_4_CPM = Conv2d(in_channels=256, out_channels=128, kernel_size=3, stride=1, pad_mode='same', | ||||
| has_bias=True) | has_bias=True) | ||||
| self.relu = ReLU() | self.relu = ReLU() | ||||
| def construct(self, x): | def construct(self, x): | ||||
| x = self.vgg_base(x) | x = self.vgg_base(x) | ||||
| x = self.relu(self.conv4_3_CPM(x)) | x = self.relu(self.conv4_3_CPM(x)) | ||||
| @@ -1,27 +1,11 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import argparse | |||||
| import os | |||||
| import math | |||||
| import time | import time | ||||
| import numpy as np | import numpy as np | ||||
| from mindspore.train.serialization import load_checkpoint, load_param_into_net | from mindspore.train.serialization import load_checkpoint, load_param_into_net | ||||
| from mindspore.train.callback import LossMonitor | |||||
| from mindspore.train.callback import LossMonitor, Callback | |||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| from src.config import params | |||||
| from mindspore.common import dtype as mstype | |||||
| class MyLossMonitor(LossMonitor): | class MyLossMonitor(LossMonitor): | ||||
| def __init__(self, per_print_times=1): | def __init__(self, per_print_times=1): | ||||
| @@ -32,6 +16,7 @@ class MyLossMonitor(LossMonitor): | |||||
| def step_end(self, run_context): | def step_end(self, run_context): | ||||
| cb_params = run_context.original_args() | cb_params = run_context.original_args() | ||||
| loss = cb_params.net_outputs | loss = cb_params.net_outputs | ||||
| if isinstance(loss, (tuple, list)): | if isinstance(loss, (tuple, list)): | ||||
| @@ -47,63 +32,76 @@ class MyLossMonitor(LossMonitor): | |||||
| raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format( | raise ValueError("epoch: {} step: {}. Invalid loss, terminating training.".format( | ||||
| cb_params.cur_epoch_num, cur_step_in_epoch)) | cb_params.cur_epoch_num, cur_step_in_epoch)) | ||||
| if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: | if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: | ||||
| # print("epoch: %s step: %s, loss is %s, step time: %.3f s." % (cb_params.cur_epoch_num, cur_step_in_epoch, | |||||
| # loss, | |||||
| # (time.time() - self._start_time)), flush=True) | |||||
| self._loss_list.append(loss) | self._loss_list.append(loss) | ||||
| if cb_params.cur_step_num % 100 == 0: | if cb_params.cur_step_num % 100 == 0: | ||||
| print("epoch: %s, steps: [%s] mean loss is: %s"%(cb_params.cur_epoch_num, cur_step_in_epoch, | |||||
| np.array(self._loss_list).mean()), flush=True) | |||||
| print("epoch: %s, steps: [%s], mean loss is: %s"%(cb_params.cur_epoch_num, cur_step_in_epoch, | |||||
| np.array(self._loss_list).mean()), flush=True) | |||||
| self._loss_list = [] | self._loss_list = [] | ||||
| self._start_time = time.time() | self._start_time = time.time() | ||||
| class MyScaleSensCallback(Callback): | |||||
| '''MyLossScaleCallback''' | |||||
| def __init__(self, loss_scale_list, epoch_list): | |||||
| super(MyScaleSensCallback, self).__init__() | |||||
| self.loss_scale_list = loss_scale_list | |||||
| self.epoch_list = epoch_list | |||||
| self.scaling_sens = loss_scale_list[0] | |||||
| def parse_args(): | |||||
| """Parse train arguments.""" | |||||
| parser = argparse.ArgumentParser('mindspore openpose training') | |||||
| # dataset related | |||||
| parser.add_argument('--train_dir', type=str, default='train2017', help='train data dir') | |||||
| parser.add_argument('--train_ann', type=str, default='person_keypoints_train2017.json', | |||||
| help='train annotations json') | |||||
| parser.add_argument('--group_size', type=int, default=1, help='world size of distributed') | |||||
| args, _ = parser.parse_known_args() | |||||
| args.jsonpath_train = os.path.join(params['data_dir'], 'annotations/' + args.train_ann) | |||||
| args.imgpath_train = os.path.join(params['data_dir'], args.train_dir) | |||||
| args.maskpath_train = os.path.join(params['data_dir'], 'ignore_mask_train') | |||||
| return args | |||||
| def get_lr(lr, lr_gamma, steps_per_epoch, max_epoch_train, lr_steps, group_size): | |||||
| lr_stage = np.array([lr] * steps_per_epoch * max_epoch_train).astype('f') | |||||
| for step in lr_steps: | |||||
| step //= group_size | |||||
| lr_stage[step:] *= lr_gamma | |||||
| def epoch_end(self, run_context): | |||||
| cb_params = run_context.original_args() | |||||
| epoch = cb_params.cur_epoch_num | |||||
| for i, _ in enumerate(self.epoch_list): | |||||
| if epoch >= self.epoch_list[i]: | |||||
| self.scaling_sens = self.loss_scale_list[i+1] | |||||
| else: | |||||
| break | |||||
| scaling_sens_tensor = Tensor(self.scaling_sens, dtype=mstype.float32) | |||||
| cb_params.train_network.set_sense_scale(scaling_sens_tensor) | |||||
| print("Epoch: set train network scale sens to {}".format(self.scaling_sens)) | |||||
| def _linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr): | |||||
| lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps) | |||||
| learning_rate = float(init_lr) + lr_inc * current_step | |||||
| return learning_rate | |||||
| def _a_cosine_learning_rate(current_step, base_lr, warmup_steps, decay_steps): | |||||
| base = float(current_step - warmup_steps) / float(decay_steps) | |||||
| learning_rate = (1 + math.cos(base * math.pi)) / 2 * base_lr | |||||
| return learning_rate | |||||
| def _dynamic_lr(base_lr, total_steps, warmup_steps, warmup_ratio=1 / 3): | |||||
| lr = [] | |||||
| for i in range(total_steps): | |||||
| if i < warmup_steps: | |||||
| lr.append(_linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * warmup_ratio)) | |||||
| else: | |||||
| lr.append(_a_cosine_learning_rate(i, base_lr, warmup_steps, total_steps)) | |||||
| return lr | |||||
| def get_lr(lr, lr_gamma, steps_per_epoch, max_epoch_train, lr_steps, group_size, lr_type='default', warmup_epoch=5): | |||||
| if lr_type == 'default': | |||||
| lr_stage = np.array([lr] * steps_per_epoch * max_epoch_train).astype('f') | |||||
| for step in lr_steps: | |||||
| step //= group_size | |||||
| lr_stage[step:] *= lr_gamma | |||||
| elif lr_type == 'cosine': | |||||
| lr_stage = _dynamic_lr(lr, steps_per_epoch * max_epoch_train, warmup_epoch * steps_per_epoch, | |||||
| warmup_ratio=1 / 3) | |||||
| lr_stage = np.array(lr_stage).astype('f') | |||||
| else: | |||||
| raise ValueError("lr type {} is not support.".format(lr_type)) | |||||
| lr_base = lr_stage.copy() | lr_base = lr_stage.copy() | ||||
| lr_base = lr_base / 4 | lr_base = lr_base / 4 | ||||
| lr_vgg = lr_base.copy() | lr_vgg = lr_base.copy() | ||||
| vgg_freeze_step = 2000 | |||||
| vgg_freeze_step = 2000 // group_size | |||||
| lr_vgg[:vgg_freeze_step] = 0 | lr_vgg[:vgg_freeze_step] = 0 | ||||
| return lr_stage, lr_base, lr_vgg | |||||
| # zhang add | |||||
| def adjust_learning_rate(init_lr, lr_gamma, steps_per_epoch, max_epoch_train, stepvalues): | |||||
| lr_stage = np.array([init_lr] * steps_per_epoch * max_epoch_train).astype('f') | |||||
| for epoch in stepvalues: | |||||
| lr_stage[epoch * steps_per_epoch:] *= lr_gamma | |||||
| lr_base = lr_stage.copy() | |||||
| lr_base = lr_base / 4 | |||||
| lr_vgg = lr_base.copy() | |||||
| vgg_freeze_step = 2000 | |||||
| lr_vgg[:vgg_freeze_step] = 0 | |||||
| return lr_stage, lr_base, lr_vgg | return lr_stage, lr_base, lr_vgg | ||||
| @@ -13,26 +13,38 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| import os | import os | ||||
| import argparse | |||||
| import mindspore | |||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore.context import ParallelMode | from mindspore.context import ParallelMode | ||||
| from mindspore.communication.management import init, get_rank, get_group_size | from mindspore.communication.management import init, get_rank, get_group_size | ||||
| from mindspore.train import Model | from mindspore.train import Model | ||||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor | from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor | ||||
| from mindspore.nn.optim import Adam, Momentum | |||||
| from mindspore.train.loss_scale_manager import FixedLossScaleManager | from mindspore.train.loss_scale_manager import FixedLossScaleManager | ||||
| from mindspore.nn.optim import Adam | |||||
| from src.dataset import create_dataset | from src.dataset import create_dataset | ||||
| from src.openposenet import OpenPoseNet | from src.openposenet import OpenPoseNet | ||||
| from src.loss import openpose_loss, BuildTrainNetwork | |||||
| from src.loss import openpose_loss, BuildTrainNetwork, TrainOneStepWithClipGradientCell | |||||
| from src.config import params | from src.config import params | ||||
| from src.utils import parse_args, get_lr, load_model, MyLossMonitor | |||||
| from src.utils import get_lr, load_model, MyLossMonitor | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) | ||||
| parser = argparse.ArgumentParser('mindspore openpose training') | |||||
| parser.add_argument('--train_dir', type=str, default='train2017', help='train data dir') | |||||
| parser.add_argument('--train_ann', type=str, default='person_keypoints_train2017.json', | |||||
| help='train annotations json') | |||||
| parser.add_argument('--group_size', type=int, default=1, help='world size of distributed') | |||||
| args, _ = parser.parse_known_args() | |||||
| args.jsonpath_train = os.path.join(params['data_dir'], 'annotations/' + args.train_ann) | |||||
| args.imgpath_train = os.path.join(params['data_dir'], args.train_dir) | |||||
| args.maskpath_train = os.path.join(params['data_dir'], 'ignore_mask_train') | |||||
| def train(): | def train(): | ||||
| """Train function.""" | """Train function.""" | ||||
| args = parse_args() | |||||
| args.outputs_dir = params['save_model_path'] | args.outputs_dir = params['save_model_path'] | ||||
| @@ -46,11 +58,15 @@ def train(): | |||||
| args.outputs_dir = os.path.join(args.outputs_dir, "ckpt_0/") | args.outputs_dir = os.path.join(args.outputs_dir, "ckpt_0/") | ||||
| args.rank = 0 | args.rank = 0 | ||||
| # with out loss_scale | |||||
| if args.group_size > 1: | if args.group_size > 1: | ||||
| args.max_epoch = params["max_epoch_train_NP"] | |||||
| args.loss_scale = params['loss_scale'] / 2 | args.loss_scale = params['loss_scale'] / 2 | ||||
| args.lr_steps = list(map(int, params["lr_steps_NP"].split(','))) | args.lr_steps = list(map(int, params["lr_steps_NP"].split(','))) | ||||
| params['train_type'] = params['train_type_NP'] | |||||
| params['optimizer'] = params['optimizer_NP'] | |||||
| params['group_params'] = params['group_params_NP'] | |||||
| else: | else: | ||||
| args.max_epoch = params["max_epoch_train"] | |||||
| args.loss_scale = params['loss_scale'] | args.loss_scale = params['loss_scale'] | ||||
| args.lr_steps = list(map(int, params["lr_steps"].split(','))) | args.lr_steps = list(map(int, params["lr_steps"].split(','))) | ||||
| @@ -58,9 +74,7 @@ def train(): | |||||
| print('start create network') | print('start create network') | ||||
| criterion = openpose_loss() | criterion = openpose_loss() | ||||
| criterion.add_flags_recursive(fp32=True) | criterion.add_flags_recursive(fp32=True) | ||||
| network = OpenPoseNet(vggpath=params['vgg_path']) | |||||
| # network.add_flags_recursive(fp32=True) | |||||
| network = OpenPoseNet(vggpath=params['vgg_path'], vgg_with_bn=params['vgg_with_bn']) | |||||
| if params["load_pretrain"]: | if params["load_pretrain"]: | ||||
| print("load pretrain model:", params["pretrained_model_path"]) | print("load pretrain model:", params["pretrained_model_path"]) | ||||
| load_model(network, params["pretrained_model_path"]) | load_model(network, params["pretrained_model_path"]) | ||||
| @@ -72,7 +86,7 @@ def train(): | |||||
| print('start create dataset') | print('start create dataset') | ||||
| else: | else: | ||||
| print('Error: wrong data path') | print('Error: wrong data path') | ||||
| return 0 | |||||
| num_worker = 20 if args.group_size > 1 else 48 | num_worker = 20 if args.group_size > 1 else 48 | ||||
| de_dataset_train = create_dataset(args.jsonpath_train, args.imgpath_train, args.maskpath_train, | de_dataset_train = create_dataset(args.jsonpath_train, args.imgpath_train, args.maskpath_train, | ||||
| @@ -90,35 +104,63 @@ def train(): | |||||
| lr_stage, lr_base, lr_vgg = get_lr(params['lr'] * args.group_size, | lr_stage, lr_base, lr_vgg = get_lr(params['lr'] * args.group_size, | ||||
| params['lr_gamma'], | params['lr_gamma'], | ||||
| steps_per_epoch, | steps_per_epoch, | ||||
| params["max_epoch_train"], | |||||
| args.max_epoch, | |||||
| args.lr_steps, | args.lr_steps, | ||||
| args.group_size) | |||||
| vgg19_base_params = list(filter(lambda x: 'base.vgg_base' in x.name, train_net.trainable_params())) | |||||
| base_params = list(filter(lambda x: 'base.conv' in x.name, train_net.trainable_params())) | |||||
| stages_params = list(filter(lambda x: 'base' not in x.name, train_net.trainable_params())) | |||||
| group_params = [{'params': vgg19_base_params, 'lr': lr_vgg}, | |||||
| {'params': base_params, 'lr': lr_base}, | |||||
| {'params': stages_params, 'lr': lr_stage}] | |||||
| opt = Adam(group_params, loss_scale=args.loss_scale) | |||||
| train_net.set_train(True) | |||||
| loss_scale_manager = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False) | |||||
| model = Model(train_net, optimizer=opt, loss_scale_manager=loss_scale_manager) | |||||
| params['ckpt_interval'] = max(steps_per_epoch, params['ckpt_interval']) | |||||
| args.group_size, | |||||
| lr_type=params['lr_type'], | |||||
| warmup_epoch=params['warmup_epoch']) | |||||
| # optimizer | |||||
| if params['group_params']: | |||||
| vgg19_base_params = list(filter(lambda x: 'base.vgg_base' in x.name, train_net.trainable_params())) | |||||
| base_params = list(filter(lambda x: 'base.conv' in x.name, train_net.trainable_params())) | |||||
| stages_params = list(filter(lambda x: 'base' not in x.name, train_net.trainable_params())) | |||||
| group_params = [{'params': vgg19_base_params, 'lr': lr_vgg}, | |||||
| {'params': base_params, 'lr': lr_base}, | |||||
| {'params': stages_params, 'lr': lr_stage}] | |||||
| if params['optimizer'] == "Momentum": | |||||
| opt = Momentum(group_params, learning_rate=lr_stage, momentum=0.9) | |||||
| elif params['optimizer'] == "Adam": | |||||
| opt = Adam(group_params) | |||||
| else: | |||||
| raise ValueError("optimizer not support.") | |||||
| else: | |||||
| if params['optimizer'] == "Momentum": | |||||
| opt = Momentum(train_net.trainable_params(), learning_rate=lr_stage, momentum=0.9) | |||||
| elif params['optimizer'] == "Adam": | |||||
| opt = Adam(train_net.trainable_params(), learning_rate=lr_stage) | |||||
| else: | |||||
| raise ValueError("optimizer not support.") | |||||
| # callback | |||||
| config_ck = CheckpointConfig(save_checkpoint_steps=params['ckpt_interval'], | config_ck = CheckpointConfig(save_checkpoint_steps=params['ckpt_interval'], | ||||
| keep_checkpoint_max=params["keep_checkpoint_max"]) | keep_checkpoint_max=params["keep_checkpoint_max"]) | ||||
| ckpoint_cb = ModelCheckpoint(prefix='{}'.format(args.rank), directory=args.outputs_dir, config=config_ck) | ckpoint_cb = ModelCheckpoint(prefix='{}'.format(args.rank), directory=args.outputs_dir, config=config_ck) | ||||
| time_cb = TimeMonitor(data_size=de_dataset_train.get_dataset_size()) | time_cb = TimeMonitor(data_size=de_dataset_train.get_dataset_size()) | ||||
| callback_list = [MyLossMonitor(), time_cb, ckpoint_cb] | |||||
| if args.rank == 0: | |||||
| callback_list = [MyLossMonitor(), time_cb, ckpoint_cb] | |||||
| else: | |||||
| callback_list = [MyLossMonitor(), time_cb] | |||||
| # train | |||||
| if params['train_type'] == 'clip_grad': | |||||
| train_net = TrainOneStepWithClipGradientCell(train_net, opt, sens=args.loss_scale) | |||||
| train_net.set_train() | |||||
| model = Model(train_net) | |||||
| elif params['train_type'] == 'fix_loss_scale': | |||||
| loss_scale_manager = FixedLossScaleManager(args.loss_scale, drop_overflow_update=False) | |||||
| train_net.set_train() | |||||
| model = Model(train_net, optimizer=opt, loss_scale_manager=loss_scale_manager) | |||||
| else: | |||||
| raise ValueError("Type {} is not support.".format(params['train_type'])) | |||||
| print("============== Starting Training ==============") | print("============== Starting Training ==============") | ||||
| model.train(params["max_epoch_train"], de_dataset_train, callbacks=callback_list, | |||||
| model.train(args.max_epoch, de_dataset_train, callbacks=callback_list, | |||||
| dataset_sink_mode=False) | dataset_sink_mode=False) | ||||
| return 0 | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||
| # mindspore.common.seed.set_seed(1) | |||||
| mindspore.common.seed.set_seed(1) | |||||
| train() | train() | ||||