From a52b75c9c113d55251b2aea088081e010226ff11 Mon Sep 17 00:00:00 2001 From: ly119399 Date: Wed, 29 Jun 2022 21:57:50 +0800 Subject: [PATCH] update dst conf --- modelscope/metainfo.py | 2 +- .../preprocessors/space/dst_processors.py | 36 +++++++++---------- .../nlp/test_dialog_state_tracking.py | 8 ++--- 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 51618bb1..bd826141 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -54,7 +54,7 @@ class Pipelines(object): nli = 'nli' dialog_intent_prediction = 'dialog-intent-prediction' dialog_modeling = 'dialog-modeling' - dialog_state_tracking = 'dialog_state_tracking' + dialog_state_tracking = 'dialog-state-tracking' # audio tasks sambert_hifigan_16k_tts = 'sambert-hifigan-16k-tts' diff --git a/modelscope/preprocessors/space/dst_processors.py b/modelscope/preprocessors/space/dst_processors.py index ed20a168..6f9de25a 100644 --- a/modelscope/preprocessors/space/dst_processors.py +++ b/modelscope/preprocessors/space/dst_processors.py @@ -1135,24 +1135,24 @@ def convert_examples_to_features(examples, assert (len(input_ids) == len(input_ids_unmasked)) - if example_index < 10: - logger.info('*** Example ***') - logger.info('guid: %s' % (example.guid)) - logger.info('tokens: %s' % ' '.join(tokens)) - logger.info('input_ids: %s' % ' '.join([str(x) - for x in input_ids])) - logger.info('input_mask: %s' - % ' '.join([str(x) for x in input_mask])) - logger.info('segment_ids: %s' - % ' '.join([str(x) for x in segment_ids])) - logger.info('start_pos: %s' % str(start_pos_dict)) - logger.info('end_pos: %s' % str(end_pos_dict)) - logger.info('values: %s' % str(value_dict)) - logger.info('inform: %s' % str(inform_dict)) - logger.info('inform_slot: %s' % str(inform_slot_dict)) - logger.info('refer_id: %s' % str(refer_id_dict)) - logger.info('diag_state: %s' % str(diag_state_dict)) - logger.info('class_label_id: %s' % str(class_label_id_dict)) + # if example_index < 10: + # logger.info('*** Example ***') + # logger.info('guid: %s' % (example.guid)) + # logger.info('tokens: %s' % ' '.join(tokens)) + # logger.info('input_ids: %s' % ' '.join([str(x) + # for x in input_ids])) + # logger.info('input_mask: %s' + # % ' '.join([str(x) for x in input_mask])) + # logger.info('segment_ids: %s' + # % ' '.join([str(x) for x in segment_ids])) + # logger.info('start_pos: %s' % str(start_pos_dict)) + # logger.info('end_pos: %s' % str(end_pos_dict)) + # logger.info('values: %s' % str(value_dict)) + # logger.info('inform: %s' % str(inform_dict)) + # logger.info('inform_slot: %s' % str(inform_slot_dict)) + # logger.info('refer_id: %s' % str(refer_id_dict)) + # logger.info('diag_state: %s' % str(diag_state_dict)) + # logger.info('class_label_id: %s' % str(class_label_id_dict)) features.append( InputFeatures( diff --git a/tests/pipelines/nlp/test_dialog_state_tracking.py b/tests/pipelines/nlp/test_dialog_state_tracking.py index 89f1bafc..f8756cdb 100644 --- a/tests/pipelines/nlp/test_dialog_state_tracking.py +++ b/tests/pipelines/nlp/test_dialog_state_tracking.py @@ -46,10 +46,10 @@ class DialogStateTrackingTest(unittest.TestCase): pipelines = [ DialogStateTrackingPipeline( model=model, preprocessor=preprocessor), - # pipeline( - # task=Tasks.dialog_state_tracking, - # model=model, - # preprocessor=preprocessor) + pipeline( + task=Tasks.dialog_state_tracking, + model=model, + preprocessor=preprocessor) ] history_states = [{}]