Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10763022 * [fix sequence labeling postprocess bug]master
| @@ -92,6 +92,8 @@ class NamedEntityRecognitionPipeline(Pipeline): | |||||
| offset_mapping = [x.cpu().tolist() for x in inputs['offset_mapping']] | offset_mapping = [x.cpu().tolist() for x in inputs['offset_mapping']] | ||||
| labels = [self.id2label[x] for x in predictions] | labels = [self.id2label[x] for x in predictions] | ||||
| if len(labels) > len(offset_mapping): | |||||
| labels = labels[1:-1] | |||||
| chunks = [] | chunks = [] | ||||
| chunk = {} | chunk = {} | ||||
| for label, offsets in zip(labels, offset_mapping): | for label, offsets in zip(labels, offset_mapping): | ||||
| @@ -104,6 +106,20 @@ class NamedEntityRecognitionPipeline(Pipeline): | |||||
| 'start': offsets[0], | 'start': offsets[0], | ||||
| 'end': offsets[1] | 'end': offsets[1] | ||||
| } | } | ||||
| if label[0] in 'I': | |||||
| if not chunk: | |||||
| chunk = { | |||||
| 'type': label[2:], | |||||
| 'start': offsets[0], | |||||
| 'end': offsets[1] | |||||
| } | |||||
| if label[0] in 'E': | |||||
| if not chunk: | |||||
| chunk = { | |||||
| 'type': label[2:], | |||||
| 'start': offsets[0], | |||||
| 'end': offsets[1] | |||||
| } | |||||
| if label[0] in 'IES': | if label[0] in 'IES': | ||||
| if chunk: | if chunk: | ||||
| chunk['end'] = offsets[1] | chunk['end'] = offsets[1] | ||||
| @@ -118,15 +134,15 @@ class NamedEntityRecognitionPipeline(Pipeline): | |||||
| chunk['span'] = text[chunk['start']:chunk['end']] | chunk['span'] = text[chunk['start']:chunk['end']] | ||||
| chunks.append(chunk) | chunks.append(chunk) | ||||
| # for cws output | |||||
| # for cws outputs | |||||
| if len(chunks) > 0 and chunks[0]['type'] == 'cws': | if len(chunks) > 0 and chunks[0]['type'] == 'cws': | ||||
| spans = [ | spans = [ | ||||
| chunk['span'] for chunk in chunks if chunk['span'].strip() | chunk['span'] for chunk in chunks if chunk['span'].strip() | ||||
| ] | ] | ||||
| seg_result = ' '.join(spans) | seg_result = ' '.join(spans) | ||||
| outputs = {OutputKeys.OUTPUT: seg_result, OutputKeys.LABELS: []} | |||||
| outputs = {OutputKeys.OUTPUT: seg_result} | |||||
| # for ner outpus | |||||
| # for ner outputs | |||||
| else: | else: | ||||
| outputs = {OutputKeys.OUTPUT: chunks} | outputs = {OutputKeys.OUTPUT: chunks} | ||||
| return outputs | return outputs | ||||
| @@ -95,6 +95,20 @@ class TokenClassificationPipeline(Pipeline): | |||||
| 'start': offsets[0], | 'start': offsets[0], | ||||
| 'end': offsets[1] | 'end': offsets[1] | ||||
| } | } | ||||
| if label[0] in 'I': | |||||
| if not chunk: | |||||
| chunk = { | |||||
| 'type': label[2:], | |||||
| 'start': offsets[0], | |||||
| 'end': offsets[1] | |||||
| } | |||||
| if label[0] in 'E': | |||||
| if not chunk: | |||||
| chunk = { | |||||
| 'type': label[2:], | |||||
| 'start': offsets[0], | |||||
| 'end': offsets[1] | |||||
| } | |||||
| if label[0] in 'IES': | if label[0] in 'IES': | ||||
| if chunk: | if chunk: | ||||
| chunk['end'] = offsets[1] | chunk['end'] = offsets[1] | ||||
| @@ -80,9 +80,12 @@ class WordSegmentationPipeline(Pipeline): | |||||
| Dict[str, str]: the prediction results | Dict[str, str]: the prediction results | ||||
| """ | """ | ||||
| text = inputs['text'] | text = inputs['text'] | ||||
| logits = inputs[OutputKeys.LOGITS] | |||||
| predictions = torch.argmax(logits[0], dim=-1) | |||||
| logits = torch_nested_numpify(torch_nested_detach(logits)) | |||||
| if not hasattr(inputs, 'predictions'): | |||||
| logits = inputs[OutputKeys.LOGITS] | |||||
| predictions = torch.argmax(logits[0], dim=-1) | |||||
| else: | |||||
| predictions = inputs[OutputKeys.PREDICTIONS].squeeze( | |||||
| 0).cpu().numpy() | |||||
| predictions = torch_nested_numpify(torch_nested_detach(predictions)) | predictions = torch_nested_numpify(torch_nested_detach(predictions)) | ||||
| offset_mapping = [x.cpu().tolist() for x in inputs['offset_mapping']] | offset_mapping = [x.cpu().tolist() for x in inputs['offset_mapping']] | ||||
| @@ -101,6 +104,20 @@ class WordSegmentationPipeline(Pipeline): | |||||
| 'start': offsets[0], | 'start': offsets[0], | ||||
| 'end': offsets[1] | 'end': offsets[1] | ||||
| } | } | ||||
| if label[0] in 'I': | |||||
| if not chunk: | |||||
| chunk = { | |||||
| 'type': label[2:], | |||||
| 'start': offsets[0], | |||||
| 'end': offsets[1] | |||||
| } | |||||
| if label[0] in 'E': | |||||
| if not chunk: | |||||
| chunk = { | |||||
| 'type': label[2:], | |||||
| 'start': offsets[0], | |||||
| 'end': offsets[1] | |||||
| } | |||||
| if label[0] in 'IES': | if label[0] in 'IES': | ||||
| if chunk: | if chunk: | ||||
| chunk['end'] = offsets[1] | chunk['end'] = offsets[1] | ||||
| @@ -123,7 +140,7 @@ class WordSegmentationPipeline(Pipeline): | |||||
| seg_result = ' '.join(spans) | seg_result = ' '.join(spans) | ||||
| outputs = {OutputKeys.OUTPUT: seg_result} | outputs = {OutputKeys.OUTPUT: seg_result} | ||||
| # for ner output | |||||
| # for ner outputs | |||||
| else: | else: | ||||
| outputs = {OutputKeys.OUTPUT: chunks} | outputs = {OutputKeys.OUTPUT: chunks} | ||||
| return outputs | return outputs | ||||