|
|
|
@@ -22,6 +22,7 @@ from src.config import config |
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='checkpoint export') |
|
|
|
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') |
|
|
|
parser.add_argument('--output_file', type=str, default='', help='resnet output air name.') |
|
|
|
args_opt = parser.parse_args() |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
@@ -41,4 +42,4 @@ if __name__ == '__main__': |
|
|
|
load_param_into_net(net, param_dict) |
|
|
|
|
|
|
|
inputs = np.random.uniform(0.0, 1.0, size=[1, 3, 224, 224]).astype(np.float32) |
|
|
|
export(net, Tensor(inputs), file_name='resnet-42_5004', file_format='AIR') |
|
|
|
export(net, Tensor(inputs), file_name=args_opt.output_file, file_format='AIR') |