You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

apply_bpe_encoding.py 3.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Apply bpe script."""
  16. import os
  17. import argparse
  18. from multiprocessing import Pool, cpu_count
  19. from src.utils import Dictionary
  20. from src.utils import bpe_encode
  21. parser = argparse.ArgumentParser(description='Apply BPE.')
  22. parser.add_argument("--codes", type=str, default="", required=True,
  23. help="bpe codes path.")
  24. parser.add_argument("--src_folder", type=str, default="", required=True,
  25. help="raw corpus folder.")
  26. parser.add_argument("--output_folder", type=str, default="", required=True,
  27. help="encoded corpus output path.")
  28. parser.add_argument("--prefix", type=str, default="", required=False,
  29. help="Prefix of text file.")
  30. parser.add_argument("--vocab_path", type=str, default="", required=True,
  31. help="Generated vocabulary output path.")
  32. parser.add_argument("--threshold", type=int, default=None, required=False,
  33. help="Filter out words that frequency is lower than threshold.")
  34. parser.add_argument("--processes", type=int, default=2, required=False,
  35. help="Number of processes to use.")
  36. if __name__ == '__main__':
  37. args, _ = parser.parse_known_args()
  38. if not (args.codes and args.src_folder and args.output_folder):
  39. raise ValueError("Please enter required params.")
  40. source_folder = args.src_folder
  41. output_folder = args.output_folder
  42. codes = args.codes
  43. if not os.path.exists(codes):
  44. raise FileNotFoundError("`--codes` is not existed.")
  45. if not os.path.exists(source_folder) or not os.path.isdir(source_folder):
  46. raise ValueError("`--src_folder` must be a dir and existed.")
  47. if not os.path.exists(output_folder) or not os.path.isdir(output_folder):
  48. raise ValueError("`--output_folder` must be a dir and existed.")
  49. if not isinstance(args.prefix, str) or len(args.prefix) > 128:
  50. raise ValueError("`--prefix` must be a str and len <= 128.")
  51. if not isinstance(args.processes, int):
  52. raise TypeError("`--processes` must be an integer.")
  53. available_dict = []
  54. args_groups = []
  55. for file in os.listdir(source_folder):
  56. if args.prefix and not file.startswith(args.prefix):
  57. continue
  58. if file.endswith(".txt"):
  59. output_path = os.path.join(output_folder, file.replace(".txt", "_bpe.txt"))
  60. dict_path = os.path.join(output_folder, file.replace(".txt", ".dict"))
  61. available_dict.append(dict_path)
  62. args_groups.append((codes, os.path.join(source_folder, file),
  63. output_path, dict_path))
  64. kernel_size = 1 if args.processes <= 0 else args.processes
  65. kernel_size = min(kernel_size, cpu_count())
  66. pool = Pool(kernel_size)
  67. for arg in args_groups:
  68. pool.apply_async(bpe_encode, args=arg)
  69. pool.close()
  70. pool.join()
  71. vocab = Dictionary.load_from_text(available_dict)
  72. if args.threshold is not None:
  73. vocab = vocab.shrink(args.threshold)
  74. vocab.persistence(args.vocab_path)
  75. print(f" | Vocabulary Size: {len(vocab)}")