diff --git a/model_zoo/official/cv/FCN8s/src/data/build_seg_data.py b/model_zoo/official/cv/FCN8s/src/data/build_seg_data.py index 8822faefd5..b5655d7166 100644 --- a/model_zoo/official/cv/FCN8s/src/data/build_seg_data.py +++ b/model_zoo/official/cv/FCN8s/src/data/build_seg_data.py @@ -38,7 +38,7 @@ def parse_args(): if __name__ == '__main__': args = parse_args() - datas = [] + data_list = [] with open(args.data_lst) as f: lines = f.readlines() if args.shuffle: @@ -65,14 +65,14 @@ if __name__ == '__main__': sample_['data'] = f.read() with open(os.path.join(args.data_root, label_path), 'rb') as f: sample_['label'] = f.read() - datas.append(sample_) + data_list.append(sample_) cnt += 1 if cnt % 1000 == 0: - writer.write_raw_data(datas) + writer.write_raw_data(data_list) print('number of samples written:', cnt) - datas = [] + data_list = [] - if datas: - writer.write_raw_data(datas) + if data_list: + writer.write_raw_data(data_list) writer.commit() print('number of samples written:', cnt) diff --git a/model_zoo/official/nlp/fasttext/src/dataset.py b/model_zoo/official/nlp/fasttext/src/dataset.py index 7135bd4590..dc6810315c 100644 --- a/model_zoo/official/nlp/fasttext/src/dataset.py +++ b/model_zoo/official/nlp/fasttext/src/dataset.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -63,6 +63,33 @@ class FastTextDataPreProcess(): self.vec2words[1] = 'UNK' self.str_html = re.compile(r'<[^>]+>') + def common_block(self, _pair_sen, spacy_nlp): + """common block for data preprocessing""" + label_idx = int(_pair_sen[0]) - 1 + if len(_pair_sen) == 3: + src_tokens = self.input_preprocess(src_text1=_pair_sen[1], + src_text2=_pair_sen[2], + spacy_nlp=spacy_nlp, + train_mode=True) + src_tokens_length = len(src_tokens) + elif len(_pair_sen) == 2: + src_tokens = self.input_preprocess(src_text1=_pair_sen[1], + src_text2=None, + spacy_nlp=spacy_nlp, + train_mode=True) + src_tokens_length = len(src_tokens) + elif len(_pair_sen) == 4: + if _pair_sen[2]: + sen_o_t = _pair_sen[1] + ' ' + _pair_sen[2] + else: + sen_o_t = _pair_sen[1] + src_tokens = self.input_preprocess(src_text1=sen_o_t, + src_text2=_pair_sen[3], + spacy_nlp=spacy_nlp, + train_mode=True) + src_tokens_length = len(src_tokens) + return src_tokens, src_tokens_length, label_idx + def load(self): """data preprocess loader""" train_dataset_list = [] @@ -73,30 +100,8 @@ class FastTextDataPreProcess(): with open(self.train_path, 'r', newline='', encoding='utf-8') as src_file: reader = csv.reader(src_file, delimiter=",", quotechar='"') for _, _pair_sen in enumerate(reader): - label_idx = int(_pair_sen[0]) - 1 - if len(_pair_sen) == 3: - src_tokens = self.input_preprocess(src_text1=_pair_sen[1], - src_text2=_pair_sen[2], - spacy_nlp=spacy_nlp, - train_mode=True) - src_tokens_length = len(src_tokens) - elif len(_pair_sen) == 2: - src_tokens = self.input_preprocess(src_text1=_pair_sen[1], - src_text2=None, - spacy_nlp=spacy_nlp, - train_mode=True) - src_tokens_length = len(src_tokens) - elif len(_pair_sen) == 4: - if _pair_sen[2]: - sen_o_t = _pair_sen[1] + ' ' + _pair_sen[2] - else: - sen_o_t = _pair_sen[1] - src_tokens = self.input_preprocess(src_text1=sen_o_t, - src_text2=_pair_sen[3], - spacy_nlp=spacy_nlp, - train_mode=True) - src_tokens_length = len(src_tokens) - + src_tokens, src_tokens_length, label_idx = self.common_block(_pair_sen=_pair_sen, + spacy_nlp=spacy_nlp) train_dataset_list.append([src_tokens, src_tokens_length, label_idx]) print("Begin to process test data...") @@ -274,7 +279,7 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--train_file', type=str, required=True, help='train dataset file path') parser.add_argument('--test_file', type=str, required=True, help='test dataset file path') - parser.add_argument('--class_num', type=int, required=True, help='Dataset classe number') + parser.add_argument('--class_num', type=int, required=True, help='Dataset class number') parser.add_argument('--ngram', type=int, default=2, required=False) parser.add_argument('--max_len', type=int, required=False, help='max length sentence in dataset') parser.add_argument('--bucket', type=ast.literal_eval, default=[64, 128, 467], help='bucket sequence length.')