|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179 |
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- import argparse
- import os
- import time
- import numpy as np
-
-
- from mindspore import Tensor, float32, context
- from mindspore.common import set_seed
- from mindspore.train.serialization import load_checkpoint, load_param_into_net
-
- from src.config import config
- from src.dataset import flip_pairs, keypoint_dataset
- from src.evaluate.coco_eval import evaluate
- from src.model import get_pose_net
- from src.utils.transform import flip_back
- from src.predict import get_final_preds
-
-
- def parse_args():
- parser = argparse.ArgumentParser(description='Train keypoints network')
- parser.add_argument("--train_url", type=str, default="", help="")
- parser.add_argument("--data_url", type=str, default="", help="data")
- # output
- parser.add_argument('--output-url',
- help='output dir',
- type=str)
- # training
- parser.add_argument('--workers',
- help='num of dataloader workers',
- default=8,
- type=int)
- parser.add_argument('--model-file',
- help='model state file',
- type=str)
- parser.add_argument('--use-detect-bbox',
- help='use detect bbox',
- action='store_true')
- parser.add_argument('--flip-test',
- help='use flip test',
- default=True,
- action='store_true')
- parser.add_argument('--post-process',
- help='use post process',
- action='store_true')
- parser.add_argument('--shift-heatmap',
- help='shift heatmap',
- action='store_true')
- parser.add_argument('--coco-bbox-file',
- help='coco detection bbox file',
- type=str)
-
- args = parser.parse_args()
-
- return args
-
-
- def reset_config(cfg, args):
- if args.use_detect_bbox:
- cfg.TEST.USE_GT_BBOX = not args.use_detect_bbox
- if args.flip_test:
- cfg.TEST.FLIP_TEST = args.flip_test
- print('use flip test:', cfg.TEST.FLIP_TEST)
- if args.post_process:
- cfg.TEST.POST_PROCESS = args.post_process
- if args.shift_heatmap:
- cfg.TEST.SHIFT_HEATMAP = args.shift_heatmap
- if args.model_file:
- cfg.TEST.MODEL_FILE = args.model_file
- if args.coco_bbox_file:
- cfg.TEST.COCO_BBOX_FILE = args.coco_bbox_file
-
-
- def validate(cfg, val_dataset, model, output_dir):
- # switch to evaluate mode
- model.set_train(False)
-
- # init record
- num_samples = val_dataset.get_dataset_size() * cfg.TEST.BATCH_SIZE
- all_preds = np.zeros((num_samples, cfg.MODEL.NUM_JOINTS, 3),
- dtype=np.float32)
- all_boxes = np.zeros((num_samples, 2))
- image_id = []
- idx = 0
-
- # start eval
- start = time.time()
- for item in val_dataset.create_dict_iterator():
- # input data
- inputs = item['image'].asnumpy()
- # compute output
- output = model(Tensor(inputs, float32)).asnumpy()
- if cfg.TEST.FLIP_TEST:
- inputs_flipped = Tensor(inputs[:, :, :, ::-1], float32)
- output_flipped = model(inputs_flipped)
- output_flipped = flip_back(output_flipped.asnumpy(), flip_pairs)
-
- # feature is not aligned, shift flipped heatmap for higher accuracy
- if cfg.TEST.SHIFT_HEATMAP:
- output_flipped[:, :, :, 1:] = \
- output_flipped.copy()[:, :, :, 0:-1]
-
- output = (output + output_flipped) * 0.5
-
- # meta data
- c = item['center'].asnumpy()
- s = item['scale'].asnumpy()
- score = item['score'].asnumpy()
- file_id = list(item['id'].asnumpy())
-
- # pred by heatmaps
- preds, maxvals = get_final_preds(cfg, output.copy(), c, s)
- num_images, _ = preds.shape[:2]
- all_preds[idx:idx + num_images, :, 0:2] = preds[:, :, 0:2]
- all_preds[idx:idx + num_images, :, 2:3] = maxvals
- # double check this all_boxes parts
- all_boxes[idx:idx + num_images, 0] = np.prod(s * 200, 1)
- all_boxes[idx:idx + num_images, 1] = score
- image_id.extend(file_id)
- idx += num_images
- if idx % 1024 == 0:
- print('{} samples validated in {} seconds'.format(idx, time.time() - start))
- start = time.time()
-
- print(all_preds[:idx].shape, all_boxes[:idx].shape, len(image_id))
- _, perf_indicator = evaluate(
- cfg, all_preds[:idx], output_dir, all_boxes[:idx], image_id)
- print("AP:", perf_indicator)
- return perf_indicator
-
-
- def main():
- # init seed
- set_seed(1)
-
- # set context
- device_id = int(os.getenv('DEVICE_ID'))
- context.set_context(mode=context.GRAPH_MODE,
- device_target="Ascend", save_graphs=False, device_id=device_id)
-
- args = parse_args()
- # update config
- reset_config(config, args)
-
- # init model
- model = get_pose_net(config, is_train=False)
-
- # load parameters
- ckpt_name = config.TEST.MODEL_FILE
- print('loading model ckpt from {}'.format(ckpt_name))
- load_param_into_net(model, load_checkpoint(ckpt_name))
-
- # Data loading code
- valid_dataset, _ = keypoint_dataset(
- config,
- bbox_file=config.TEST.COCO_BBOX_FILE,
- train_mode=False,
- num_parallel_workers=args.workers,
- )
-
- # evaluate on validation set
- validate(config, valid_dataset, model, ckpt_name.split('.')[0])
-
-
- if __name__ == '__main__':
- main()
|