diff --git a/model_zoo/official/cv/resnet_thor/export.py b/model_zoo/official/cv/resnet_thor/export.py index 9b79d97aed..f90cfa0467 100644 --- a/model_zoo/official/cv/resnet_thor/export.py +++ b/model_zoo/official/cv/resnet_thor/export.py @@ -22,6 +22,7 @@ from src.config import config parser = argparse.ArgumentParser(description='checkpoint export') parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') +parser.add_argument('--output_file', type=str, default='', help='resnet output air name.') args_opt = parser.parse_args() if __name__ == '__main__': @@ -41,4 +42,4 @@ if __name__ == '__main__': load_param_into_net(net, param_dict) inputs = np.random.uniform(0.0, 1.0, size=[1, 3, 224, 224]).astype(np.float32) - export(net, Tensor(inputs), file_name='resnet-42_5004', file_format='AIR') + export(net, Tensor(inputs), file_name=args_opt.output_file, file_format='AIR')