|
|
|
@@ -90,13 +90,13 @@ def main(): |
|
|
|
cfg.logger.info('start create dataloader')
|
|
|
|
de_dataset, steps_per_epoch, class_num = get_de_dataset(cfg)
|
|
|
|
cfg.steps_per_epoch = steps_per_epoch
|
|
|
|
cfg.logger.info('step per epoch: %d' % cfg.steps_per_epoch)
|
|
|
|
cfg.logger.info('step per epoch: {}'.format(cfg.steps_per_epoch))
|
|
|
|
de_dataloader = de_dataset.create_tuple_iterator()
|
|
|
|
cfg.logger.info('class num original: %d' % class_num)
|
|
|
|
cfg.logger.info('class num original: {}'.format(class_num))
|
|
|
|
if class_num % 16 != 0:
|
|
|
|
class_num = (class_num // 16 + 1) * 16
|
|
|
|
cfg.class_num = class_num
|
|
|
|
cfg.logger.info('change the class num to: %d' % cfg.class_num)
|
|
|
|
cfg.logger.info('change the class num to: {}'.format(cfg.class_num))
|
|
|
|
cfg.logger.info('end create dataloader')
|
|
|
|
|
|
|
|
# backbone and loss
|
|
|
|
@@ -119,7 +119,7 @@ def main(): |
|
|
|
else:
|
|
|
|
param_dict_new[key] = values
|
|
|
|
load_param_into_net(network, param_dict_new)
|
|
|
|
cfg.logger.info('load model %s success' % cfg.pretrained)
|
|
|
|
cfg.logger.info('load model %s success', cfg.pretrained)
|
|
|
|
|
|
|
|
# mixed precision training
|
|
|
|
network.add_flags_recursive(fp16=True)
|
|
|
|
|