| @@ -33,22 +33,22 @@ DIALOG_ACT = 'Dialog_Act' | |||||
| utter1 = { | utter1 = { | ||||
| 'User-1': | 'User-1': | ||||
| "I'd really like to take my client out to a nice restaurant that serves indian food." | |||||
| "I'd really like to take my client out to a nice restaurant that serves indian food." | |||||
| } | } | ||||
| history_states1 = [ | history_states1 = [ | ||||
| {}, | {}, | ||||
| ] | ] | ||||
| utter2 = { | utter2 = { | ||||
| 'User-1': | 'User-1': | ||||
| "I'd really like to take my client out to a nice restaurant that serves indian food.", | |||||
| "I'd really like to take my client out to a nice restaurant that serves indian food.", | |||||
| 'System-1': | 'System-1': | ||||
| 'I show many restaurants that serve Indian food in that price range. What area would you like to travel to?', | |||||
| 'I show many restaurants that serve Indian food in that price range. What area would you like to travel to?', | |||||
| 'Dialog_Act-1': { | 'Dialog_Act-1': { | ||||
| 'Restaurant-Inform': [['choice', 'many'], ['food', 'Indian'], | 'Restaurant-Inform': [['choice', 'many'], ['food', 'Indian'], | ||||
| ['pricerange', 'that price range']] | ['pricerange', 'that price range']] | ||||
| }, | }, | ||||
| 'User-2': | 'User-2': | ||||
| 'I am looking for an expensive indian restaurant in the area of centre.', | |||||
| 'I am looking for an expensive indian restaurant in the area of centre.', | |||||
| } | } | ||||
| history_states2 = [{}, { | history_states2 = [{}, { | ||||
| @@ -77,11 +77,11 @@ history_states2 = [{}, { | |||||
| 'reference': 'JXVKZ7KV' | 'reference': 'JXVKZ7KV' | ||||
| }], | }], | ||||
| 'day': | 'day': | ||||
| 'sunday', | |||||
| 'sunday', | |||||
| 'people': | 'people': | ||||
| '6', | |||||
| '6', | |||||
| 'stay': | 'stay': | ||||
| '4' | |||||
| '4' | |||||
| }, | }, | ||||
| 'semi': { | 'semi': { | ||||
| 'area': '', | 'area': '', | ||||
| @@ -144,17 +144,17 @@ history_states2 = [{}, { | |||||
| utter3 = { | utter3 = { | ||||
| 'User-1': | 'User-1': | ||||
| "I'd really like to take my client out to a nice restaurant that serves indian food.", | |||||
| "I'd really like to take my client out to a nice restaurant that serves indian food.", | |||||
| 'System-1': | 'System-1': | ||||
| 'I show many restaurants that serve Indian food in that price range. What area would you like to travel to?', | |||||
| 'I show many restaurants that serve Indian food in that price range. What area would you like to travel to?', | |||||
| 'Dialog_Act-1': { | 'Dialog_Act-1': { | ||||
| 'Restaurant-Inform': [['choice', 'many'], ['food', 'Indian'], | 'Restaurant-Inform': [['choice', 'many'], ['food', 'Indian'], | ||||
| ['pricerange', 'that price range']] | ['pricerange', 'that price range']] | ||||
| }, | }, | ||||
| 'User-2': | 'User-2': | ||||
| 'I am looking for an expensive indian restaurant in the area of centre.', | |||||
| 'I am looking for an expensive indian restaurant in the area of centre.', | |||||
| 'System-2': | 'System-2': | ||||
| 'Might I recommend Saffron Brasserie? That is an expensive Indian restaurant in the center of town. I can book a table for you, if you like.', | |||||
| 'Might I recommend Saffron Brasserie? That is an expensive Indian restaurant in the center of town. I can book a table for you, if you like.', | |||||
| 'Dialog_Act-2': { | 'Dialog_Act-2': { | ||||
| 'Restaurant-Recommend': [['area', 'center of town'], | 'Restaurant-Recommend': [['area', 'center of town'], | ||||
| ['food', 'Indian'], | ['food', 'Indian'], | ||||
| @@ -190,11 +190,11 @@ history_states3 = [{}, { | |||||
| 'reference': 'JXVKZ7KV' | 'reference': 'JXVKZ7KV' | ||||
| }], | }], | ||||
| 'day': | 'day': | ||||
| 'sunday', | |||||
| 'sunday', | |||||
| 'people': | 'people': | ||||
| '6', | |||||
| '6', | |||||
| 'stay': | 'stay': | ||||
| '4' | |||||
| '4' | |||||
| }, | }, | ||||
| 'semi': { | 'semi': { | ||||
| 'area': '', | 'area': '', | ||||
| @@ -254,99 +254,98 @@ history_states3 = [{}, { | |||||
| } | } | ||||
| } | } | ||||
| }, {}, { | }, {}, { | ||||
| 'attraction': { | |||||
| 'book': { | |||||
| 'booked': [] | |||||
| }, | |||||
| 'semi': { | |||||
| 'area': '', | |||||
| 'name': '', | |||||
| 'type': '' | |||||
| } | |||||
| }, | |||||
| 'hospital': { | |||||
| 'book': { | |||||
| 'booked': [] | |||||
| }, | |||||
| 'semi': { | |||||
| 'department': '' | |||||
| } | |||||
| }, | |||||
| 'hotel': { | |||||
| 'book': { | |||||
| 'booked': [{ | |||||
| 'name': 'alexander bed and breakfast', | |||||
| 'reference': 'JXVKZ7KV' | |||||
| }], | |||||
| 'day': | |||||
| 'sunday', | |||||
| 'people': | |||||
| '6', | |||||
| 'stay': | |||||
| '4' | |||||
| }, | |||||
| 'semi': { | |||||
| 'area': '', | |||||
| 'internet': 'yes', | |||||
| 'name': 'alexander bed and breakfast', | |||||
| 'parking': 'yes', | |||||
| 'pricerange': 'cheap', | |||||
| 'stars': '', | |||||
| 'type': 'guesthouse' | |||||
| } | |||||
| }, | |||||
| 'police': { | |||||
| 'book': { | |||||
| 'booked': [] | |||||
| }, | |||||
| 'semi': {} | |||||
| }, | |||||
| 'restaurant': { | |||||
| 'book': { | |||||
| 'booked': [{ | |||||
| 'name': 'ask', | |||||
| 'reference': 'Y2Y8QYBY' | |||||
| }], | |||||
| 'day': 'sunday', | |||||
| 'people': '6', | |||||
| 'time': '18:45' | |||||
| }, | |||||
| 'semi': { | |||||
| 'area': 'centre', | |||||
| 'food': 'italian', | |||||
| 'name': 'ask', | |||||
| 'pricerange': 'cheap' | |||||
| } | |||||
| }, | |||||
| 'taxi': { | |||||
| 'book': { | |||||
| 'booked': [] | |||||
| }, | |||||
| 'semi': { | |||||
| 'arriveBy': '', | |||||
| 'departure': '', | |||||
| 'destination': '', | |||||
| 'leaveAt': '' | |||||
| } | |||||
| }, | |||||
| 'train': { | |||||
| 'book': { | |||||
| 'booked': [], | |||||
| 'people': '' | |||||
| }, | |||||
| 'semi': { | |||||
| 'arriveBy': '', | |||||
| 'day': '', | |||||
| 'departure': '', | |||||
| 'destination': '', | |||||
| 'leaveAt': '' | |||||
| } | |||||
| } | |||||
| }, {}] | |||||
| 'attraction': { | |||||
| 'book': { | |||||
| 'booked': [] | |||||
| }, | |||||
| 'semi': { | |||||
| 'area': '', | |||||
| 'name': '', | |||||
| 'type': '' | |||||
| } | |||||
| }, | |||||
| 'hospital': { | |||||
| 'book': { | |||||
| 'booked': [] | |||||
| }, | |||||
| 'semi': { | |||||
| 'department': '' | |||||
| } | |||||
| }, | |||||
| 'hotel': { | |||||
| 'book': { | |||||
| 'booked': [{ | |||||
| 'name': 'alexander bed and breakfast', | |||||
| 'reference': 'JXVKZ7KV' | |||||
| }], | |||||
| 'day': | |||||
| 'sunday', | |||||
| 'people': | |||||
| '6', | |||||
| 'stay': | |||||
| '4' | |||||
| }, | |||||
| 'semi': { | |||||
| 'area': '', | |||||
| 'internet': 'yes', | |||||
| 'name': 'alexander bed and breakfast', | |||||
| 'parking': 'yes', | |||||
| 'pricerange': 'cheap', | |||||
| 'stars': '', | |||||
| 'type': 'guesthouse' | |||||
| } | |||||
| }, | |||||
| 'police': { | |||||
| 'book': { | |||||
| 'booked': [] | |||||
| }, | |||||
| 'semi': {} | |||||
| }, | |||||
| 'restaurant': { | |||||
| 'book': { | |||||
| 'booked': [{ | |||||
| 'name': 'ask', | |||||
| 'reference': 'Y2Y8QYBY' | |||||
| }], | |||||
| 'day': 'sunday', | |||||
| 'people': '6', | |||||
| 'time': '18:45' | |||||
| }, | |||||
| 'semi': { | |||||
| 'area': 'centre', | |||||
| 'food': 'italian', | |||||
| 'name': 'ask', | |||||
| 'pricerange': 'cheap' | |||||
| } | |||||
| }, | |||||
| 'taxi': { | |||||
| 'book': { | |||||
| 'booked': [] | |||||
| }, | |||||
| 'semi': { | |||||
| 'arriveBy': '', | |||||
| 'departure': '', | |||||
| 'destination': '', | |||||
| 'leaveAt': '' | |||||
| } | |||||
| }, | |||||
| 'train': { | |||||
| 'book': { | |||||
| 'booked': [], | |||||
| 'people': '' | |||||
| }, | |||||
| 'semi': { | |||||
| 'arriveBy': '', | |||||
| 'day': '', | |||||
| 'departure': '', | |||||
| 'destination': '', | |||||
| 'leaveAt': '' | |||||
| } | |||||
| } | |||||
| }, {}] | |||||
| class DSTProcessor(object): | class DSTProcessor(object): | ||||
| ACTS_DICT = { | ACTS_DICT = { | ||||
| 'taxi-depart': 'taxi-departure', | 'taxi-depart': 'taxi-departure', | ||||
| 'taxi-dest': 'taxi-destination', | 'taxi-dest': 'taxi-destination', | ||||
| @@ -428,7 +427,7 @@ class DSTProcessor(object): | |||||
| for a in item: | for a in item: | ||||
| aa = a.lower().split('-') | aa = a.lower().split('-') | ||||
| if aa[1] == 'inform' or aa[1] == 'recommend' or aa[ | if aa[1] == 'inform' or aa[1] == 'recommend' or aa[ | ||||
| 1] == 'select' or aa[1] == 'book': | |||||
| 1] == 'select' or aa[1] == 'book': | |||||
| for i in item[a]: | for i in item[a]: | ||||
| s = i[0].lower() | s = i[0].lower() | ||||
| v = i[1].lower().strip() | v = i[1].lower().strip() | ||||
| @@ -443,7 +442,7 @@ class DSTProcessor(object): | |||||
| if key not in s_dict: | if key not in s_dict: | ||||
| s_dict[key] = list([v]) | s_dict[key] = list([v]) | ||||
| # ... Option 2: Keep last informed value | # ... Option 2: Keep last informed value | ||||
| #s_dict[key] = list([v]) | |||||
| # s_dict[key] = list([v]) | |||||
| return s_dict | return s_dict | ||||
| @@ -472,7 +471,7 @@ class multiwoz22Processor(DSTProcessor): | |||||
| '(\d{2})(:\d{2}) ?p\.?m\.?', lambda x: str( | '(\d{2})(:\d{2}) ?p\.?m\.?', lambda x: str( | ||||
| int(x.groups()[0]) + 12 | int(x.groups()[0]) + 12 | ||||
| if int(x.groups()[0]) < 12 else int(x.groups()[0])) + x.groups( | if int(x.groups()[0]) < 12 else int(x.groups()[0])) + x.groups( | ||||
| )[1], text) | |||||
| )[1], text) | |||||
| text = re.sub('(^| )24:(\d{2})', r'\g<1>00:\2', | text = re.sub('(^| )24:(\d{2})', r'\g<1>00:\2', | ||||
| text) # Correct times that use 24 as hour | text) # Correct times that use 24 as hour | ||||
| return text | return text | ||||
| @@ -509,7 +508,7 @@ class multiwoz22Processor(DSTProcessor): | |||||
| for a in acts[d][t]['dialog_act']: | for a in acts[d][t]['dialog_act']: | ||||
| aa = a.lower().split('-') | aa = a.lower().split('-') | ||||
| if aa[1] == 'inform' or aa[1] == 'recommend' or aa[ | if aa[1] == 'inform' or aa[1] == 'recommend' or aa[ | ||||
| 1] == 'select' or aa[1] == 'book': | |||||
| 1] == 'select' or aa[1] == 'book': | |||||
| for i in acts[d][t]['dialog_act'][a]: | for i in acts[d][t]['dialog_act'][a]: | ||||
| s = i[0].lower() | s = i[0].lower() | ||||
| v = i[1].lower().strip() | v = i[1].lower().strip() | ||||
| @@ -524,7 +523,7 @@ class multiwoz22Processor(DSTProcessor): | |||||
| if key not in s_dict: | if key not in s_dict: | ||||
| s_dict[key] = list([v]) | s_dict[key] = list([v]) | ||||
| # ... Option 2: Keep last informed value | # ... Option 2: Keep last informed value | ||||
| #s_dict[key] = list([v]) | |||||
| # s_dict[key] = list([v]) | |||||
| return s_dict | return s_dict | ||||
| # This should only contain label normalizations. All other mappings should | # This should only contain label normalizations. All other mappings should | ||||
| @@ -764,7 +763,7 @@ class multiwoz22Processor(DSTProcessor): | |||||
| inform_dict = {slot: 'none' for slot in slot_list} | inform_dict = {slot: 'none' for slot in slot_list} | ||||
| for slot in slot_list: | for slot in slot_list: | ||||
| if (str(dialog_id), str(turn_itr), | if (str(dialog_id), str(turn_itr), | ||||
| slot) in sys_inform_dict: | |||||
| slot) in sys_inform_dict: | |||||
| inform_dict[slot] = sys_inform_dict[(str(dialog_id), | inform_dict[slot] = sys_inform_dict[(str(dialog_id), | ||||
| str(turn_itr), | str(turn_itr), | ||||
| slot)] | slot)] | ||||
| @@ -802,7 +801,7 @@ class multiwoz22Processor(DSTProcessor): | |||||
| value_label = booked_slots[s] | value_label = booked_slots[s] | ||||
| # Remember modified slots and entire dialog state | # Remember modified slots and entire dialog state | ||||
| if cs in slot_list and cumulative_labels[ | if cs in slot_list and cumulative_labels[ | ||||
| cs] != value_label: | |||||
| cs] != value_label: | |||||
| modified_slots[cs] = value_label | modified_slots[cs] = value_label | ||||
| cumulative_labels[cs] = value_label | cumulative_labels[cs] = value_label | ||||
| @@ -884,13 +883,13 @@ class multiwoz22Processor(DSTProcessor): | |||||
| (informed_value, referred_slot, usr_utt_tok_label, | (informed_value, referred_slot, usr_utt_tok_label, | ||||
| class_type) = self.get_turn_label( | class_type) = self.get_turn_label( | ||||
| value_label, | |||||
| inform_label, | |||||
| sys_utt_tok, | |||||
| usr_utt_tok, | |||||
| slot, | |||||
| diag_seen_slots_value_dict, | |||||
| slot_last_occurrence=True) | |||||
| value_label, | |||||
| inform_label, | |||||
| sys_utt_tok, | |||||
| usr_utt_tok, | |||||
| slot, | |||||
| diag_seen_slots_value_dict, | |||||
| slot_last_occurrence=True) | |||||
| inform_dict[slot] = informed_value | inform_dict[slot] = informed_value | ||||
| @@ -903,7 +902,7 @@ class multiwoz22Processor(DSTProcessor): | |||||
| if label_value_repetitions and slot in diag_seen_slots_dict: | if label_value_repetitions and slot in diag_seen_slots_dict: | ||||
| if class_type == 'copy_value' and list( | if class_type == 'copy_value' and list( | ||||
| diag_seen_slots_value_dict.values()).count( | diag_seen_slots_value_dict.values()).count( | ||||
| value_label) > 1: | |||||
| value_label) > 1: | |||||
| class_type = 'none' | class_type = 'none' | ||||
| usr_utt_tok_label = [0 for _ in usr_utt_tok_label] | usr_utt_tok_label = [0 for _ in usr_utt_tok_label] | ||||
| @@ -915,15 +914,15 @@ class multiwoz22Processor(DSTProcessor): | |||||
| if swap_utterances: | if swap_utterances: | ||||
| new_hst_utt_tok_label_dict[ | new_hst_utt_tok_label_dict[ | ||||
| slot] = usr_utt_tok_label + sys_utt_tok_label + new_hst_utt_tok_label_dict[ | slot] = usr_utt_tok_label + sys_utt_tok_label + new_hst_utt_tok_label_dict[ | ||||
| slot] | |||||
| slot] | |||||
| else: | else: | ||||
| new_hst_utt_tok_label_dict[ | new_hst_utt_tok_label_dict[ | ||||
| slot] = sys_utt_tok_label + usr_utt_tok_label + new_hst_utt_tok_label_dict[ | slot] = sys_utt_tok_label + usr_utt_tok_label + new_hst_utt_tok_label_dict[ | ||||
| slot] | |||||
| slot] | |||||
| else: | else: | ||||
| new_hst_utt_tok_label_dict[slot] = [ | new_hst_utt_tok_label_dict[slot] = [ | ||||
| 0 for _ in sys_utt_tok_label + usr_utt_tok_label | 0 for _ in sys_utt_tok_label + usr_utt_tok_label | ||||
| + new_hst_utt_tok_label_dict[slot] | |||||
| + new_hst_utt_tok_label_dict[slot] | |||||
| ] | ] | ||||
| # For now, we map all occurences of unpointable slot values | # For now, we map all occurences of unpointable slot values | ||||
| @@ -936,10 +935,10 @@ class multiwoz22Processor(DSTProcessor): | |||||
| referral_dict[slot] = 'none' | referral_dict[slot] = 'none' | ||||
| if analyze: | if analyze: | ||||
| if slot not in diag_seen_slots_dict or value_label != diag_seen_slots_value_dict[ | if slot not in diag_seen_slots_dict or value_label != diag_seen_slots_value_dict[ | ||||
| slot]: | |||||
| slot]: | |||||
| print('(%s): %s, ' % (slot, value_label), end='') | print('(%s): %s, ' % (slot, value_label), end='') | ||||
| elif slot in diag_seen_slots_dict and class_type == diag_seen_slots_dict[ | elif slot in diag_seen_slots_dict and class_type == diag_seen_slots_dict[ | ||||
| slot] and class_type != 'copy_value' and class_type != 'inform': | |||||
| slot] and class_type != 'copy_value' and class_type != 'inform': | |||||
| # If slot has seen before and its class type did not change, label this slot a not present, | # If slot has seen before and its class type did not change, label this slot a not present, | ||||
| # assuming that the slot has not actually been mentioned in this turn. | # assuming that the slot has not actually been mentioned in this turn. | ||||
| # Exceptions are copy_value and inform. If a seen slot has been tagged as copy_value or inform, | # Exceptions are copy_value and inform. If a seen slot has been tagged as copy_value or inform, | ||||
| @@ -1195,7 +1194,7 @@ def convert_examples_to_features(examples, | |||||
| if slot_value_dropout == 0.0 or joint_label == 0: | if slot_value_dropout == 0.0 or joint_label == 0: | ||||
| tokens.extend(sub_tokens) | tokens.extend(sub_tokens) | ||||
| else: | else: | ||||
| rn_list = np.random.random_sample((len(sub_tokens), )) | |||||
| rn_list = np.random.random_sample((len(sub_tokens),)) | |||||
| for rn, sub_token in zip(rn_list, sub_tokens): | for rn, sub_token in zip(rn_list, sub_tokens): | ||||
| if rn > slot_value_dropout: | if rn > slot_value_dropout: | ||||
| tokens.append(sub_token) | tokens.append(sub_token) | ||||
| @@ -1262,7 +1261,7 @@ def convert_examples_to_features(examples, | |||||
| def _get_start_end_pos(class_type, token_label_ids, max_seq_length): | def _get_start_end_pos(class_type, token_label_ids, max_seq_length): | ||||
| if class_type == 'copy_value' and 1 not in token_label_ids: | if class_type == 'copy_value' and 1 not in token_label_ids: | ||||
| #logger.warn("copy_value label, but token_label not detected. Setting label to 'none'.") | |||||
| # logger.warn("copy_value label, but token_label not detected. Setting label to 'none'.") | |||||
| class_type = 'none' | class_type = 'none' | ||||
| start_pos = 0 | start_pos = 0 | ||||
| end_pos = 0 | end_pos = 0 | ||||