Browse Source

modify transformer export

tags/v1.2.0-rc1
changzherui 4 years ago
parent
commit
a0ff6613c3
1 changed files with 2 additions and 3 deletions
  1. +2
    -3
      model_zoo/official/nlp/transformer/export.py

+ 2
- 3
model_zoo/official/nlp/transformer/export.py View File

@@ -26,7 +26,6 @@ from eval import load_weights


parser = argparse.ArgumentParser(description='transformer export') parser = argparse.ArgumentParser(description='transformer export')
parser.add_argument("--device_id", type=int, default=0, help="Device id") parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument("--file_name", type=str, default="transformer", help="output file name.") parser.add_argument("--file_name", type=str, default="transformer", help="output file name.")
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format') parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format')
parser.add_argument("--device_target", type=str, default="Ascend", parser.add_argument("--device_target", type=str, default="Ascend",
@@ -43,7 +42,7 @@ if __name__ == '__main__':
parameter_dict = load_weights(cfg.model_file) parameter_dict = load_weights(cfg.model_file)
load_param_into_net(tfm_model, parameter_dict) load_param_into_net(tfm_model, parameter_dict)


source_ids = Tensor(np.ones((args.batch_size, transformer_net_cfg.seq_length)).astype(np.int32))
source_mask = Tensor(np.ones((args.batch_size, transformer_net_cfg.seq_length)).astype(np.int32))
source_ids = Tensor(np.ones((transformer_net_cfg.batch_size, transformer_net_cfg.seq_length)).astype(np.int32))
source_mask = Tensor(np.ones((transformer_net_cfg.batch_size, transformer_net_cfg.seq_length)).astype(np.int32))


export(tfm_model, source_ids, source_mask, file_name=args.file_name, file_format=args.file_format) export(tfm_model, source_ids, source_mask, file_name=args.file_name, file_format=args.file_format)

Loading…
Cancel
Save