| @@ -201,9 +201,6 @@ class BeamSearch(Generator): | |||
| predictions = torch.cat([predictions, preds.unsqueeze(2)], dim=2) | |||
| state = repeat(state, beam_size) | |||
| parent_idx_list = [] | |||
| pred_list = [] | |||
| if max_gen_len is None: | |||
| max_gen_len = self.max_gen_len | |||
| for step in range(2, max_gen_len + 1): | |||
| @@ -229,11 +226,11 @@ class BeamSearch(Generator): | |||
| pre_eos_mask = (1 - torch.not_equal(pre_ids, eos_id).float()) + \ | |||
| (1 - torch.not_equal(pre_ids, self.pad_id).float()) | |||
| scores = scores * (1 - pre_eos_mask) + \ | |||
| pre_eos_mask.repeat(1, 1, self.vocab_size) * scores_after_end | |||
| scores = scores * (1 - pre_eos_mask) + pre_eos_mask.repeat( | |||
| 1, 1, self.vocab_size) * scores_after_end | |||
| if self.length_average: | |||
| scaled_value = pre_eos_mask + (1 - pre_eos_mask) * (1 - | |||
| 1 / step) | |||
| scaled_value = \ | |||
| pre_eos_mask + (1 - pre_eos_mask) * (1 - 1 / step) | |||
| sequence_scores = sequence_scores.unsqueeze(2) * scaled_value | |||
| scaled_value = pre_eos_mask + (1 - pre_eos_mask) * (1 / step) | |||
| scores = scores * scaled_value | |||
| @@ -171,7 +171,7 @@ class BPETextField(object): | |||
| src_token.append(list(chain(*utts))[-self.max_len:]) | |||
| # Position ids | |||
| pos = [list(range(l)) for l in utt_lens] | |||
| pos = [list(range(utt_len)) for utt_len in utt_lens] | |||
| src_pos.append(list(chain(*pos))[-self.max_len:]) | |||
| # Turn ids | |||
| @@ -205,15 +205,15 @@ class BPETextField(object): | |||
| understand = [self.understand_ids for _ in samples] | |||
| understand_token = np.array(understand).astype('int64') | |||
| batch['understand_token'] = understand_token | |||
| batch['understand_mask'] = (understand_token != | |||
| self.pad_id).astype('int64') | |||
| batch['understand_mask'] = \ | |||
| (understand_token != self.pad_id).astype('int64') | |||
| if self.policy_ids and self.policy: | |||
| policy = [self.policy_ids for _ in samples] | |||
| policy_token = np.array(policy).astype('int64') | |||
| batch['policy_token'] = policy_token | |||
| batch['policy_mask'] = (policy_token != | |||
| self.pad_id).astype('int64') | |||
| batch['policy_mask'] = \ | |||
| (policy_token != self.pad_id).astype('int64') | |||
| if 'tgt' in samples[0]: | |||
| tgt = [sp['tgt'] for sp in samples] | |||
| @@ -421,7 +421,7 @@ class MultiWOZBPETextField(BPETextField): | |||
| try: | |||
| log_str += 'turn num:%d, dial num: %d, batch num: %d last batch len: %d\n' % ( | |||
| k, len(turn_bucket[k]), len(batches), len(batches[-1])) | |||
| except: | |||
| except Exception: | |||
| log_str += 'turn num:%d, dial num: %d, batch num: %d last batch len: %d\n' % ( | |||
| k, len(turn_bucket[k]), len(batches), 0.0) | |||
| # print("turn num:%d, dial num:v%d, batch num: %d, "%(k, len(turn_bucket[k]), len(batches))) | |||
| @@ -520,7 +520,7 @@ class MultiWOZBPETextField(BPETextField): | |||
| ns) is not str else ns | |||
| if ns == "'s": | |||
| continue | |||
| except: | |||
| except Exception: | |||
| continue | |||
| if not constraint_dict.get(domain): | |||
| constraint_dict[domain] = {} | |||
| @@ -670,10 +670,9 @@ class MultiWOZBPETextField(BPETextField): | |||
| pv_turn['aspn'] + pv_turn['resp'] | |||
| ] | |||
| else: | |||
| pv_context = pv_turn['labels'] + [ | |||
| pv_turn['bspn'] + pv_turn['db'] + pv_turn['aspn'] + | |||
| pv_turn['resp'] | |||
| ] | |||
| pv_info = pv_turn['bspn'] + pv_turn['db'] + pv_turn[ | |||
| 'aspn'] + pv_turn['resp'] | |||
| pv_context = pv_turn['labels'] + [pv_info] | |||
| # prompt response, add sos_r | |||
| inputs['src'] = pv_context + [context] | |||
| @@ -229,9 +229,8 @@ class BPETextField(object): | |||
| # 10% randomly change token to random token | |||
| elif prob < 0.9: | |||
| output_chars.append( | |||
| random.randint(1, self.vocab_size - | |||
| 1)) # start from 1, to exclude pad_id | |||
| tmp = random.randint(1, self.vocab_size - 1) | |||
| output_chars.append(tmp) # start from 1, to exclude pad_id | |||
| # 10% randomly change token to current token | |||
| else: | |||
| @@ -401,7 +400,7 @@ class BPETextField(object): | |||
| build symmetric score matrix | |||
| """ | |||
| assert self.num_process == 1 | |||
| print(f'Building score matrix from examples ...') | |||
| print('Building score matrix from examples ...') | |||
| num = len(examples) | |||
| score_matrix = np.eye( | |||
| num, num, dtype='float32' | |||
| @@ -415,7 +414,7 @@ class BPETextField(object): | |||
| score_matrix[i][j] = score | |||
| score_matrix[j][i] = score | |||
| print(f'Built score matrix') | |||
| print('Built score matrix') | |||
| return score_matrix | |||
| def build_score_matrix_on_the_fly(self, | |||
| @@ -482,7 +481,7 @@ class BPETextField(object): | |||
| build score matrix | |||
| """ | |||
| assert self.num_process >= 2 and multiprocessing.cpu_count() >= 2 | |||
| print(f'Building score matrix from examples ...') | |||
| print('Building score matrix from examples ...') | |||
| results = [] | |||
| num = len(examples) | |||
| sub_num, res_num = num // self.num_process, num % self.num_process | |||
| @@ -512,7 +511,7 @@ class BPETextField(object): | |||
| score_matrix, | |||
| 1.) # in case of empty label of self, resulting in score 0. | |||
| print(f'Built score matrix') | |||
| print('Built score matrix') | |||
| return score_matrix | |||
| def extract_span_texts(self, text, label): | |||
| @@ -556,7 +555,7 @@ class BPETextField(object): | |||
| token_list = [ | |||
| tok for tok in map(str.strip, | |||
| re.split('(\W+)', text.lower())) | |||
| re.split('(\\W+)', text.lower())) | |||
| if len(tok) > 0 | |||
| ] | |||
| span_list = np.zeros(len(token_list), dtype=np.int32) | |||
| @@ -586,10 +585,10 @@ class BPETextField(object): | |||
| history_span_mask.append(span_mask) | |||
| history_label.append(self.fix_label(label)) | |||
| tmp = self.utts_filter_pred(history[:-1]) and all( | |||
| map(self.utt_filter_pred, history)) | |||
| if ( | |||
| (self.utts_filter_pred(history[:-1]) | |||
| and all(map(self.utt_filter_pred, history))) | |||
| or data_type == 'test' | |||
| tmp or data_type == 'test' | |||
| ) and role in self.trigger_role and t: # TODO consider test | |||
| src = [ | |||
| s[-self.max_utt_len:] | |||
| @@ -603,11 +602,18 @@ class BPETextField(object): | |||
| role | |||
| for role in history_role[:-1][-self.max_ctx_turn:] | |||
| ] | |||
| src = [[self.sos_u_id] + self.numericalize(s) + | |||
| [self.eos_u_id] | |||
| if roles[i] == 'user' else [self.sos_r_id] + | |||
| self.numericalize(s) + [self.eos_r_id] | |||
| for i, s in enumerate(src)] | |||
| new_src = [] | |||
| for i, s in enumerate(src): | |||
| if roles[i] == 'user': | |||
| user_or_sys = [self.eos_u_id] | |||
| else: | |||
| user_or_sys = [self.sos_r_id] | |||
| tmp = [self.sos_u_id | |||
| ] + self.numericalize(s) + user_or_sys | |||
| tmp = tmp + self.numericalize(s) + [self.eos_r_id] | |||
| new_src.append(tmp) | |||
| src_span_mask = [[0] + list(map(int, s)) + [0] | |||
| for s in src_span_mask] | |||
| @@ -619,7 +625,7 @@ class BPETextField(object): | |||
| ex = { | |||
| 'dialog_id': dialog_id, | |||
| 'turn_id': turn['turn_id'], | |||
| 'src': src, | |||
| 'src': new_src, | |||
| 'src_span_mask': src_span_mask, | |||
| 'tgt': tgt, | |||
| 'query_label': history_label[-2], | |||
| @@ -654,7 +660,7 @@ class BPETextField(object): | |||
| history, history_role, history_span_mask = [], [], [] | |||
| utterance, span_mask = [], [] | |||
| token_list = [ | |||
| tok for tok in map(str.strip, re.split('(\W+)', text.lower())) | |||
| tok for tok in map(str.strip, re.split('(\\W+)', text.lower())) | |||
| if len(tok) > 0 | |||
| ] | |||
| span_list = np.zeros(len(token_list), dtype=np.int32) | |||
| @@ -680,10 +686,17 @@ class BPETextField(object): | |||
| for s in history_span_mask[-self.max_ctx_turn:] | |||
| ] | |||
| roles = [role for role in history_role[-self.max_ctx_turn:]] | |||
| src = [[self.sos_u_id] + self.numericalize(s) + | |||
| [self.eos_u_id] if roles[i] == 'user' else [self.sos_r_id] + | |||
| self.numericalize(s) + [self.eos_r_id] | |||
| for i, s in enumerate(src)] | |||
| new_src = [] | |||
| for i, s in enumerate(src): | |||
| if roles[i] == 'user': | |||
| user_or_sys = [self.eos_u_id] | |||
| else: | |||
| user_or_sys = [self.sos_r_id] | |||
| tmp = [self.sos_u_id] + self.numericalize(s) + user_or_sys | |||
| tmp = tmp + self.numericalize(s) + [self.eos_r_id] | |||
| new_src.append(tmp) | |||
| src_span_mask = [[0] + list(map(int, s)) + [0] | |||
| for s in src_span_mask] | |||
| @@ -691,7 +704,7 @@ class BPETextField(object): | |||
| 'dialog_id': 'inference', | |||
| 'turn_id': 0, | |||
| 'role': role, | |||
| 'src': src, | |||
| 'src': new_src, | |||
| 'src_span_mask': src_span_mask, | |||
| 'query_label': { | |||
| 'DEFAULT_DOMAIN': { | |||
| @@ -734,7 +747,7 @@ class BPETextField(object): | |||
| token_list = [ | |||
| tok for tok in map(str.strip, | |||
| re.split('(\W+)', text.lower())) | |||
| re.split('(\\W+)', text.lower())) | |||
| if len(tok) > 0 | |||
| ] | |||
| span_list = np.zeros(len(token_list), dtype=np.int32) | |||
| @@ -763,10 +776,10 @@ class BPETextField(object): | |||
| history_role.append(role) | |||
| history_span_mask.append(span_mask) | |||
| if ((self.utts_filter_pred(history) | |||
| and all(map(self.utt_filter_pred, history))) | |||
| or data_type == 'test' | |||
| ) and role in self.trigger_role: # TODO consider test | |||
| tmp = self.utts_filter_pred(history) and all( | |||
| map(self.utt_filter_pred, history)) | |||
| tmp = tmp or data_type == 'test' | |||
| if tmp and role in self.trigger_role: # TODO consider test | |||
| src = [ | |||
| s[-self.max_utt_len:] | |||
| for s in history[-self.max_ctx_turn:] | |||
| @@ -778,11 +791,17 @@ class BPETextField(object): | |||
| roles = [ | |||
| role for role in history_role[-self.max_ctx_turn:] | |||
| ] | |||
| src = [[self.sos_u_id] + self.numericalize(s) + | |||
| [self.eos_u_id] | |||
| if roles[i] == 'user' else [self.sos_r_id] + | |||
| self.numericalize(s) + [self.eos_r_id] | |||
| for i, s in enumerate(src)] | |||
| new_src = [] | |||
| for i, s in enumerate(src): | |||
| if roles[i] == 'user': | |||
| user_or_sys = [self.eos_u_id] | |||
| else: | |||
| user_or_sys = [self.sos_r_id] | |||
| tmp = [self.sos_u_id | |||
| ] + self.numericalize(s) + user_or_sys | |||
| tmp = tmp + self.numericalize(s) + [self.eos_r_id] | |||
| new_src.append(tmp) | |||
| src_span_mask = [[0] + list(map(int, s)) + [0] | |||
| for s in src_span_mask] | |||
| @@ -790,7 +809,7 @@ class BPETextField(object): | |||
| 'dialog_id': dialog_id, | |||
| 'turn_id': turn['turn_id'], | |||
| 'role': role, | |||
| 'src': src, | |||
| 'src': new_src, | |||
| 'src_span_mask': src_span_mask, | |||
| 'query_label': self.fix_label(label), | |||
| 'extra_info': turn.get('extra_info', '') | |||
| @@ -829,7 +848,7 @@ class BPETextField(object): | |||
| src_token.append(list(chain(*utts))[-self.max_len:]) | |||
| # Position ids | |||
| pos = [list(range(l)) for l in utt_lens] | |||
| pos = [list(range(utt_len)) for utt_len in utt_lens] | |||
| src_pos.append(list(chain(*pos))[-self.max_len:]) | |||
| # Turn ids | |||
| @@ -887,15 +906,15 @@ class BPETextField(object): | |||
| understand = [self.understand_ids for _ in samples] | |||
| understand_token = np.array(understand).astype('int64') | |||
| batch['understand_token'] = understand_token | |||
| batch['understand_mask'] = (understand_token != | |||
| self.pad_id).astype('int64') | |||
| batch['understand_mask'] = \ | |||
| (understand_token != self.pad_id).astype('int64') | |||
| if self.policy_ids and self.policy: | |||
| policy = [self.policy_ids for _ in samples] | |||
| policy_token = np.array(policy).astype('int64') | |||
| batch['policy_token'] = policy_token | |||
| batch['policy_mask'] = (policy_token != | |||
| self.pad_id).astype('int64') | |||
| batch['policy_mask'] = \ | |||
| (policy_token != self.pad_id).astype('int64') | |||
| if 'tgt' in samples[0]: | |||
| tgt = [sp['tgt'] for sp in samples] | |||
| @@ -952,8 +971,8 @@ class IntentBPETextField(BPETextField): | |||
| # One example for each label | |||
| example_inds = [] | |||
| for l in set(labels.tolist()): | |||
| if l == -1: | |||
| for lable in set(labels.tolist()): | |||
| if lable == -1: | |||
| continue | |||
| ind = random.choice(cache[l]) | |||
| @@ -1001,7 +1020,7 @@ class IntentBPETextField(BPETextField): | |||
| src_token.append(list(chain(*utts))[-self.max_len:]) | |||
| # Position ids | |||
| pos = [list(range(l)) for l in utt_lens] | |||
| pos = [list(range(utt_len)) for utt_len in utt_lens] | |||
| src_pos.append(list(chain(*pos))[-self.max_len:]) | |||
| # Turn ids | |||
| @@ -325,13 +325,15 @@ class BasicTokenizer(object): | |||
| # as is Japanese Hiragana and Katakana. Those alphabets are used to write | |||
| # space-separated words, so they are not treated specially and handled | |||
| # like the all of the other languages. | |||
| if ((cp >= 0x4E00 and cp <= 0x9FFF) or (cp >= 0x3400 and cp <= 0x4DBF) | |||
| or (cp >= 0x20000 and cp <= 0x2A6DF) | |||
| or (cp >= 0x2A700 and cp <= 0x2B73F) | |||
| or (cp >= 0x2B740 and cp <= 0x2B81F) | |||
| or (cp >= 0x2B820 and cp <= 0x2CEAF) | |||
| or (cp >= 0xF900 and cp <= 0xFAFF) | |||
| or (cp >= 0x2F800 and cp <= 0x2FA1F)): | |||
| tmp = (cp >= 0x4E00 and cp <= 0x9FFF) | |||
| tmp = tmp or (cp >= 0x3400 and cp <= 0x4DBF) | |||
| tmp = tmp or (cp >= 0x20000 and cp <= 0x2A6DF) | |||
| tmp = tmp or (cp >= 0x2A700 and cp <= 0x2B73F) | |||
| tmp = tmp or (cp >= 0x2B740 and cp <= 0x2B81F) | |||
| tmp = tmp or (cp >= 0x2B820 and cp <= 0x2CEAF) | |||
| tmp = tmp or (cp >= 0xF900 and cp <= 0xFAFF) | |||
| tmp = tmp or (cp >= 0x2F800 and cp <= 0x2FA1F) | |||
| if tmp: | |||
| return True | |||
| return False | |||
| @@ -441,8 +443,11 @@ def _is_punctuation(char): | |||
| # Characters such as "^", "$", and "`" are not in the Unicode | |||
| # Punctuation class but we treat them as punctuation anyways, for | |||
| # consistency. | |||
| if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) | |||
| or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): | |||
| tmp = (cp >= 33 and cp <= 47) | |||
| tmp = tmp or (cp >= 58 and cp <= 64) | |||
| tmp = tmp or (cp >= 91 and cp <= 96) | |||
| tmp = tmp or (cp >= 123 and cp <= 126) | |||
| if tmp: | |||
| return True | |||
| cat = unicodedata.category(char) | |||
| if cat.startswith('P'): | |||
| @@ -589,7 +594,7 @@ class GPT2Tokenizer(object): | |||
| j = word.index(first, i) | |||
| new_word.extend(word[i:j]) | |||
| i = j | |||
| except: | |||
| except Exception: | |||
| new_word.extend(word[i:]) | |||
| break | |||
| @@ -625,8 +630,10 @@ class GPT2Tokenizer(object): | |||
| def convert_tokens_to_ids(self, tokens): | |||
| """ Converts a sequence of tokens into ids using the vocab. """ | |||
| ids = [] | |||
| if isinstance(tokens, str) or (sys.version_info[0] == 2 | |||
| and isinstance(tokens, unicode)): | |||
| python_version_3 = isinstance(tokens, str) | |||
| python_version_2 = ( | |||
| sys.version_info[0] == 2 and isinstance(tokens, unicode)) | |||
| if python_version_3 or python_version_2: | |||
| if tokens in self.special_tokens: | |||
| return self.special_tokens[tokens] | |||
| else: | |||
| @@ -46,8 +46,8 @@ def clean_replace(s, r, t, forward=True, backward=False): | |||
| return s, -1 | |||
| if forward: | |||
| while idx_r < len(s) and (s[idx_r].isalpha() | |||
| or s[idx_r].isdigit()): | |||
| while \ | |||
| idx_r < len(s) and (s[idx_r].isalpha() or s[idx_r].isdigit()): | |||
| idx_r += 1 | |||
| elif idx_r != len(s) and (s[idx_r].isalpha() or s[idx_r].isdigit()): | |||
| return s, -1 | |||
| @@ -122,13 +122,15 @@ class MultiWOZVocab(object): | |||
| self._word2idx[word] = idx | |||
| def construct(self): | |||
| l = sorted(self._freq_dict.keys(), key=lambda x: -self._freq_dict[x]) | |||
| freq_dict_sorted = sorted( | |||
| self._freq_dict.keys(), key=lambda x: -self._freq_dict[x]) | |||
| print('Vocabulary size including oov: %d' % | |||
| (len(l) + len(self._idx2word))) | |||
| if len(l) + len(self._idx2word) < self.vocab_size: | |||
| (len(freq_dict_sorted) + len(self._idx2word))) | |||
| if len(freq_dict_sorted) + len(self._idx2word) < self.vocab_size: | |||
| logging.warning( | |||
| 'actual label set smaller than that configured: {}/{}'.format( | |||
| len(l) + len(self._idx2word), self.vocab_size)) | |||
| len(freq_dict_sorted) + len(self._idx2word), | |||
| self.vocab_size)) | |||
| for word in ontology.all_domains + ['general']: | |||
| word = '[' + word + ']' | |||
| self._add_to_vocab(word) | |||
| @@ -137,10 +139,10 @@ class MultiWOZVocab(object): | |||
| self._add_to_vocab(word) | |||
| for word in ontology.all_slots: | |||
| self._add_to_vocab(word) | |||
| for word in l: | |||
| for word in freq_dict_sorted: | |||
| if word.startswith('[value_') and word.endswith(']'): | |||
| self._add_to_vocab(word) | |||
| for word in l: | |||
| for word in freq_dict_sorted: | |||
| self._add_to_vocab(word) | |||
| self.vocab_size_oov = len(self._idx2word) | |||
| @@ -192,13 +194,13 @@ class MultiWOZVocab(object): | |||
| else: | |||
| return self._idx2word[idx] + '(o)' | |||
| def sentence_decode(self, index_list, eos=None, indicate_oov=False): | |||
| l = [self.decode(_, indicate_oov) for _ in index_list] | |||
| if not eos or eos not in l: | |||
| return ' '.join(l) | |||
| else: | |||
| idx = l.index(eos) | |||
| return ' '.join(l[:idx]) | |||
| def nl_decode(self, l, eos=None): | |||
| return [self.sentence_decode(_, eos) + '\n' for _ in l] | |||
| # def sentence_decode(self, index_list, eos=None, indicate_oov=False): | |||
| # l = [self.decode(_, indicate_oov) for _ in index_list] | |||
| # if not eos or eos not in l: | |||
| # return ' '.join(l) | |||
| # else: | |||
| # idx = l.index(eos) | |||
| # return ' '.join(l[:idx]) | |||
| # | |||
| # def nl_decode(self, l, eos=None): | |||
| # return [self.sentence_decode(_, eos) + '\n' for _ in l] | |||