You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gigaword.py 3.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Generate Gigaword dataset."""
  16. import os
  17. import argparse
  18. from src.dataset import BiLingualDataLoader
  19. from src.language_model import NoiseChannelLanguageModel
  20. from src.utils import Dictionary
  21. parser = argparse.ArgumentParser(description='Create Gigaword fine-tune Dataset.')
  22. parser.add_argument("--train_src", type=str, default="", required=False,
  23. help="train dataset source file path.")
  24. parser.add_argument("--train_ref", type=str, default="", required=False,
  25. help="train dataset reference file path.")
  26. parser.add_argument("--test_src", type=str, default="", required=False,
  27. help="test dataset source file path.")
  28. parser.add_argument("--test_ref", type=str, default="", required=False,
  29. help="test dataset reference file path.")
  30. parser.add_argument("--noise_prob", type=float, default=0., required=False,
  31. help="add noise prob.")
  32. parser.add_argument("--existed_vocab", type=str, default="", required=False,
  33. help="existed vocab path.")
  34. parser.add_argument("--max_len", type=int, default=64, required=False,
  35. help="max length of sentences.")
  36. parser.add_argument("--output_folder", type=str, default="", required=True,
  37. help="dataset output path.")
  38. parser.add_argument("--format", type=str, default="tfrecord", required=False,
  39. help="dataset format.")
  40. if __name__ == '__main__':
  41. args, _ = parser.parse_known_args()
  42. vocab = Dictionary.load_from_persisted_dict(args.existed_vocab)
  43. if args.train_src and args.train_ref:
  44. train = BiLingualDataLoader(
  45. src_filepath=args.train_src,
  46. tgt_filepath=args.train_ref,
  47. src_dict=vocab, tgt_dict=vocab,
  48. src_lang="en", tgt_lang="en",
  49. language_model=NoiseChannelLanguageModel(add_noise_prob=args.noise_prob),
  50. max_sen_len=args.max_len
  51. )
  52. if "tf" in args.format.lower():
  53. train.write_to_tfrecord(
  54. path=os.path.join(args.output_folder, "gigaword_train_dataset.tfrecord")
  55. )
  56. else:
  57. train.write_to_mindrecord(
  58. path=os.path.join(args.output_folder, "gigaword_train_dataset.mindrecord")
  59. )
  60. if args.test_src and args.test_ref:
  61. test = BiLingualDataLoader(
  62. src_filepath=args.test_src,
  63. tgt_filepath=args.test_ref,
  64. src_dict=vocab, tgt_dict=vocab,
  65. src_lang="en", tgt_lang="en",
  66. language_model=NoiseChannelLanguageModel(add_noise_prob=0),
  67. max_sen_len=args.max_len
  68. )
  69. if "tf" in args.format.lower():
  70. test.write_to_tfrecord(
  71. path=os.path.join(args.output_folder, "gigaword_test_dataset.tfrecord")
  72. )
  73. else:
  74. test.write_to_mindrecord(
  75. path=os.path.join(args.output_folder, "gigaword_test_dataset.mindrecord")
  76. )
  77. print(f" | Vocabulary size: {vocab.size}.")