You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 4.8 kB

7 years ago
7 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. import torch
  2. def seq_lens_to_mask(seq_lens):
  3. batch_size = seq_lens.size(0)
  4. max_len = seq_lens.max()
  5. indexes = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device)
  6. masks = indexes.lt(seq_lens.unsqueeze(1))
  7. return masks
  8. from itertools import chain
  9. def refine_ys_on_seq_len(ys, seq_lens):
  10. refined_ys = []
  11. for b_idx, length in enumerate(seq_lens):
  12. refined_ys.append(list(ys[b_idx][:length]))
  13. return refined_ys
  14. def flat_nested_list(nested_list):
  15. return list(chain(*nested_list))
  16. def calculate_pre_rec_f1(model, batcher, type='segapp'):
  17. true_ys, pred_ys = decode_iterator(model, batcher)
  18. true_ys = flat_nested_list(true_ys)
  19. pred_ys = flat_nested_list(pred_ys)
  20. cor_num = 0
  21. start = 0
  22. if type=='segapp':
  23. yp_wordnum = pred_ys.count(1)
  24. yt_wordnum = true_ys.count(1)
  25. if true_ys[0]==1 and pred_ys[0]==1:
  26. cor_num += 1
  27. start = 1
  28. for i in range(1, len(true_ys)):
  29. if true_ys[i] == 1:
  30. flag = True
  31. if true_ys[start-1] != pred_ys[start-1]:
  32. flag = False
  33. else:
  34. for j in range(start, i + 1):
  35. if true_ys[j] != pred_ys[j]:
  36. flag = False
  37. break
  38. if flag:
  39. cor_num += 1
  40. start = i + 1
  41. elif type=='bmes':
  42. yp_wordnum = pred_ys.count(2) + pred_ys.count(3)
  43. yt_wordnum = true_ys.count(2) + true_ys.count(3)
  44. for i in range(len(true_ys)):
  45. if true_ys[i] == 2 or true_ys[i] == 3:
  46. flag = True
  47. for j in range(start, i + 1):
  48. if true_ys[j] != pred_ys[j]:
  49. flag = False
  50. break
  51. if flag:
  52. cor_num += 1
  53. start = i + 1
  54. P = cor_num / (float(yp_wordnum) + 1e-6)
  55. R = cor_num / (float(yt_wordnum) + 1e-6)
  56. F = 2 * P * R / (P + R + 1e-6)
  57. # print(cor_num, yt_wordnum, yp_wordnum)
  58. return P, R, F
  59. def decode_iterator(model, batcher):
  60. true_ys = []
  61. pred_ys = []
  62. seq_lens = []
  63. with torch.no_grad():
  64. model.eval()
  65. for batch_x, batch_y in batcher:
  66. pred_dict = model.predict(**batch_x)
  67. seq_len = batch_x['seq_lens'].cpu().numpy()
  68. pred_y = pred_dict['pred_tags']
  69. true_y = batch_y['tags']
  70. pred_y = pred_y.cpu().numpy()
  71. true_y = true_y.cpu().numpy()
  72. true_ys.extend(true_y.tolist())
  73. pred_ys.extend(pred_y.tolist())
  74. seq_lens.extend(list(seq_len))
  75. model.train()
  76. true_ys = refine_ys_on_seq_len(true_ys, seq_lens)
  77. pred_ys = refine_ys_on_seq_len(pred_ys, seq_lens)
  78. return true_ys, pred_ys
  79. from torch import nn
  80. import torch.nn.functional as F
  81. class FocalLoss(nn.Module):
  82. r"""
  83. This criterion is a implemenation of Focal Loss, which is proposed in
  84. Focal Loss for Dense Object Detection.
  85. Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])
  86. The losses are averaged across observations for each minibatch.
  87. Args:
  88. alpha(1D Tensor, Variable) : the scalar factor for this criterion
  89. gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5),
  90. putting more focus on hard, misclassified examples
  91. size_average(bool): size_average(bool): By default, the losses are averaged over observations for each minibatch.
  92. However, if the field size_average is set to False, the losses are
  93. instead summed for each minibatch.
  94. """
  95. def __init__(self, class_num, gamma=2, size_average=True, reduce=False):
  96. super(FocalLoss, self).__init__()
  97. self.gamma = gamma
  98. self.class_num = class_num
  99. self.size_average = size_average
  100. self.reduce = reduce
  101. def forward(self, inputs, targets):
  102. N = inputs.size(0)
  103. C = inputs.size(1)
  104. P = F.softmax(inputs, dim=-1)
  105. class_mask = inputs.data.new(N, C).fill_(0)
  106. class_mask.requires_grad = True
  107. ids = targets.view(-1, 1)
  108. class_mask = class_mask.scatter(1, ids.data, 1.)
  109. probs = (P * class_mask).sum(1).view(-1, 1)
  110. log_p = probs.log()
  111. batch_loss = - (torch.pow((1 - probs), self.gamma)) * log_p
  112. if self.reduce:
  113. if self.size_average:
  114. loss = batch_loss.mean()
  115. else:
  116. loss = batch_loss.sum()
  117. return loss
  118. return batch_loss