You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

small.py 5.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. from utils_ import get_skip_path_trivial, Trie, get_skip_path
  2. from load_data import load_yangjie_rich_pretrain_word_list, load_ontonotes4ner, equip_chinese_ner_with_skip
  3. from pathes import *
  4. from functools import partial
  5. from fastNLP import cache_results
  6. from fastNLP.embeddings.static_embedding import StaticEmbedding
  7. import torch
  8. import torch.nn as nn
  9. import torch.nn.functional as F
  10. from fastNLP.core.metrics import _bmes_tag_to_spans,_bmeso_tag_to_spans
  11. from load_data import load_resume_ner
  12. # embed = StaticEmbedding(None,embedding_dim=2)
  13. # datasets,vocabs,embeddings = load_ontonotes4ner(ontonote4ner_cn_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path,
  14. # _refresh=True,index_token=False)
  15. #
  16. # w_list = load_yangjie_rich_pretrain_word_list(yangjie_rich_pretrain_word_path,
  17. # _refresh=False)
  18. #
  19. # datasets,vocabs,embeddings = equip_chinese_ner_with_skip(datasets,vocabs,embeddings,w_list,yangjie_rich_pretrain_word_path,
  20. # _refresh=True)
  21. #
  22. def reverse_style(input_string):
  23. target_position = input_string.index('[')
  24. input_len = len(input_string)
  25. output_string = input_string[target_position:input_len] + input_string[0:target_position]
  26. # print('in:{}.out:{}'.format(input_string, output_string))
  27. return output_string
  28. def get_yangjie_bmeso(label_list):
  29. def get_ner_BMESO_yj(label_list):
  30. # list_len = len(word_list)
  31. # assert(list_len == len(label_list)), "word list size unmatch with label list"
  32. list_len = len(label_list)
  33. begin_label = 'b-'
  34. end_label = 'e-'
  35. single_label = 's-'
  36. whole_tag = ''
  37. index_tag = ''
  38. tag_list = []
  39. stand_matrix = []
  40. for i in range(0, list_len):
  41. # wordlabel = word_list[i]
  42. current_label = label_list[i].lower()
  43. if begin_label in current_label:
  44. if index_tag != '':
  45. tag_list.append(whole_tag + ',' + str(i - 1))
  46. whole_tag = current_label.replace(begin_label, "", 1) + '[' + str(i)
  47. index_tag = current_label.replace(begin_label, "", 1)
  48. elif single_label in current_label:
  49. if index_tag != '':
  50. tag_list.append(whole_tag + ',' + str(i - 1))
  51. whole_tag = current_label.replace(single_label, "", 1) + '[' + str(i)
  52. tag_list.append(whole_tag)
  53. whole_tag = ""
  54. index_tag = ""
  55. elif end_label in current_label:
  56. if index_tag != '':
  57. tag_list.append(whole_tag + ',' + str(i))
  58. whole_tag = ''
  59. index_tag = ''
  60. else:
  61. continue
  62. if (whole_tag != '') & (index_tag != ''):
  63. tag_list.append(whole_tag)
  64. tag_list_len = len(tag_list)
  65. for i in range(0, tag_list_len):
  66. if len(tag_list[i]) > 0:
  67. tag_list[i] = tag_list[i] + ']'
  68. insert_list = reverse_style(tag_list[i])
  69. stand_matrix.append(insert_list)
  70. # print stand_matrix
  71. return stand_matrix
  72. def transform_YJ_to_fastNLP(span):
  73. span = span[1:]
  74. span_split = span.split(']')
  75. # print('span_list:{}'.format(span_split))
  76. span_type = span_split[1]
  77. # print('span_split[0].split(','):{}'.format(span_split[0].split(',')))
  78. if ',' in span_split[0]:
  79. b, e = span_split[0].split(',')
  80. else:
  81. b = span_split[0]
  82. e = b
  83. b = int(b)
  84. e = int(e)
  85. e += 1
  86. return (span_type, (b, e))
  87. yj_form = get_ner_BMESO_yj(label_list)
  88. # print('label_list:{}'.format(label_list))
  89. # print('yj_from:{}'.format(yj_form))
  90. fastNLP_form = list(map(transform_YJ_to_fastNLP,yj_form))
  91. return fastNLP_form
  92. # tag_list = ['O', 'B-singer', 'M-singer', 'E-singer', 'O', 'O']
  93. # span_list = get_ner_BMES(tag_list)
  94. # print(span_list)
  95. # yangjie_label_list = ['B-NAME', 'E-NAME', 'O', 'B-CONT', 'M-CONT', 'E-CONT', 'B-RACE', 'E-RACE', 'B-TITLE', 'M-TITLE', 'E-TITLE', 'B-EDU', 'M-EDU', 'E-EDU', 'B-ORG', 'M-ORG', 'E-ORG', 'M-NAME', 'B-PRO', 'M-PRO', 'E-PRO', 'S-RACE', 'S-NAME', 'B-LOC', 'M-LOC', 'E-LOC', 'M-RACE', 'S-ORG']
  96. # my_label_list = ['O', 'M-ORG', 'M-TITLE', 'B-TITLE', 'E-TITLE', 'B-ORG', 'E-ORG', 'M-EDU', 'B-NAME', 'E-NAME', 'B-EDU', 'E-EDU', 'M-NAME', 'M-PRO', 'M-CONT', 'B-PRO', 'E-PRO', 'B-CONT', 'E-CONT', 'M-LOC', 'B-RACE', 'E-RACE', 'S-NAME', 'B-LOC', 'E-LOC', 'M-RACE', 'S-RACE', 'S-ORG']
  97. # yangjie_label = set(yangjie_label_list)
  98. # my_label = set(my_label_list)
  99. a = torch.tensor([0,2,0,3])
  100. b = (a==0)
  101. print(b)
  102. print(b.float())
  103. from fastNLP import RandomSampler
  104. # f = open('/remote-home/xnli/weight_debug/lattice_yangjie.pkl','rb')
  105. # weight_dict = torch.load(f)
  106. # print(weight_dict.keys())
  107. # for k,v in weight_dict.items():
  108. # print("{}:{}".format(k,v.size()))