|
|
|
@@ -26,7 +26,6 @@ from eval import load_weights |
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='transformer export') |
|
|
|
parser.add_argument("--device_id", type=int, default=0, help="Device id") |
|
|
|
parser.add_argument("--batch_size", type=int, default=1, help="batch size") |
|
|
|
parser.add_argument("--file_name", type=str, default="transformer", help="output file name.") |
|
|
|
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format') |
|
|
|
parser.add_argument("--device_target", type=str, default="Ascend", |
|
|
|
@@ -43,7 +42,7 @@ if __name__ == '__main__': |
|
|
|
parameter_dict = load_weights(cfg.model_file) |
|
|
|
load_param_into_net(tfm_model, parameter_dict) |
|
|
|
|
|
|
|
source_ids = Tensor(np.ones((args.batch_size, transformer_net_cfg.seq_length)).astype(np.int32)) |
|
|
|
source_mask = Tensor(np.ones((args.batch_size, transformer_net_cfg.seq_length)).astype(np.int32)) |
|
|
|
source_ids = Tensor(np.ones((transformer_net_cfg.batch_size, transformer_net_cfg.seq_length)).astype(np.int32)) |
|
|
|
source_mask = Tensor(np.ones((transformer_net_cfg.batch_size, transformer_net_cfg.seq_length)).astype(np.int32)) |
|
|
|
|
|
|
|
export(tfm_model, source_ids, source_mask, file_name=args.file_name, file_format=args.file_format) |