|
|
|
@@ -80,9 +80,12 @@ class WordSegmentationPipeline(Pipeline): |
|
|
|
Dict[str, str]: the prediction results |
|
|
|
""" |
|
|
|
text = inputs['text'] |
|
|
|
logits = inputs[OutputKeys.LOGITS] |
|
|
|
predictions = torch.argmax(logits[0], dim=-1) |
|
|
|
logits = torch_nested_numpify(torch_nested_detach(logits)) |
|
|
|
if not hasattr(inputs, 'predictions'): |
|
|
|
logits = inputs[OutputKeys.LOGITS] |
|
|
|
predictions = torch.argmax(logits[0], dim=-1) |
|
|
|
else: |
|
|
|
predictions = inputs[OutputKeys.PREDICTIONS].squeeze( |
|
|
|
0).cpu().numpy() |
|
|
|
predictions = torch_nested_numpify(torch_nested_detach(predictions)) |
|
|
|
offset_mapping = [x.cpu().tolist() for x in inputs['offset_mapping']] |
|
|
|
|
|
|
|
@@ -101,6 +104,20 @@ class WordSegmentationPipeline(Pipeline): |
|
|
|
'start': offsets[0], |
|
|
|
'end': offsets[1] |
|
|
|
} |
|
|
|
if label[0] in 'I': |
|
|
|
if not chunk: |
|
|
|
chunk = { |
|
|
|
'type': label[2:], |
|
|
|
'start': offsets[0], |
|
|
|
'end': offsets[1] |
|
|
|
} |
|
|
|
if label[0] in 'E': |
|
|
|
if not chunk: |
|
|
|
chunk = { |
|
|
|
'type': label[2:], |
|
|
|
'start': offsets[0], |
|
|
|
'end': offsets[1] |
|
|
|
} |
|
|
|
if label[0] in 'IES': |
|
|
|
if chunk: |
|
|
|
chunk['end'] = offsets[1] |
|
|
|
@@ -123,7 +140,7 @@ class WordSegmentationPipeline(Pipeline): |
|
|
|
seg_result = ' '.join(spans) |
|
|
|
outputs = {OutputKeys.OUTPUT: seg_result} |
|
|
|
|
|
|
|
# for ner output |
|
|
|
# for ner outputs |
|
|
|
else: |
|
|
|
outputs = {OutputKeys.OUTPUT: chunks} |
|
|
|
return outputs |