|
|
@@ -229,9 +229,8 @@ class BPETextField(object): |
|
|
|
|
|
|
|
|
# 10% randomly change token to random token |
|
|
# 10% randomly change token to random token |
|
|
elif prob < 0.9: |
|
|
elif prob < 0.9: |
|
|
output_chars.append( |
|
|
|
|
|
random.randint(1, self.vocab_size - |
|
|
|
|
|
1)) # start from 1, to exclude pad_id |
|
|
|
|
|
|
|
|
tmp = random.randint(1, self.vocab_size - 1) |
|
|
|
|
|
output_chars.append(tmp) # start from 1, to exclude pad_id |
|
|
|
|
|
|
|
|
# 10% randomly change token to current token |
|
|
# 10% randomly change token to current token |
|
|
else: |
|
|
else: |
|
|
@@ -401,7 +400,7 @@ class BPETextField(object): |
|
|
build symmetric score matrix |
|
|
build symmetric score matrix |
|
|
""" |
|
|
""" |
|
|
assert self.num_process == 1 |
|
|
assert self.num_process == 1 |
|
|
print(f'Building score matrix from examples ...') |
|
|
|
|
|
|
|
|
print('Building score matrix from examples ...') |
|
|
num = len(examples) |
|
|
num = len(examples) |
|
|
score_matrix = np.eye( |
|
|
score_matrix = np.eye( |
|
|
num, num, dtype='float32' |
|
|
num, num, dtype='float32' |
|
|
@@ -415,7 +414,7 @@ class BPETextField(object): |
|
|
score_matrix[i][j] = score |
|
|
score_matrix[i][j] = score |
|
|
score_matrix[j][i] = score |
|
|
score_matrix[j][i] = score |
|
|
|
|
|
|
|
|
print(f'Built score matrix') |
|
|
|
|
|
|
|
|
print('Built score matrix') |
|
|
return score_matrix |
|
|
return score_matrix |
|
|
|
|
|
|
|
|
def build_score_matrix_on_the_fly(self, |
|
|
def build_score_matrix_on_the_fly(self, |
|
|
@@ -482,7 +481,7 @@ class BPETextField(object): |
|
|
build score matrix |
|
|
build score matrix |
|
|
""" |
|
|
""" |
|
|
assert self.num_process >= 2 and multiprocessing.cpu_count() >= 2 |
|
|
assert self.num_process >= 2 and multiprocessing.cpu_count() >= 2 |
|
|
print(f'Building score matrix from examples ...') |
|
|
|
|
|
|
|
|
print('Building score matrix from examples ...') |
|
|
results = [] |
|
|
results = [] |
|
|
num = len(examples) |
|
|
num = len(examples) |
|
|
sub_num, res_num = num // self.num_process, num % self.num_process |
|
|
sub_num, res_num = num // self.num_process, num % self.num_process |
|
|
@@ -512,7 +511,7 @@ class BPETextField(object): |
|
|
score_matrix, |
|
|
score_matrix, |
|
|
1.) # in case of empty label of self, resulting in score 0. |
|
|
1.) # in case of empty label of self, resulting in score 0. |
|
|
|
|
|
|
|
|
print(f'Built score matrix') |
|
|
|
|
|
|
|
|
print('Built score matrix') |
|
|
return score_matrix |
|
|
return score_matrix |
|
|
|
|
|
|
|
|
def extract_span_texts(self, text, label): |
|
|
def extract_span_texts(self, text, label): |
|
|
@@ -556,7 +555,7 @@ class BPETextField(object): |
|
|
|
|
|
|
|
|
token_list = [ |
|
|
token_list = [ |
|
|
tok for tok in map(str.strip, |
|
|
tok for tok in map(str.strip, |
|
|
re.split('(\W+)', text.lower())) |
|
|
|
|
|
|
|
|
re.split('(\\W+)', text.lower())) |
|
|
if len(tok) > 0 |
|
|
if len(tok) > 0 |
|
|
] |
|
|
] |
|
|
span_list = np.zeros(len(token_list), dtype=np.int32) |
|
|
span_list = np.zeros(len(token_list), dtype=np.int32) |
|
|
@@ -586,10 +585,10 @@ class BPETextField(object): |
|
|
history_span_mask.append(span_mask) |
|
|
history_span_mask.append(span_mask) |
|
|
history_label.append(self.fix_label(label)) |
|
|
history_label.append(self.fix_label(label)) |
|
|
|
|
|
|
|
|
|
|
|
tmp = self.utts_filter_pred(history[:-1]) and all( |
|
|
|
|
|
map(self.utt_filter_pred, history)) |
|
|
if ( |
|
|
if ( |
|
|
(self.utts_filter_pred(history[:-1]) |
|
|
|
|
|
and all(map(self.utt_filter_pred, history))) |
|
|
|
|
|
or data_type == 'test' |
|
|
|
|
|
|
|
|
tmp or data_type == 'test' |
|
|
) and role in self.trigger_role and t: # TODO consider test |
|
|
) and role in self.trigger_role and t: # TODO consider test |
|
|
src = [ |
|
|
src = [ |
|
|
s[-self.max_utt_len:] |
|
|
s[-self.max_utt_len:] |
|
|
@@ -603,11 +602,18 @@ class BPETextField(object): |
|
|
role |
|
|
role |
|
|
for role in history_role[:-1][-self.max_ctx_turn:] |
|
|
for role in history_role[:-1][-self.max_ctx_turn:] |
|
|
] |
|
|
] |
|
|
src = [[self.sos_u_id] + self.numericalize(s) + |
|
|
|
|
|
[self.eos_u_id] |
|
|
|
|
|
if roles[i] == 'user' else [self.sos_r_id] + |
|
|
|
|
|
self.numericalize(s) + [self.eos_r_id] |
|
|
|
|
|
for i, s in enumerate(src)] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
new_src = [] |
|
|
|
|
|
for i, s in enumerate(src): |
|
|
|
|
|
if roles[i] == 'user': |
|
|
|
|
|
user_or_sys = [self.eos_u_id] |
|
|
|
|
|
else: |
|
|
|
|
|
user_or_sys = [self.sos_r_id] |
|
|
|
|
|
tmp = [self.sos_u_id |
|
|
|
|
|
] + self.numericalize(s) + user_or_sys |
|
|
|
|
|
tmp = tmp + self.numericalize(s) + [self.eos_r_id] |
|
|
|
|
|
new_src.append(tmp) |
|
|
|
|
|
|
|
|
src_span_mask = [[0] + list(map(int, s)) + [0] |
|
|
src_span_mask = [[0] + list(map(int, s)) + [0] |
|
|
for s in src_span_mask] |
|
|
for s in src_span_mask] |
|
|
|
|
|
|
|
|
@@ -619,7 +625,7 @@ class BPETextField(object): |
|
|
ex = { |
|
|
ex = { |
|
|
'dialog_id': dialog_id, |
|
|
'dialog_id': dialog_id, |
|
|
'turn_id': turn['turn_id'], |
|
|
'turn_id': turn['turn_id'], |
|
|
'src': src, |
|
|
|
|
|
|
|
|
'src': new_src, |
|
|
'src_span_mask': src_span_mask, |
|
|
'src_span_mask': src_span_mask, |
|
|
'tgt': tgt, |
|
|
'tgt': tgt, |
|
|
'query_label': history_label[-2], |
|
|
'query_label': history_label[-2], |
|
|
@@ -654,7 +660,7 @@ class BPETextField(object): |
|
|
history, history_role, history_span_mask = [], [], [] |
|
|
history, history_role, history_span_mask = [], [], [] |
|
|
utterance, span_mask = [], [] |
|
|
utterance, span_mask = [], [] |
|
|
token_list = [ |
|
|
token_list = [ |
|
|
tok for tok in map(str.strip, re.split('(\W+)', text.lower())) |
|
|
|
|
|
|
|
|
tok for tok in map(str.strip, re.split('(\\W+)', text.lower())) |
|
|
if len(tok) > 0 |
|
|
if len(tok) > 0 |
|
|
] |
|
|
] |
|
|
span_list = np.zeros(len(token_list), dtype=np.int32) |
|
|
span_list = np.zeros(len(token_list), dtype=np.int32) |
|
|
@@ -680,10 +686,17 @@ class BPETextField(object): |
|
|
for s in history_span_mask[-self.max_ctx_turn:] |
|
|
for s in history_span_mask[-self.max_ctx_turn:] |
|
|
] |
|
|
] |
|
|
roles = [role for role in history_role[-self.max_ctx_turn:]] |
|
|
roles = [role for role in history_role[-self.max_ctx_turn:]] |
|
|
src = [[self.sos_u_id] + self.numericalize(s) + |
|
|
|
|
|
[self.eos_u_id] if roles[i] == 'user' else [self.sos_r_id] + |
|
|
|
|
|
self.numericalize(s) + [self.eos_r_id] |
|
|
|
|
|
for i, s in enumerate(src)] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
new_src = [] |
|
|
|
|
|
for i, s in enumerate(src): |
|
|
|
|
|
if roles[i] == 'user': |
|
|
|
|
|
user_or_sys = [self.eos_u_id] |
|
|
|
|
|
else: |
|
|
|
|
|
user_or_sys = [self.sos_r_id] |
|
|
|
|
|
tmp = [self.sos_u_id] + self.numericalize(s) + user_or_sys |
|
|
|
|
|
tmp = tmp + self.numericalize(s) + [self.eos_r_id] |
|
|
|
|
|
new_src.append(tmp) |
|
|
|
|
|
|
|
|
src_span_mask = [[0] + list(map(int, s)) + [0] |
|
|
src_span_mask = [[0] + list(map(int, s)) + [0] |
|
|
for s in src_span_mask] |
|
|
for s in src_span_mask] |
|
|
|
|
|
|
|
|
@@ -691,7 +704,7 @@ class BPETextField(object): |
|
|
'dialog_id': 'inference', |
|
|
'dialog_id': 'inference', |
|
|
'turn_id': 0, |
|
|
'turn_id': 0, |
|
|
'role': role, |
|
|
'role': role, |
|
|
'src': src, |
|
|
|
|
|
|
|
|
'src': new_src, |
|
|
'src_span_mask': src_span_mask, |
|
|
'src_span_mask': src_span_mask, |
|
|
'query_label': { |
|
|
'query_label': { |
|
|
'DEFAULT_DOMAIN': { |
|
|
'DEFAULT_DOMAIN': { |
|
|
@@ -734,7 +747,7 @@ class BPETextField(object): |
|
|
|
|
|
|
|
|
token_list = [ |
|
|
token_list = [ |
|
|
tok for tok in map(str.strip, |
|
|
tok for tok in map(str.strip, |
|
|
re.split('(\W+)', text.lower())) |
|
|
|
|
|
|
|
|
re.split('(\\W+)', text.lower())) |
|
|
if len(tok) > 0 |
|
|
if len(tok) > 0 |
|
|
] |
|
|
] |
|
|
span_list = np.zeros(len(token_list), dtype=np.int32) |
|
|
span_list = np.zeros(len(token_list), dtype=np.int32) |
|
|
@@ -763,10 +776,10 @@ class BPETextField(object): |
|
|
history_role.append(role) |
|
|
history_role.append(role) |
|
|
history_span_mask.append(span_mask) |
|
|
history_span_mask.append(span_mask) |
|
|
|
|
|
|
|
|
if ((self.utts_filter_pred(history) |
|
|
|
|
|
and all(map(self.utt_filter_pred, history))) |
|
|
|
|
|
or data_type == 'test' |
|
|
|
|
|
) and role in self.trigger_role: # TODO consider test |
|
|
|
|
|
|
|
|
tmp = self.utts_filter_pred(history) and all( |
|
|
|
|
|
map(self.utt_filter_pred, history)) |
|
|
|
|
|
tmp = tmp or data_type == 'test' |
|
|
|
|
|
if tmp and role in self.trigger_role: # TODO consider test |
|
|
src = [ |
|
|
src = [ |
|
|
s[-self.max_utt_len:] |
|
|
s[-self.max_utt_len:] |
|
|
for s in history[-self.max_ctx_turn:] |
|
|
for s in history[-self.max_ctx_turn:] |
|
|
@@ -778,11 +791,17 @@ class BPETextField(object): |
|
|
roles = [ |
|
|
roles = [ |
|
|
role for role in history_role[-self.max_ctx_turn:] |
|
|
role for role in history_role[-self.max_ctx_turn:] |
|
|
] |
|
|
] |
|
|
src = [[self.sos_u_id] + self.numericalize(s) + |
|
|
|
|
|
[self.eos_u_id] |
|
|
|
|
|
if roles[i] == 'user' else [self.sos_r_id] + |
|
|
|
|
|
self.numericalize(s) + [self.eos_r_id] |
|
|
|
|
|
for i, s in enumerate(src)] |
|
|
|
|
|
|
|
|
new_src = [] |
|
|
|
|
|
for i, s in enumerate(src): |
|
|
|
|
|
if roles[i] == 'user': |
|
|
|
|
|
user_or_sys = [self.eos_u_id] |
|
|
|
|
|
else: |
|
|
|
|
|
user_or_sys = [self.sos_r_id] |
|
|
|
|
|
tmp = [self.sos_u_id |
|
|
|
|
|
] + self.numericalize(s) + user_or_sys |
|
|
|
|
|
tmp = tmp + self.numericalize(s) + [self.eos_r_id] |
|
|
|
|
|
new_src.append(tmp) |
|
|
|
|
|
|
|
|
src_span_mask = [[0] + list(map(int, s)) + [0] |
|
|
src_span_mask = [[0] + list(map(int, s)) + [0] |
|
|
for s in src_span_mask] |
|
|
for s in src_span_mask] |
|
|
|
|
|
|
|
|
@@ -790,7 +809,7 @@ class BPETextField(object): |
|
|
'dialog_id': dialog_id, |
|
|
'dialog_id': dialog_id, |
|
|
'turn_id': turn['turn_id'], |
|
|
'turn_id': turn['turn_id'], |
|
|
'role': role, |
|
|
'role': role, |
|
|
'src': src, |
|
|
|
|
|
|
|
|
'src': new_src, |
|
|
'src_span_mask': src_span_mask, |
|
|
'src_span_mask': src_span_mask, |
|
|
'query_label': self.fix_label(label), |
|
|
'query_label': self.fix_label(label), |
|
|
'extra_info': turn.get('extra_info', '') |
|
|
'extra_info': turn.get('extra_info', '') |
|
|
@@ -829,7 +848,7 @@ class BPETextField(object): |
|
|
src_token.append(list(chain(*utts))[-self.max_len:]) |
|
|
src_token.append(list(chain(*utts))[-self.max_len:]) |
|
|
|
|
|
|
|
|
# Position ids |
|
|
# Position ids |
|
|
pos = [list(range(l)) for l in utt_lens] |
|
|
|
|
|
|
|
|
pos = [list(range(utt_len)) for utt_len in utt_lens] |
|
|
src_pos.append(list(chain(*pos))[-self.max_len:]) |
|
|
src_pos.append(list(chain(*pos))[-self.max_len:]) |
|
|
|
|
|
|
|
|
# Turn ids |
|
|
# Turn ids |
|
|
@@ -887,15 +906,15 @@ class BPETextField(object): |
|
|
understand = [self.understand_ids for _ in samples] |
|
|
understand = [self.understand_ids for _ in samples] |
|
|
understand_token = np.array(understand).astype('int64') |
|
|
understand_token = np.array(understand).astype('int64') |
|
|
batch['understand_token'] = understand_token |
|
|
batch['understand_token'] = understand_token |
|
|
batch['understand_mask'] = (understand_token != |
|
|
|
|
|
self.pad_id).astype('int64') |
|
|
|
|
|
|
|
|
batch['understand_mask'] = \ |
|
|
|
|
|
(understand_token != self.pad_id).astype('int64') |
|
|
|
|
|
|
|
|
if self.policy_ids and self.policy: |
|
|
if self.policy_ids and self.policy: |
|
|
policy = [self.policy_ids for _ in samples] |
|
|
policy = [self.policy_ids for _ in samples] |
|
|
policy_token = np.array(policy).astype('int64') |
|
|
policy_token = np.array(policy).astype('int64') |
|
|
batch['policy_token'] = policy_token |
|
|
batch['policy_token'] = policy_token |
|
|
batch['policy_mask'] = (policy_token != |
|
|
|
|
|
self.pad_id).astype('int64') |
|
|
|
|
|
|
|
|
batch['policy_mask'] = \ |
|
|
|
|
|
(policy_token != self.pad_id).astype('int64') |
|
|
|
|
|
|
|
|
if 'tgt' in samples[0]: |
|
|
if 'tgt' in samples[0]: |
|
|
tgt = [sp['tgt'] for sp in samples] |
|
|
tgt = [sp['tgt'] for sp in samples] |
|
|
@@ -952,8 +971,8 @@ class IntentBPETextField(BPETextField): |
|
|
|
|
|
|
|
|
# One example for each label |
|
|
# One example for each label |
|
|
example_inds = [] |
|
|
example_inds = [] |
|
|
for l in set(labels.tolist()): |
|
|
|
|
|
if l == -1: |
|
|
|
|
|
|
|
|
for lable in set(labels.tolist()): |
|
|
|
|
|
if lable == -1: |
|
|
continue |
|
|
continue |
|
|
|
|
|
|
|
|
ind = random.choice(cache[l]) |
|
|
ind = random.choice(cache[l]) |
|
|
@@ -1001,7 +1020,7 @@ class IntentBPETextField(BPETextField): |
|
|
src_token.append(list(chain(*utts))[-self.max_len:]) |
|
|
src_token.append(list(chain(*utts))[-self.max_len:]) |
|
|
|
|
|
|
|
|
# Position ids |
|
|
# Position ids |
|
|
pos = [list(range(l)) for l in utt_lens] |
|
|
|
|
|
|
|
|
pos = [list(range(utt_len)) for utt_len in utt_lens] |
|
|
src_pos.append(list(chain(*pos))[-self.max_len:]) |
|
|
src_pos.append(list(chain(*pos))[-self.max_len:]) |
|
|
|
|
|
|
|
|
# Turn ids |
|
|
# Turn ids |
|
|
|