diff --git a/model_zoo/official/nlp/transformer/export.py b/model_zoo/official/nlp/transformer/export.py index 40bd3d5094..6c0e30f1dc 100644 --- a/model_zoo/official/nlp/transformer/export.py +++ b/model_zoo/official/nlp/transformer/export.py @@ -26,7 +26,6 @@ from eval import load_weights parser = argparse.ArgumentParser(description='transformer export') parser.add_argument("--device_id", type=int, default=0, help="Device id") -parser.add_argument("--batch_size", type=int, default=1, help="batch size") parser.add_argument("--file_name", type=str, default="transformer", help="output file name.") parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format') parser.add_argument("--device_target", type=str, default="Ascend", @@ -43,7 +42,7 @@ if __name__ == '__main__': parameter_dict = load_weights(cfg.model_file) load_param_into_net(tfm_model, parameter_dict) - source_ids = Tensor(np.ones((args.batch_size, transformer_net_cfg.seq_length)).astype(np.int32)) - source_mask = Tensor(np.ones((args.batch_size, transformer_net_cfg.seq_length)).astype(np.int32)) + source_ids = Tensor(np.ones((transformer_net_cfg.batch_size, transformer_net_cfg.seq_length)).astype(np.int32)) + source_mask = Tensor(np.ones((transformer_net_cfg.batch_size, transformer_net_cfg.seq_length)).astype(np.int32)) export(tfm_model, source_ids, source_mask, file_name=args.file_name, file_format=args.file_format)