|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Apply bpe script."""
- import os
- import argparse
- from multiprocessing import Pool, cpu_count
-
- from src.utils import Dictionary
- from src.utils import bpe_encode
-
- parser = argparse.ArgumentParser(description='Apply BPE.')
- parser.add_argument("--codes", type=str, default="", required=True,
- help="bpe codes path.")
- parser.add_argument("--src_folder", type=str, default="", required=True,
- help="raw corpus folder.")
- parser.add_argument("--output_folder", type=str, default="", required=True,
- help="encoded corpus output path.")
- parser.add_argument("--prefix", type=str, default="", required=False,
- help="Prefix of text file.")
- parser.add_argument("--vocab_path", type=str, default="", required=True,
- help="Generated vocabulary output path.")
- parser.add_argument("--threshold", type=int, default=None, required=False,
- help="Filter out words that frequency is lower than threshold.")
- parser.add_argument("--processes", type=int, default=2, required=False,
- help="Number of processes to use.")
-
- if __name__ == '__main__':
- args, _ = parser.parse_known_args()
-
- if not (args.codes and args.src_folder and args.output_folder):
- raise ValueError("Please enter required params.")
-
- source_folder = args.src_folder
- output_folder = args.output_folder
- codes = args.codes
-
- if not os.path.exists(codes):
- raise FileNotFoundError("`--codes` is not existed.")
- if not os.path.exists(source_folder) or not os.path.isdir(source_folder):
- raise ValueError("`--src_folder` must be a dir and existed.")
- if not os.path.exists(output_folder) or not os.path.isdir(output_folder):
- raise ValueError("`--output_folder` must be a dir and existed.")
- if not isinstance(args.prefix, str) or len(args.prefix) > 128:
- raise ValueError("`--prefix` must be a str and len <= 128.")
- if not isinstance(args.processes, int):
- raise TypeError("`--processes` must be an integer.")
-
- available_dict = []
- args_groups = []
- for file in os.listdir(source_folder):
- if args.prefix and not file.startswith(args.prefix):
- continue
- if file.endswith(".txt"):
- output_path = os.path.join(output_folder, file.replace(".txt", "_bpe.txt"))
- dict_path = os.path.join(output_folder, file.replace(".txt", ".dict"))
- available_dict.append(dict_path)
- args_groups.append((codes, os.path.join(source_folder, file),
- output_path, dict_path))
-
- kernel_size = 1 if args.processes <= 0 else args.processes
- kernel_size = min(kernel_size, cpu_count())
- pool = Pool(kernel_size)
- for arg in args_groups:
- pool.apply_async(bpe_encode, args=arg)
- pool.close()
- pool.join()
-
- vocab = Dictionary.load_from_text(available_dict)
- if args.threshold is not None:
- vocab = vocab.shrink(args.threshold)
- vocab.persistence(args.vocab_path)
- print(f" | Vocabulary Size: {len(vocab)}")
|