| @@ -97,5 +97,6 @@ class SpaceForDialogStateTrackingModel(Model): | |||||
| 'input_ids_unmasked': input_ids_unmasked, | 'input_ids_unmasked': input_ids_unmasked, | ||||
| 'values': values, | 'values': values, | ||||
| 'inform': inform, | 'inform': inform, | ||||
| 'prefix': 'final' | |||||
| 'prefix': 'final', | |||||
| 'ds': input['ds'] | |||||
| } | } | ||||
| @@ -1,7 +1,6 @@ | |||||
| from .audio import LinearAECPipeline | |||||
| from .audio.ans_pipeline import ANSPipeline | |||||
| # from .audio import LinearAECPipeline | |||||
| # from .audio.ans_pipeline import ANSPipeline | |||||
| from .base import Pipeline | from .base import Pipeline | ||||
| from .builder import pipeline | from .builder import pipeline | ||||
| from .cv import * # noqa F403 | |||||
| from .multi_modal import * # noqa F403 | from .multi_modal import * # noqa F403 | ||||
| from .nlp import * # noqa F403 | from .nlp import * # noqa F403 | ||||
| @@ -45,7 +45,8 @@ class DialogStateTrackingPipeline(Pipeline): | |||||
| values = inputs['values'] | values = inputs['values'] | ||||
| inform = inputs['inform'] | inform = inputs['inform'] | ||||
| prefix = inputs['prefix'] | prefix = inputs['prefix'] | ||||
| ds = {slot: 'none' for slot in self.config.dst_slot_list} | |||||
| # ds = {slot: 'none' for slot in self.config.dst_slot_list} | |||||
| ds = inputs['ds'] | |||||
| ds = predict_and_format(self.config, self.tokenizer, _inputs, | ds = predict_and_format(self.config, self.tokenizer, _inputs, | ||||
| _outputs[2], _outputs[3], _outputs[4], | _outputs[2], _outputs[3], _outputs[4], | ||||
| @@ -113,7 +114,11 @@ def predict_and_format(config, tokenizer, features, per_slot_class_logits, | |||||
| 'false'): | 'false'): | ||||
| dialog_state[slot] = 'false' | dialog_state[slot] = 'false' | ||||
| elif class_prediction == config.dst_class_types.index('inform'): | elif class_prediction == config.dst_class_types.index('inform'): | ||||
| dialog_state[slot] = '§§' + inform[i][slot] | |||||
| # dialog_state[slot] = '§§' + inform[i][slot] | |||||
| if isinstance(inform[i][slot], str): | |||||
| dialog_state[slot] = inform[i][slot] | |||||
| elif isinstance(inform[i][slot], list): | |||||
| dialog_state[slot] = inform[i][slot][0] | |||||
| # Referral case is handled below | # Referral case is handled below | ||||
| prediction_addendum['slot_prediction_%s' | prediction_addendum['slot_prediction_%s' | ||||
| @@ -114,6 +114,44 @@ TASK_OUTPUTS = { | |||||
| # "scores": [0.9, 0.1, 0.05, 0.05] | # "scores": [0.9, 0.1, 0.05, 0.05] | ||||
| # } | # } | ||||
| Tasks.nli: ['scores', 'labels'], | Tasks.nli: ['scores', 'labels'], | ||||
| Tasks.dialog_modeling: [], | |||||
| Tasks.dialog_intent_prediction: [], | |||||
| # { | |||||
| # "dialog_states": { | |||||
| # "taxi-leaveAt": "none", | |||||
| # "taxi-destination": "none", | |||||
| # "taxi-departure": "none", | |||||
| # "taxi-arriveBy": "none", | |||||
| # "restaurant-book_people": "none", | |||||
| # "restaurant-book_day": "none", | |||||
| # "restaurant-book_time": "none", | |||||
| # "restaurant-food": "none", | |||||
| # "restaurant-pricerange": "none", | |||||
| # "restaurant-name": "none", | |||||
| # "restaurant-area": "none", | |||||
| # "hotel-book_people": "none", | |||||
| # "hotel-book_day": "none", | |||||
| # "hotel-book_stay": "none", | |||||
| # "hotel-name": "none", | |||||
| # "hotel-area": "none", | |||||
| # "hotel-parking": "none", | |||||
| # "hotel-pricerange": "cheap", | |||||
| # "hotel-stars": "none", | |||||
| # "hotel-internet": "none", | |||||
| # "hotel-type": "true", | |||||
| # "attraction-type": "none", | |||||
| # "attraction-name": "none", | |||||
| # "attraction-area": "none", | |||||
| # "train-book_people": "none", | |||||
| # "train-leaveAt": "none", | |||||
| # "train-destination": "none", | |||||
| # "train-day": "none", | |||||
| # "train-arriveBy": "none", | |||||
| # "train-departure": "none" | |||||
| # } | |||||
| # } | |||||
| Tasks.dialog_state_tracking: ['dialog_states'], | |||||
| # ============ audio tasks =================== | # ============ audio tasks =================== | ||||
| @@ -153,43 +191,5 @@ TASK_OUTPUTS = { | |||||
| # { | # { | ||||
| # "image": np.ndarray with shape [height, width, 3] | # "image": np.ndarray with shape [height, width, 3] | ||||
| # } | # } | ||||
| Tasks.text_to_image_synthesis: ['image'], | |||||
| Tasks.dialog_modeling: [], | |||||
| Tasks.dialog_intent_prediction: [], | |||||
| # { | |||||
| # "dialog_states": { | |||||
| # "taxi-leaveAt": "none", | |||||
| # "taxi-destination": "none", | |||||
| # "taxi-departure": "none", | |||||
| # "taxi-arriveBy": "none", | |||||
| # "restaurant-book_people": "none", | |||||
| # "restaurant-book_day": "none", | |||||
| # "restaurant-book_time": "none", | |||||
| # "restaurant-food": "none", | |||||
| # "restaurant-pricerange": "none", | |||||
| # "restaurant-name": "none", | |||||
| # "restaurant-area": "none", | |||||
| # "hotel-book_people": "none", | |||||
| # "hotel-book_day": "none", | |||||
| # "hotel-book_stay": "none", | |||||
| # "hotel-name": "none", | |||||
| # "hotel-area": "none", | |||||
| # "hotel-parking": "none", | |||||
| # "hotel-pricerange": "cheap", | |||||
| # "hotel-stars": "none", | |||||
| # "hotel-internet": "none", | |||||
| # "hotel-type": "true", | |||||
| # "attraction-type": "none", | |||||
| # "attraction-name": "none", | |||||
| # "attraction-area": "none", | |||||
| # "train-book_people": "none", | |||||
| # "train-leaveAt": "none", | |||||
| # "train-destination": "none", | |||||
| # "train-day": "none", | |||||
| # "train-arriveBy": "none", | |||||
| # "train-departure": "none" | |||||
| # } | |||||
| # } | |||||
| Tasks.dialog_state_tracking: ['dialog_states'] | |||||
| Tasks.text_to_image_synthesis: ['image'] | |||||
| } | } | ||||
| @@ -118,8 +118,14 @@ class DialogStateTrackingPreprocessor(Preprocessor): | |||||
| for slot in self.config.dst_slot_list | for slot in self.config.dst_slot_list | ||||
| } | } | ||||
| if len(history_states) > 2: | |||||
| ds = history_states[-2] | |||||
| else: | |||||
| ds = {slot: 'none' for slot in self.config.dst_slot_list} | |||||
| return { | return { | ||||
| 'batch': dataset, | 'batch': dataset, | ||||
| 'features': features, | 'features': features, | ||||
| 'diag_state': diag_state | |||||
| 'diag_state': diag_state, | |||||
| 'ds': ds | |||||
| } | } | ||||
| @@ -432,6 +432,7 @@ class multiwoz22Processor(DSTProcessor): | |||||
| usr_sys_switch = True | usr_sys_switch = True | ||||
| turn_itr = 0 | turn_itr = 0 | ||||
| inform_dict = {slot: 'none' for slot in slot_list} | |||||
| for utt in utterances: | for utt in utterances: | ||||
| # Assert that system and user utterances alternate | # Assert that system and user utterances alternate | ||||
| is_sys_utt = utt['metadata'] != {} | is_sys_utt = utt['metadata'] != {} | ||||
| @@ -1501,7 +1502,7 @@ if __name__ == '__main__': | |||||
| } | } | ||||
| }, {}] | }, {}] | ||||
| example = processor.create_example(utter3, history_states3, set_type, | |||||
| example = processor.create_example(utter2, history_states2, set_type, | |||||
| slot_list, {}, append_history, | slot_list, {}, append_history, | ||||
| use_history_labels, swap_utterances, | use_history_labels, swap_utterances, | ||||
| label_value_repetitions, | label_value_repetitions, | ||||