You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

sample_process.py 3.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """process txt"""
  16. import re
  17. import json
  18. def process_one_example_p(tokenizer, text, max_seq_len=128):
  19. """process one testline"""
  20. textlist = list(text)
  21. tokens = []
  22. for _, word in enumerate(textlist):
  23. token = tokenizer.tokenize(word)
  24. tokens.extend(token)
  25. if len(tokens) >= max_seq_len - 1:
  26. tokens = tokens[0:(max_seq_len - 2)]
  27. ntokens = []
  28. segment_ids = []
  29. label_ids = []
  30. ntokens.append("[CLS]")
  31. segment_ids.append(0)
  32. for _, token in enumerate(tokens):
  33. ntokens.append(token)
  34. segment_ids.append(0)
  35. ntokens.append("[SEP]")
  36. segment_ids.append(0)
  37. input_ids = tokenizer.convert_tokens_to_ids(ntokens)
  38. input_mask = [1] * len(input_ids)
  39. while len(input_ids) < max_seq_len:
  40. input_ids.append(0)
  41. input_mask.append(0)
  42. segment_ids.append(0)
  43. label_ids.append(0)
  44. ntokens.append("**NULL**")
  45. assert len(input_ids) == max_seq_len
  46. assert len(input_mask) == max_seq_len
  47. assert len(segment_ids) == max_seq_len
  48. feature = (input_ids, input_mask, segment_ids)
  49. return feature
  50. def label_generation(text, probs):
  51. """generate label"""
  52. data = [text]
  53. probs = [probs]
  54. result = []
  55. label2id = json.loads(open("./label2id.json").read())
  56. id2label = [k for k, v in label2id.items()]
  57. for index, prob in enumerate(probs):
  58. for v in prob[1:len(data[index]) + 1]:
  59. result.append(id2label[int(v)])
  60. labels = {}
  61. start = None
  62. index = 0
  63. for _, t in zip("".join(data), result):
  64. if re.search("^[BS]", t):
  65. if start is not None:
  66. label = result[index - 1][2:]
  67. if labels.get(label):
  68. te_ = text[start:index]
  69. labels[label][te_] = [[start, index - 1]]
  70. else:
  71. te_ = text[start:index]
  72. labels[label] = {te_: [[start, index - 1]]}
  73. start = index
  74. if re.search("^O", t):
  75. if start is not None:
  76. label = result[index - 1][2:]
  77. if labels.get(label):
  78. te_ = text[start:index]
  79. labels[label][te_] = [[start, index - 1]]
  80. else:
  81. te_ = text[start:index]
  82. labels[label] = {te_: [[start, index - 1]]}
  83. start = None
  84. index += 1
  85. if start is not None:
  86. label = result[start][2:]
  87. if labels.get(label):
  88. te_ = text[start:index]
  89. labels[label][te_] = [[start, index - 1]]
  90. else:
  91. te_ = text[start:index]
  92. labels[label] = {te_: [[start, index - 1]]}
  93. return labels