|
|
|
@@ -1,4 +1,4 @@ |
|
|
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
|
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
@@ -63,6 +63,33 @@ class FastTextDataPreProcess(): |
|
|
|
self.vec2words[1] = 'UNK'
|
|
|
|
self.str_html = re.compile(r'<[^>]+>')
|
|
|
|
|
|
|
|
def common_block(self, _pair_sen, spacy_nlp):
|
|
|
|
"""common block for data preprocessing"""
|
|
|
|
label_idx = int(_pair_sen[0]) - 1
|
|
|
|
if len(_pair_sen) == 3:
|
|
|
|
src_tokens = self.input_preprocess(src_text1=_pair_sen[1],
|
|
|
|
src_text2=_pair_sen[2],
|
|
|
|
spacy_nlp=spacy_nlp,
|
|
|
|
train_mode=True)
|
|
|
|
src_tokens_length = len(src_tokens)
|
|
|
|
elif len(_pair_sen) == 2:
|
|
|
|
src_tokens = self.input_preprocess(src_text1=_pair_sen[1],
|
|
|
|
src_text2=None,
|
|
|
|
spacy_nlp=spacy_nlp,
|
|
|
|
train_mode=True)
|
|
|
|
src_tokens_length = len(src_tokens)
|
|
|
|
elif len(_pair_sen) == 4:
|
|
|
|
if _pair_sen[2]:
|
|
|
|
sen_o_t = _pair_sen[1] + ' ' + _pair_sen[2]
|
|
|
|
else:
|
|
|
|
sen_o_t = _pair_sen[1]
|
|
|
|
src_tokens = self.input_preprocess(src_text1=sen_o_t,
|
|
|
|
src_text2=_pair_sen[3],
|
|
|
|
spacy_nlp=spacy_nlp,
|
|
|
|
train_mode=True)
|
|
|
|
src_tokens_length = len(src_tokens)
|
|
|
|
return src_tokens, src_tokens_length, label_idx
|
|
|
|
|
|
|
|
def load(self):
|
|
|
|
"""data preprocess loader"""
|
|
|
|
train_dataset_list = []
|
|
|
|
@@ -73,30 +100,8 @@ class FastTextDataPreProcess(): |
|
|
|
with open(self.train_path, 'r', newline='', encoding='utf-8') as src_file:
|
|
|
|
reader = csv.reader(src_file, delimiter=",", quotechar='"')
|
|
|
|
for _, _pair_sen in enumerate(reader):
|
|
|
|
label_idx = int(_pair_sen[0]) - 1
|
|
|
|
if len(_pair_sen) == 3:
|
|
|
|
src_tokens = self.input_preprocess(src_text1=_pair_sen[1],
|
|
|
|
src_text2=_pair_sen[2],
|
|
|
|
spacy_nlp=spacy_nlp,
|
|
|
|
train_mode=True)
|
|
|
|
src_tokens_length = len(src_tokens)
|
|
|
|
elif len(_pair_sen) == 2:
|
|
|
|
src_tokens = self.input_preprocess(src_text1=_pair_sen[1],
|
|
|
|
src_text2=None,
|
|
|
|
spacy_nlp=spacy_nlp,
|
|
|
|
train_mode=True)
|
|
|
|
src_tokens_length = len(src_tokens)
|
|
|
|
elif len(_pair_sen) == 4:
|
|
|
|
if _pair_sen[2]:
|
|
|
|
sen_o_t = _pair_sen[1] + ' ' + _pair_sen[2]
|
|
|
|
else:
|
|
|
|
sen_o_t = _pair_sen[1]
|
|
|
|
src_tokens = self.input_preprocess(src_text1=sen_o_t,
|
|
|
|
src_text2=_pair_sen[3],
|
|
|
|
spacy_nlp=spacy_nlp,
|
|
|
|
train_mode=True)
|
|
|
|
src_tokens_length = len(src_tokens)
|
|
|
|
|
|
|
|
src_tokens, src_tokens_length, label_idx = self.common_block(_pair_sen=_pair_sen,
|
|
|
|
spacy_nlp=spacy_nlp)
|
|
|
|
train_dataset_list.append([src_tokens, src_tokens_length, label_idx])
|
|
|
|
|
|
|
|
print("Begin to process test data...")
|
|
|
|
@@ -274,7 +279,7 @@ if __name__ == '__main__': |
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument('--train_file', type=str, required=True, help='train dataset file path')
|
|
|
|
parser.add_argument('--test_file', type=str, required=True, help='test dataset file path')
|
|
|
|
parser.add_argument('--class_num', type=int, required=True, help='Dataset classe number')
|
|
|
|
parser.add_argument('--class_num', type=int, required=True, help='Dataset class number')
|
|
|
|
parser.add_argument('--ngram', type=int, default=2, required=False)
|
|
|
|
parser.add_argument('--max_len', type=int, required=False, help='max length sentence in dataset')
|
|
|
|
parser.add_argument('--bucket', type=ast.literal_eval, default=[64, 128, 467], help='bucket sequence length.')
|
|
|
|
|