| @@ -26,7 +26,6 @@ from eval import load_weights | |||||
| parser = argparse.ArgumentParser(description='transformer export') | parser = argparse.ArgumentParser(description='transformer export') | ||||
| parser.add_argument("--device_id", type=int, default=0, help="Device id") | parser.add_argument("--device_id", type=int, default=0, help="Device id") | ||||
| parser.add_argument("--batch_size", type=int, default=1, help="batch size") | |||||
| parser.add_argument("--file_name", type=str, default="transformer", help="output file name.") | parser.add_argument("--file_name", type=str, default="transformer", help="output file name.") | ||||
| parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format') | parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format') | ||||
| parser.add_argument("--device_target", type=str, default="Ascend", | parser.add_argument("--device_target", type=str, default="Ascend", | ||||
| @@ -43,7 +42,7 @@ if __name__ == '__main__': | |||||
| parameter_dict = load_weights(cfg.model_file) | parameter_dict = load_weights(cfg.model_file) | ||||
| load_param_into_net(tfm_model, parameter_dict) | load_param_into_net(tfm_model, parameter_dict) | ||||
| source_ids = Tensor(np.ones((args.batch_size, transformer_net_cfg.seq_length)).astype(np.int32)) | |||||
| source_mask = Tensor(np.ones((args.batch_size, transformer_net_cfg.seq_length)).astype(np.int32)) | |||||
| source_ids = Tensor(np.ones((transformer_net_cfg.batch_size, transformer_net_cfg.seq_length)).astype(np.int32)) | |||||
| source_mask = Tensor(np.ones((transformer_net_cfg.batch_size, transformer_net_cfg.seq_length)).astype(np.int32)) | |||||
| export(tfm_model, source_ids, source_mask, file_name=args.file_name, file_format=args.file_format) | export(tfm_model, source_ids, source_mask, file_name=args.file_name, file_format=args.file_format) | ||||