| @@ -56,7 +56,7 @@ class SummarizationLoader(JsonLoader): | |||||
| return ds | return ds | ||||
| def process(self, paths, vocab_size, vocab_path, sent_max_len, doc_max_timesteps, domain=False, tag=False, load_vocab=True): | |||||
| def process(self, paths, vocab_size, vocab_path, sent_max_len, doc_max_timesteps, domain=False, tag=False, load_vocab_file=True): | |||||
| """ | """ | ||||
| :param paths: dict path for each dataset | :param paths: dict path for each dataset | ||||
| :param vocab_size: int max_size for vocab | :param vocab_size: int max_size for vocab | ||||
| @@ -65,7 +65,7 @@ class SummarizationLoader(JsonLoader): | |||||
| :param doc_max_timesteps: int max sentence number of the document | :param doc_max_timesteps: int max sentence number of the document | ||||
| :param domain: bool build vocab for publication, use 'X' for unknown | :param domain: bool build vocab for publication, use 'X' for unknown | ||||
| :param tag: bool build vocab for tag, use 'X' for unknown | :param tag: bool build vocab for tag, use 'X' for unknown | ||||
| :param load_vocab: bool build vocab (False) or load vocab (True) | |||||
| :param load_vocab_file: bool build vocab (False) or load vocab (True) | |||||
| :return: DataBundle | :return: DataBundle | ||||
| datasets: dict keys correspond to the paths dict | datasets: dict keys correspond to the paths dict | ||||
| vocabs: dict key: vocab(if "train" in paths), domain(if domain=True), tag(if tag=True) | vocabs: dict key: vocab(if "train" in paths), domain(if domain=True), tag(if tag=True) | ||||
| @@ -146,7 +146,7 @@ class SummarizationLoader(JsonLoader): | |||||
| train_ds = datasets[key] | train_ds = datasets[key] | ||||
| vocab_dict = {} | vocab_dict = {} | ||||
| if load_vocab == False: | |||||
| if load_vocab_file == False: | |||||
| logger.info("[INFO] Build new vocab from training dataset!") | logger.info("[INFO] Build new vocab from training dataset!") | ||||
| if train_ds == None: | if train_ds == None: | ||||
| raise ValueError("Lack train file to build vocabulary!") | raise ValueError("Lack train file to build vocabulary!") | ||||
| @@ -36,8 +36,8 @@ import pickle | |||||
| from nltk.tokenize import sent_tokenize | from nltk.tokenize import sent_tokenize | ||||
| import utils | |||||
| from logger import * | |||||
| import tools.utils | |||||
| from tools.logger import * | |||||
| # <s> and </s> are used in the data files to segment the abstracts into sentences. They don't receive vocab ids. | # <s> and </s> are used in the data files to segment the abstracts into sentences. They don't receive vocab ids. | ||||
| SENTENCE_START = '<s>' | SENTENCE_START = '<s>' | ||||
| @@ -313,7 +313,8 @@ class Example(object): | |||||
| for sent in article_sents: | for sent in article_sents: | ||||
| article_words = sent.split() | article_words = sent.split() | ||||
| self.enc_sent_len.append(len(article_words)) # store the length after truncation but before padding | self.enc_sent_len.append(len(article_words)) # store the length after truncation but before padding | ||||
| self.enc_sent_input.append([vocab.word2id(w) for w in article_words]) # list of word ids; OOVs are represented by the id for UNK token | |||||
| # self.enc_sent_input.append([vocab.word2id(w) for w in article_words]) # list of word ids; OOVs are represented by the id for UNK token | |||||
| self.enc_sent_input.append([vocab.word2id(w.lower()) for w in article_words]) # list of word ids; OOVs are represented by the id for UNK token | |||||
| self._pad_encoder_input(vocab.word2id('[PAD]')) | self._pad_encoder_input(vocab.word2id('[PAD]')) | ||||
| # Store the original strings | # Store the original strings | ||||
| @@ -29,7 +29,7 @@ import torch.nn | |||||
| os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' | os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' | ||||
| os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' | os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' | ||||
| sys.path.append('/remote-home/dqwang/FastNLP/fastNLP/') | |||||
| sys.path.append('/remote-home/dqwang/FastNLP/fastNLP_brxx/') | |||||
| from fastNLP.core.const import Const | from fastNLP.core.const import Const | ||||