Browse Source

fix code issue

pull/14905/head
zhaojichen 5 years ago
parent
commit
0ee8ccddd6
2 changed files with 37 additions and 32 deletions
  1. +6
    -6
      model_zoo/official/cv/FCN8s/src/data/build_seg_data.py
  2. +31
    -26
      model_zoo/official/nlp/fasttext/src/dataset.py

+ 6
- 6
model_zoo/official/cv/FCN8s/src/data/build_seg_data.py View File

@@ -38,7 +38,7 @@ def parse_args():
if __name__ == '__main__':
args = parse_args()
datas = []
data_list = []
with open(args.data_lst) as f:
lines = f.readlines()
if args.shuffle:
@@ -65,14 +65,14 @@ if __name__ == '__main__':
sample_['data'] = f.read()
with open(os.path.join(args.data_root, label_path), 'rb') as f:
sample_['label'] = f.read()
datas.append(sample_)
data_list.append(sample_)
cnt += 1
if cnt % 1000 == 0:
writer.write_raw_data(datas)
writer.write_raw_data(data_list)
print('number of samples written:', cnt)
datas = []
data_list = []
if datas:
writer.write_raw_data(datas)
if data_list:
writer.write_raw_data(data_list)
writer.commit()
print('number of samples written:', cnt)

+ 31
- 26
model_zoo/official/nlp/fasttext/src/dataset.py View File

@@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -63,6 +63,33 @@ class FastTextDataPreProcess():
self.vec2words[1] = 'UNK'
self.str_html = re.compile(r'<[^>]+>')
def common_block(self, _pair_sen, spacy_nlp):
"""common block for data preprocessing"""
label_idx = int(_pair_sen[0]) - 1
if len(_pair_sen) == 3:
src_tokens = self.input_preprocess(src_text1=_pair_sen[1],
src_text2=_pair_sen[2],
spacy_nlp=spacy_nlp,
train_mode=True)
src_tokens_length = len(src_tokens)
elif len(_pair_sen) == 2:
src_tokens = self.input_preprocess(src_text1=_pair_sen[1],
src_text2=None,
spacy_nlp=spacy_nlp,
train_mode=True)
src_tokens_length = len(src_tokens)
elif len(_pair_sen) == 4:
if _pair_sen[2]:
sen_o_t = _pair_sen[1] + ' ' + _pair_sen[2]
else:
sen_o_t = _pair_sen[1]
src_tokens = self.input_preprocess(src_text1=sen_o_t,
src_text2=_pair_sen[3],
spacy_nlp=spacy_nlp,
train_mode=True)
src_tokens_length = len(src_tokens)
return src_tokens, src_tokens_length, label_idx
def load(self):
"""data preprocess loader"""
train_dataset_list = []
@@ -73,30 +100,8 @@ class FastTextDataPreProcess():
with open(self.train_path, 'r', newline='', encoding='utf-8') as src_file:
reader = csv.reader(src_file, delimiter=",", quotechar='"')
for _, _pair_sen in enumerate(reader):
label_idx = int(_pair_sen[0]) - 1
if len(_pair_sen) == 3:
src_tokens = self.input_preprocess(src_text1=_pair_sen[1],
src_text2=_pair_sen[2],
spacy_nlp=spacy_nlp,
train_mode=True)
src_tokens_length = len(src_tokens)
elif len(_pair_sen) == 2:
src_tokens = self.input_preprocess(src_text1=_pair_sen[1],
src_text2=None,
spacy_nlp=spacy_nlp,
train_mode=True)
src_tokens_length = len(src_tokens)
elif len(_pair_sen) == 4:
if _pair_sen[2]:
sen_o_t = _pair_sen[1] + ' ' + _pair_sen[2]
else:
sen_o_t = _pair_sen[1]
src_tokens = self.input_preprocess(src_text1=sen_o_t,
src_text2=_pair_sen[3],
spacy_nlp=spacy_nlp,
train_mode=True)
src_tokens_length = len(src_tokens)
src_tokens, src_tokens_length, label_idx = self.common_block(_pair_sen=_pair_sen,
spacy_nlp=spacy_nlp)
train_dataset_list.append([src_tokens, src_tokens_length, label_idx])
print("Begin to process test data...")
@@ -274,7 +279,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--train_file', type=str, required=True, help='train dataset file path')
parser.add_argument('--test_file', type=str, required=True, help='test dataset file path')
parser.add_argument('--class_num', type=int, required=True, help='Dataset classe number')
parser.add_argument('--class_num', type=int, required=True, help='Dataset class number')
parser.add_argument('--ngram', type=int, default=2, required=False)
parser.add_argument('--max_len', type=int, required=False, help='max length sentence in dataset')
parser.add_argument('--bucket', type=ast.literal_eval, default=[64, 128, 467], help='bucket sequence length.')


Loading…
Cancel
Save