|
|
|
@@ -34,8 +34,8 @@ set_seed(1) |
|
|
|
parser = argparse.ArgumentParser(description="crnn training") |
|
|
|
parser.add_argument("--run_distribute", action='store_true', help="Run distribute, default is false.") |
|
|
|
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path, default is None') |
|
|
|
parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU'], |
|
|
|
help='Running platform, choose from Ascend, GPU, and default is Ascend.') |
|
|
|
parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend'], |
|
|
|
help='Running platform, only support Ascend now. Default is Ascend.') |
|
|
|
parser.add_argument('--model', type=str, default='lowercase', help="Model type, default is lowercase") |
|
|
|
parser.add_argument('--dataset', type=str, default='synth', choices=['synth', 'ic03', 'ic13', 'svt', 'iiit5k']) |
|
|
|
parser.set_defaults(run_distribute=False) |
|
|
|
@@ -92,7 +92,7 @@ if __name__ == '__main__': |
|
|
|
model = Model(net) |
|
|
|
# define callbacks |
|
|
|
callbacks = [LossMonitor(), TimeMonitor(data_size=step_size)] |
|
|
|
if config.save_checkpoint: |
|
|
|
if config.save_checkpoint and rank == 0: |
|
|
|
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps, |
|
|
|
keep_checkpoint_max=config.keep_checkpoint_max) |
|
|
|
save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/') |
|
|
|
|