|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Generate Cornell Movie Dialog dataset."""
- import os
- import argparse
- from src.dataset import BiLingualDataLoader
- from src.language_model import NoiseChannelLanguageModel
- from src.utils import Dictionary
-
- parser = argparse.ArgumentParser(description='Generate Cornell Movie Dialog dataset file.')
- parser.add_argument("--src_folder", type=str, default="", required=True,
- help="Raw corpus folder.")
- parser.add_argument("--existed_vocab", type=str, default="", required=True,
- help="Existed vocabulary.")
- parser.add_argument("--train_prefix", type=str, default="train", required=False,
- help="Prefix of train file.")
- parser.add_argument("--test_prefix", type=str, default="test", required=False,
- help="Prefix of test file.")
- parser.add_argument("--valid_prefix", type=str, default=None, required=False,
- help="Prefix of valid file.")
- parser.add_argument("--noise_prob", type=float, default=0., required=False,
- help="Add noise prob.")
- parser.add_argument("--max_len", type=int, default=32, required=False,
- help="Max length of sentence.")
- parser.add_argument("--output_folder", type=str, default="", required=True,
- help="Dataset output path.")
-
- if __name__ == '__main__':
- args, _ = parser.parse_known_args()
-
- dicts = []
- train_src_file = ""
- train_tgt_file = ""
- test_src_file = ""
- test_tgt_file = ""
- valid_src_file = ""
- valid_tgt_file = ""
- for file in os.listdir(args.src_folder):
- if file.startswith(args.train_prefix) and "src" in file and file.endswith(".txt"):
- train_src_file = os.path.join(args.src_folder, file)
- elif file.startswith(args.train_prefix) and "tgt" in file and file.endswith(".txt"):
- train_tgt_file = os.path.join(args.src_folder, file)
- elif file.startswith(args.test_prefix) and "src" in file and file.endswith(".txt"):
- test_src_file = os.path.join(args.src_folder, file)
- elif file.startswith(args.test_prefix) and "tgt" in file and file.endswith(".txt"):
- test_tgt_file = os.path.join(args.src_folder, file)
- elif args.valid_prefix and file.startswith(args.valid_prefix) and "src" in file and file.endswith(".txt"):
- valid_src_file = os.path.join(args.src_folder, file)
- elif args.valid_prefix and file.startswith(args.valid_prefix) and "tgt" in file and file.endswith(".txt"):
- valid_tgt_file = os.path.join(args.src_folder, file)
- else:
- continue
-
- vocab = Dictionary.load_from_persisted_dict(args.existed_vocab)
-
- if train_src_file and train_tgt_file:
- BiLingualDataLoader(
- src_filepath=train_src_file,
- tgt_filepath=train_tgt_file,
- src_dict=vocab, tgt_dict=vocab,
- src_lang="en", tgt_lang="en",
- language_model=NoiseChannelLanguageModel(add_noise_prob=args.noise_prob),
- max_sen_len=args.max_len
- ).write_to_tfrecord(
- path=os.path.join(
- args.output_folder, "train_cornell_dialog.tfrecord"
- )
- )
-
- if test_src_file and test_tgt_file:
- BiLingualDataLoader(
- src_filepath=test_src_file,
- tgt_filepath=test_tgt_file,
- src_dict=vocab, tgt_dict=vocab,
- src_lang="en", tgt_lang="en",
- language_model=NoiseChannelLanguageModel(add_noise_prob=0.),
- max_sen_len=args.max_len
- ).write_to_tfrecord(
- path=os.path.join(
- args.output_folder, "test_cornell_dialog.tfrecord"
- )
- )
-
- if args.valid_prefix:
- BiLingualDataLoader(
- src_filepath=os.path.join(args.src_folder, valid_src_file),
- tgt_filepath=os.path.join(args.src_folder, valid_tgt_file),
- src_dict=vocab, tgt_dict=vocab,
- src_lang="en", tgt_lang="en",
- language_model=NoiseChannelLanguageModel(add_noise_prob=0.),
- max_sen_len=args.max_len
- ).write_to_tfrecord(
- path=os.path.join(
- args.output_folder, "valid_cornell_dialog.tfrecord"
- )
- )
-
- print(f" | Vocabulary size: {vocab.size}.")
|