|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160 |
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """
- Test centerface example
- """
- import os
- import time
- import argparse
- import datetime
- import scipy.io as sio
-
- from mindspore import context
- from mindspore.train.serialization import load_checkpoint, load_param_into_net
-
- from src.utils import get_logger
- from src.var_init import default_recurisive_init
- from src.centerface import CenterfaceMobilev2, CenterFaceWithNms
- from src.config import ConfigCenterface
-
- from dependency.centernet.src.lib.detectors.base_detector import CenterFaceDetector
- from dependency.evaluate.eval import evaluation
-
- dev_id = int(os.getenv('DEVICE_ID'))
- context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=False,
- device_target="Ascend", save_graphs=False, device_id=dev_id)
-
- parser = argparse.ArgumentParser('mindspore coco training')
- parser.add_argument('--data_dir', type=str, default='', help='train data dir')
- parser.add_argument('--test_model', type=str, default='', help='test model dir')
- parser.add_argument('--ground_truth_mat', type=str, default='', help='ground_truth, mat type')
- parser.add_argument('--save_dir', type=str, default='', help='save_path for evaluate')
- parser.add_argument('--ground_truth_path', type=str, default='', help='ground_truth path, contain all mat file')
- parser.add_argument('--eval', type=int, default=0, help='if do eval after test')
- parser.add_argument('--eval_script_path', type=str, default='', help='evaluate script path')
- parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
- parser.add_argument('--ckpt_path', type=str, default='outputs/', help='checkpoint save location')
- parser.add_argument('--ckpt_name', type=str, default="", help='input model name')
- parser.add_argument('--device_num', type=int, default=1, help='device num for testing')
- parser.add_argument('--steps_per_epoch', type=int, default=198, help='steps for each epoch')
- parser.add_argument('--start', type=int, default=0, help='start loop number, used to calculate first epoch number')
- parser.add_argument('--end', type=int, default=18, help='end loop number, used to calculate last epoch number')
-
- args, _ = parser.parse_known_args()
-
- if __name__ == "__main__":
- # logger
- args.outputs_dir = os.path.join(args.ckpt_path,
- datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
- args.logger = get_logger(args.outputs_dir, args.rank)
- args.logger.save_args(args)
-
- if args.ckpt_name != "":
- args.start = 0
- args.end = 1
-
- for loop in range(args.start, args.end, 1):
- network = CenterfaceMobilev2()
- default_recurisive_init(network)
-
- if args.ckpt_name == "":
- ckpt_num = loop * args.device_num + args.rank + 1
- ckpt_name = "0-" + str(ckpt_num) + "_" + str(args.steps_per_epoch * ckpt_num) + ".ckpt"
- else:
- ckpt_name = args.ckpt_name
-
- test_model = args.test_model + ckpt_name
- if not test_model:
- args.logger.info('load_model {} none'.format(test_model))
- continue
-
- if os.path.isfile(test_model):
- param_dict = load_checkpoint(test_model)
- param_dict_new = {}
- for key, values in param_dict.items():
- if key.startswith('moments.') or key.startswith('moment1.') or key.startswith('moment2.'):
- continue
- elif key.startswith('centerface_network.'):
- param_dict_new[key[19:]] = values
- else:
- param_dict_new[key] = values
-
- load_param_into_net(network, param_dict_new)
- args.logger.info('load_model {} success'.format(test_model))
- else:
- args.logger.info('{} not exists or not a pre-trained file'.format(test_model))
- continue
-
- train_network_type_nms = 1 # default with num
- if train_network_type_nms:
- network = CenterFaceWithNms(network)
- args.logger.info('train network type with nms')
- network.set_train(False)
- args.logger.info('finish get network')
-
- config = ConfigCenterface()
-
- # test network -----------
- start = time.time()
-
- ground_truth_mat = sio.loadmat(args.ground_truth_mat)
- event_list = ground_truth_mat['event_list']
- file_list = ground_truth_mat['file_list']
- if args.ckpt_name == "":
- save_path = args.save_dir + str(ckpt_num) + '/'
- else:
- save_path = args.save_dir+ '/'
- detector = CenterFaceDetector(config, network)
-
- for index, event in enumerate(event_list):
- file_list_item = file_list[index][0]
- im_dir = event[0][0]
- if not os.path.exists(save_path + im_dir):
- os.makedirs(save_path + im_dir)
- args.logger.info('save_path + im_dir={}'.format(save_path + im_dir))
- for num, file in enumerate(file_list_item):
- im_name = file[0][0]
- zip_name = '%s/%s.jpg' % (im_dir, im_name)
- img_path = os.path.join(args.data_dir, zip_name)
- args.logger.info('img_path={}'.format(img_path))
-
- dets = detector.run(img_path)['results']
-
- f = open(save_path + im_dir + '/' + im_name + '.txt', 'w')
- f.write('{:s}\n'.format('%s/%s.jpg' % (im_dir, im_name)))
- f.write('{:d}\n'.format(len(dets)))
- for b in dets[1]:
- x1, y1, x2, y2, s = b[0], b[1], b[2], b[3], b[4]
- f.write('{:.1f} {:.1f} {:.1f} {:.1f} {:.3f}\n'.format(x1, y1, (x2 - x1 + 1), (y2 - y1 + 1), s))
- f.close()
- args.logger.info('event:{}, num:{}'.format(index + 1, num + 1))
-
- end = time.time()
- args.logger.info("============num {} time {}".format(num, (end-start)*1000))
- start = end
-
- if args.eval:
- args.logger.info('==========start eval===============')
- args.logger.info("test output path = {}".format(save_path))
- if os.path.isdir(save_path):
- evaluation(save_path, args.ground_truth_path)
- else:
- args.logger.info('no test output path')
- args.logger.info('==========end eval===============')
-
- if args.ckpt_name != "":
- break
-
- args.logger.info('==========end testing===============')
|