Browse Source

update

master
ly119399 3 years ago
parent
commit
85081fb3b4
5 changed files with 114 additions and 90 deletions
  1. +4
    -7
      modelscope/models/nlp/space/model/generator.py
  2. +10
    -11
      modelscope/preprocessors/space/fields/gen_field.py
  3. +61
    -42
      modelscope/preprocessors/space/fields/intent_field.py
  4. +19
    -12
      modelscope/preprocessors/space/tokenizer.py
  5. +20
    -18
      modelscope/utils/nlp/space/utils.py

+ 4
- 7
modelscope/models/nlp/space/model/generator.py View File

@@ -201,9 +201,6 @@ class BeamSearch(Generator):
predictions = torch.cat([predictions, preds.unsqueeze(2)], dim=2)
state = repeat(state, beam_size)

parent_idx_list = []
pred_list = []

if max_gen_len is None:
max_gen_len = self.max_gen_len
for step in range(2, max_gen_len + 1):
@@ -229,11 +226,11 @@ class BeamSearch(Generator):
pre_eos_mask = (1 - torch.not_equal(pre_ids, eos_id).float()) + \
(1 - torch.not_equal(pre_ids, self.pad_id).float())

scores = scores * (1 - pre_eos_mask) + \
pre_eos_mask.repeat(1, 1, self.vocab_size) * scores_after_end
scores = scores * (1 - pre_eos_mask) + pre_eos_mask.repeat(
1, 1, self.vocab_size) * scores_after_end
if self.length_average:
scaled_value = pre_eos_mask + (1 - pre_eos_mask) * (1 -
1 / step)
scaled_value = \
pre_eos_mask + (1 - pre_eos_mask) * (1 - 1 / step)
sequence_scores = sequence_scores.unsqueeze(2) * scaled_value
scaled_value = pre_eos_mask + (1 - pre_eos_mask) * (1 / step)
scores = scores * scaled_value


+ 10
- 11
modelscope/preprocessors/space/fields/gen_field.py View File

@@ -171,7 +171,7 @@ class BPETextField(object):
src_token.append(list(chain(*utts))[-self.max_len:])

# Position ids
pos = [list(range(l)) for l in utt_lens]
pos = [list(range(utt_len)) for utt_len in utt_lens]
src_pos.append(list(chain(*pos))[-self.max_len:])

# Turn ids
@@ -205,15 +205,15 @@ class BPETextField(object):
understand = [self.understand_ids for _ in samples]
understand_token = np.array(understand).astype('int64')
batch['understand_token'] = understand_token
batch['understand_mask'] = (understand_token !=
self.pad_id).astype('int64')
batch['understand_mask'] = \
(understand_token != self.pad_id).astype('int64')

if self.policy_ids and self.policy:
policy = [self.policy_ids for _ in samples]
policy_token = np.array(policy).astype('int64')
batch['policy_token'] = policy_token
batch['policy_mask'] = (policy_token !=
self.pad_id).astype('int64')
batch['policy_mask'] = \
(policy_token != self.pad_id).astype('int64')

if 'tgt' in samples[0]:
tgt = [sp['tgt'] for sp in samples]
@@ -421,7 +421,7 @@ class MultiWOZBPETextField(BPETextField):
try:
log_str += 'turn num:%d, dial num: %d, batch num: %d last batch len: %d\n' % (
k, len(turn_bucket[k]), len(batches), len(batches[-1]))
except:
except Exception:
log_str += 'turn num:%d, dial num: %d, batch num: %d last batch len: %d\n' % (
k, len(turn_bucket[k]), len(batches), 0.0)
# print("turn num:%d, dial num:v%d, batch num: %d, "%(k, len(turn_bucket[k]), len(batches)))
@@ -520,7 +520,7 @@ class MultiWOZBPETextField(BPETextField):
ns) is not str else ns
if ns == "'s":
continue
except:
except Exception:
continue
if not constraint_dict.get(domain):
constraint_dict[domain] = {}
@@ -670,10 +670,9 @@ class MultiWOZBPETextField(BPETextField):
pv_turn['aspn'] + pv_turn['resp']
]
else:
pv_context = pv_turn['labels'] + [
pv_turn['bspn'] + pv_turn['db'] + pv_turn['aspn'] +
pv_turn['resp']
]
pv_info = pv_turn['bspn'] + pv_turn['db'] + pv_turn[
'aspn'] + pv_turn['resp']
pv_context = pv_turn['labels'] + [pv_info]

# prompt response, add sos_r
inputs['src'] = pv_context + [context]


+ 61
- 42
modelscope/preprocessors/space/fields/intent_field.py View File

@@ -229,9 +229,8 @@ class BPETextField(object):

# 10% randomly change token to random token
elif prob < 0.9:
output_chars.append(
random.randint(1, self.vocab_size -
1)) # start from 1, to exclude pad_id
tmp = random.randint(1, self.vocab_size - 1)
output_chars.append(tmp) # start from 1, to exclude pad_id

# 10% randomly change token to current token
else:
@@ -401,7 +400,7 @@ class BPETextField(object):
build symmetric score matrix
"""
assert self.num_process == 1
print(f'Building score matrix from examples ...')
print('Building score matrix from examples ...')
num = len(examples)
score_matrix = np.eye(
num, num, dtype='float32'
@@ -415,7 +414,7 @@ class BPETextField(object):
score_matrix[i][j] = score
score_matrix[j][i] = score

print(f'Built score matrix')
print('Built score matrix')
return score_matrix

def build_score_matrix_on_the_fly(self,
@@ -482,7 +481,7 @@ class BPETextField(object):
build score matrix
"""
assert self.num_process >= 2 and multiprocessing.cpu_count() >= 2
print(f'Building score matrix from examples ...')
print('Building score matrix from examples ...')
results = []
num = len(examples)
sub_num, res_num = num // self.num_process, num % self.num_process
@@ -512,7 +511,7 @@ class BPETextField(object):
score_matrix,
1.) # in case of empty label of self, resulting in score 0.

print(f'Built score matrix')
print('Built score matrix')
return score_matrix

def extract_span_texts(self, text, label):
@@ -556,7 +555,7 @@ class BPETextField(object):

token_list = [
tok for tok in map(str.strip,
re.split('(\W+)', text.lower()))
re.split('(\\W+)', text.lower()))
if len(tok) > 0
]
span_list = np.zeros(len(token_list), dtype=np.int32)
@@ -586,10 +585,10 @@ class BPETextField(object):
history_span_mask.append(span_mask)
history_label.append(self.fix_label(label))

tmp = self.utts_filter_pred(history[:-1]) and all(
map(self.utt_filter_pred, history))
if (
(self.utts_filter_pred(history[:-1])
and all(map(self.utt_filter_pred, history)))
or data_type == 'test'
tmp or data_type == 'test'
) and role in self.trigger_role and t: # TODO consider test
src = [
s[-self.max_utt_len:]
@@ -603,11 +602,18 @@ class BPETextField(object):
role
for role in history_role[:-1][-self.max_ctx_turn:]
]
src = [[self.sos_u_id] + self.numericalize(s) +
[self.eos_u_id]
if roles[i] == 'user' else [self.sos_r_id] +
self.numericalize(s) + [self.eos_r_id]
for i, s in enumerate(src)]

new_src = []
for i, s in enumerate(src):
if roles[i] == 'user':
user_or_sys = [self.eos_u_id]
else:
user_or_sys = [self.sos_r_id]
tmp = [self.sos_u_id
] + self.numericalize(s) + user_or_sys
tmp = tmp + self.numericalize(s) + [self.eos_r_id]
new_src.append(tmp)

src_span_mask = [[0] + list(map(int, s)) + [0]
for s in src_span_mask]

@@ -619,7 +625,7 @@ class BPETextField(object):
ex = {
'dialog_id': dialog_id,
'turn_id': turn['turn_id'],
'src': src,
'src': new_src,
'src_span_mask': src_span_mask,
'tgt': tgt,
'query_label': history_label[-2],
@@ -654,7 +660,7 @@ class BPETextField(object):
history, history_role, history_span_mask = [], [], []
utterance, span_mask = [], []
token_list = [
tok for tok in map(str.strip, re.split('(\W+)', text.lower()))
tok for tok in map(str.strip, re.split('(\\W+)', text.lower()))
if len(tok) > 0
]
span_list = np.zeros(len(token_list), dtype=np.int32)
@@ -680,10 +686,17 @@ class BPETextField(object):
for s in history_span_mask[-self.max_ctx_turn:]
]
roles = [role for role in history_role[-self.max_ctx_turn:]]
src = [[self.sos_u_id] + self.numericalize(s) +
[self.eos_u_id] if roles[i] == 'user' else [self.sos_r_id] +
self.numericalize(s) + [self.eos_r_id]
for i, s in enumerate(src)]

new_src = []
for i, s in enumerate(src):
if roles[i] == 'user':
user_or_sys = [self.eos_u_id]
else:
user_or_sys = [self.sos_r_id]
tmp = [self.sos_u_id] + self.numericalize(s) + user_or_sys
tmp = tmp + self.numericalize(s) + [self.eos_r_id]
new_src.append(tmp)

src_span_mask = [[0] + list(map(int, s)) + [0]
for s in src_span_mask]

@@ -691,7 +704,7 @@ class BPETextField(object):
'dialog_id': 'inference',
'turn_id': 0,
'role': role,
'src': src,
'src': new_src,
'src_span_mask': src_span_mask,
'query_label': {
'DEFAULT_DOMAIN': {
@@ -734,7 +747,7 @@ class BPETextField(object):

token_list = [
tok for tok in map(str.strip,
re.split('(\W+)', text.lower()))
re.split('(\\W+)', text.lower()))
if len(tok) > 0
]
span_list = np.zeros(len(token_list), dtype=np.int32)
@@ -763,10 +776,10 @@ class BPETextField(object):
history_role.append(role)
history_span_mask.append(span_mask)

if ((self.utts_filter_pred(history)
and all(map(self.utt_filter_pred, history)))
or data_type == 'test'
) and role in self.trigger_role: # TODO consider test
tmp = self.utts_filter_pred(history) and all(
map(self.utt_filter_pred, history))
tmp = tmp or data_type == 'test'
if tmp and role in self.trigger_role: # TODO consider test
src = [
s[-self.max_utt_len:]
for s in history[-self.max_ctx_turn:]
@@ -778,11 +791,17 @@ class BPETextField(object):
roles = [
role for role in history_role[-self.max_ctx_turn:]
]
src = [[self.sos_u_id] + self.numericalize(s) +
[self.eos_u_id]
if roles[i] == 'user' else [self.sos_r_id] +
self.numericalize(s) + [self.eos_r_id]
for i, s in enumerate(src)]
new_src = []
for i, s in enumerate(src):
if roles[i] == 'user':
user_or_sys = [self.eos_u_id]
else:
user_or_sys = [self.sos_r_id]
tmp = [self.sos_u_id
] + self.numericalize(s) + user_or_sys
tmp = tmp + self.numericalize(s) + [self.eos_r_id]
new_src.append(tmp)

src_span_mask = [[0] + list(map(int, s)) + [0]
for s in src_span_mask]

@@ -790,7 +809,7 @@ class BPETextField(object):
'dialog_id': dialog_id,
'turn_id': turn['turn_id'],
'role': role,
'src': src,
'src': new_src,
'src_span_mask': src_span_mask,
'query_label': self.fix_label(label),
'extra_info': turn.get('extra_info', '')
@@ -829,7 +848,7 @@ class BPETextField(object):
src_token.append(list(chain(*utts))[-self.max_len:])

# Position ids
pos = [list(range(l)) for l in utt_lens]
pos = [list(range(utt_len)) for utt_len in utt_lens]
src_pos.append(list(chain(*pos))[-self.max_len:])

# Turn ids
@@ -887,15 +906,15 @@ class BPETextField(object):
understand = [self.understand_ids for _ in samples]
understand_token = np.array(understand).astype('int64')
batch['understand_token'] = understand_token
batch['understand_mask'] = (understand_token !=
self.pad_id).astype('int64')
batch['understand_mask'] = \
(understand_token != self.pad_id).astype('int64')

if self.policy_ids and self.policy:
policy = [self.policy_ids for _ in samples]
policy_token = np.array(policy).astype('int64')
batch['policy_token'] = policy_token
batch['policy_mask'] = (policy_token !=
self.pad_id).astype('int64')
batch['policy_mask'] = \
(policy_token != self.pad_id).astype('int64')

if 'tgt' in samples[0]:
tgt = [sp['tgt'] for sp in samples]
@@ -952,8 +971,8 @@ class IntentBPETextField(BPETextField):

# One example for each label
example_inds = []
for l in set(labels.tolist()):
if l == -1:
for lable in set(labels.tolist()):
if lable == -1:
continue

ind = random.choice(cache[l])
@@ -1001,7 +1020,7 @@ class IntentBPETextField(BPETextField):
src_token.append(list(chain(*utts))[-self.max_len:])

# Position ids
pos = [list(range(l)) for l in utt_lens]
pos = [list(range(utt_len)) for utt_len in utt_lens]
src_pos.append(list(chain(*pos))[-self.max_len:])

# Turn ids


+ 19
- 12
modelscope/preprocessors/space/tokenizer.py View File

@@ -325,13 +325,15 @@ class BasicTokenizer(object):
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if ((cp >= 0x4E00 and cp <= 0x9FFF) or (cp >= 0x3400 and cp <= 0x4DBF)
or (cp >= 0x20000 and cp <= 0x2A6DF)
or (cp >= 0x2A700 and cp <= 0x2B73F)
or (cp >= 0x2B740 and cp <= 0x2B81F)
or (cp >= 0x2B820 and cp <= 0x2CEAF)
or (cp >= 0xF900 and cp <= 0xFAFF)
or (cp >= 0x2F800 and cp <= 0x2FA1F)):
tmp = (cp >= 0x4E00 and cp <= 0x9FFF)
tmp = tmp or (cp >= 0x3400 and cp <= 0x4DBF)
tmp = tmp or (cp >= 0x20000 and cp <= 0x2A6DF)
tmp = tmp or (cp >= 0x2A700 and cp <= 0x2B73F)
tmp = tmp or (cp >= 0x2B740 and cp <= 0x2B81F)
tmp = tmp or (cp >= 0x2B820 and cp <= 0x2CEAF)
tmp = tmp or (cp >= 0xF900 and cp <= 0xFAFF)
tmp = tmp or (cp >= 0x2F800 and cp <= 0x2FA1F)
if tmp:
return True

return False
@@ -441,8 +443,11 @@ def _is_punctuation(char):
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64)
or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
tmp = (cp >= 33 and cp <= 47)
tmp = tmp or (cp >= 58 and cp <= 64)
tmp = tmp or (cp >= 91 and cp <= 96)
tmp = tmp or (cp >= 123 and cp <= 126)
if tmp:
return True
cat = unicodedata.category(char)
if cat.startswith('P'):
@@ -589,7 +594,7 @@ class GPT2Tokenizer(object):
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
except Exception:
new_word.extend(word[i:])
break

@@ -625,8 +630,10 @@ class GPT2Tokenizer(object):
def convert_tokens_to_ids(self, tokens):
""" Converts a sequence of tokens into ids using the vocab. """
ids = []
if isinstance(tokens, str) or (sys.version_info[0] == 2
and isinstance(tokens, unicode)):
python_version_3 = isinstance(tokens, str)
python_version_2 = (
sys.version_info[0] == 2 and isinstance(tokens, unicode))
if python_version_3 or python_version_2:
if tokens in self.special_tokens:
return self.special_tokens[tokens]
else:


+ 20
- 18
modelscope/utils/nlp/space/utils.py View File

@@ -46,8 +46,8 @@ def clean_replace(s, r, t, forward=True, backward=False):
return s, -1

if forward:
while idx_r < len(s) and (s[idx_r].isalpha()
or s[idx_r].isdigit()):
while \
idx_r < len(s) and (s[idx_r].isalpha() or s[idx_r].isdigit()):
idx_r += 1
elif idx_r != len(s) and (s[idx_r].isalpha() or s[idx_r].isdigit()):
return s, -1
@@ -122,13 +122,15 @@ class MultiWOZVocab(object):
self._word2idx[word] = idx

def construct(self):
l = sorted(self._freq_dict.keys(), key=lambda x: -self._freq_dict[x])
freq_dict_sorted = sorted(
self._freq_dict.keys(), key=lambda x: -self._freq_dict[x])
print('Vocabulary size including oov: %d' %
(len(l) + len(self._idx2word)))
if len(l) + len(self._idx2word) < self.vocab_size:
(len(freq_dict_sorted) + len(self._idx2word)))
if len(freq_dict_sorted) + len(self._idx2word) < self.vocab_size:
logging.warning(
'actual label set smaller than that configured: {}/{}'.format(
len(l) + len(self._idx2word), self.vocab_size))
len(freq_dict_sorted) + len(self._idx2word),
self.vocab_size))
for word in ontology.all_domains + ['general']:
word = '[' + word + ']'
self._add_to_vocab(word)
@@ -137,10 +139,10 @@ class MultiWOZVocab(object):
self._add_to_vocab(word)
for word in ontology.all_slots:
self._add_to_vocab(word)
for word in l:
for word in freq_dict_sorted:
if word.startswith('[value_') and word.endswith(']'):
self._add_to_vocab(word)
for word in l:
for word in freq_dict_sorted:
self._add_to_vocab(word)
self.vocab_size_oov = len(self._idx2word)

@@ -192,13 +194,13 @@ class MultiWOZVocab(object):
else:
return self._idx2word[idx] + '(o)'

def sentence_decode(self, index_list, eos=None, indicate_oov=False):
l = [self.decode(_, indicate_oov) for _ in index_list]
if not eos or eos not in l:
return ' '.join(l)
else:
idx = l.index(eos)
return ' '.join(l[:idx])
def nl_decode(self, l, eos=None):
return [self.sentence_decode(_, eos) + '\n' for _ in l]
# def sentence_decode(self, index_list, eos=None, indicate_oov=False):
# l = [self.decode(_, indicate_oov) for _ in index_list]
# if not eos or eos not in l:
# return ' '.join(l)
# else:
# idx = l.index(eos)
# return ' '.join(l[:idx])
#
# def nl_decode(self, l, eos=None):
# return [self.sentence_decode(_, eos) + '\n' for _ in l]

Loading…
Cancel
Save