You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CRF.py 7.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. '''
  16. CRF script.
  17. '''
  18. import numpy as np
  19. import mindspore.nn as nn
  20. from mindspore.ops import operations as P
  21. from mindspore.common.tensor import Tensor
  22. from mindspore.common.parameter import Parameter
  23. import mindspore.common.dtype as mstype
  24. class CRF(nn.Cell):
  25. '''
  26. Conditional Random Field
  27. Args:
  28. tag_to_index: The dict for tag to index mapping with extra "<START>" and "<STOP>"sign.
  29. batch_size: Batch size, i.e., the length of the first dimension.
  30. seq_length: Sequence length, i.e., the length of the second dimention.
  31. is_training: Specifies whether to use training mode.
  32. Returns:
  33. Training mode: Tensor, total loss.
  34. Evaluation mode: Tuple, the index for each step with the highest score; Tuple, the index for the last
  35. step with the highest score.
  36. '''
  37. def __init__(self, tag_to_index, batch_size=1, seq_length=128, is_training=True):
  38. super(CRF, self).__init__()
  39. self.target_size = len(tag_to_index)
  40. self.is_training = is_training
  41. self.tag_to_index = tag_to_index
  42. self.batch_size = batch_size
  43. self.seq_length = seq_length
  44. self.START_TAG = "<START>"
  45. self.STOP_TAG = "<STOP>"
  46. self.START_VALUE = Tensor(self.target_size-2, dtype=mstype.int32)
  47. self.STOP_VALUE = Tensor(self.target_size-1, dtype=mstype.int32)
  48. transitions = np.random.normal(size=(self.target_size, self.target_size)).astype(np.float32)
  49. transitions[tag_to_index[self.START_TAG], :] = -10000
  50. transitions[:, tag_to_index[self.STOP_TAG]] = -10000
  51. self.transitions = Parameter(Tensor(transitions), name="transition_matrix")
  52. self.cat = P.Concat(axis=-1)
  53. self.argmax = P.ArgMaxWithValue(axis=-1)
  54. self.log = P.Log()
  55. self.exp = P.Exp()
  56. self.sum = P.ReduceSum()
  57. self.tile = P.Tile()
  58. self.reduce_sum = P.ReduceSum(keep_dims=True)
  59. self.reshape = P.Reshape()
  60. self.expand = P.ExpandDims()
  61. self.mean = P.ReduceMean()
  62. init_alphas = np.ones(shape=(self.batch_size, self.target_size)) * -10000.0
  63. init_alphas[:, self.tag_to_index[self.START_TAG]] = 0.
  64. self.init_alphas = Tensor(init_alphas, dtype=mstype.float32)
  65. self.cast = P.Cast()
  66. self.reduce_max = P.ReduceMax(keep_dims=True)
  67. self.on_value = Tensor(1.0, dtype=mstype.float32)
  68. self.off_value = Tensor(0.0, dtype=mstype.float32)
  69. self.onehot = P.OneHot()
  70. def log_sum_exp(self, logits):
  71. '''
  72. Compute the log_sum_exp score for normalization factor.
  73. '''
  74. max_score = self.reduce_max(logits, -1) #16 5 5
  75. score = self.log(self.reduce_sum(self.exp(logits - max_score), -1))
  76. score = max_score + score
  77. return score
  78. def _realpath_score(self, features, label):
  79. '''
  80. Compute the emission and transition score for the real path.
  81. '''
  82. label = label * 1
  83. concat_A = self.tile(self.reshape(self.START_VALUE, (1,)), (self.batch_size,))
  84. concat_A = self.reshape(concat_A, (self.batch_size, 1))
  85. labels = self.cat((concat_A, label))
  86. onehot_label = self.onehot(label, self.target_size, self.on_value, self.off_value)
  87. emits = features * onehot_label
  88. labels = self.onehot(labels, self.target_size, self.on_value, self.off_value)
  89. label1 = labels[:, 1:, :]
  90. label2 = labels[:, :self.seq_length, :]
  91. label1 = self.expand(label1, 3)
  92. label2 = self.expand(label2, 2)
  93. label_trans = label1 * label2
  94. transitions = self.expand(self.expand(self.transitions, 0), 0)
  95. trans = transitions * label_trans
  96. score = self.sum(emits, (1, 2)) + self.sum(trans, (1, 2, 3))
  97. stop_value_index = labels[:, (self.seq_length-1):self.seq_length, :]
  98. stop_value = self.transitions[(self.target_size-1):self.target_size, :]
  99. stop_score = stop_value * self.reshape(stop_value_index, (self.batch_size, self.target_size))
  100. score = score + self.sum(stop_score, 1)
  101. score = self.reshape(score, (self.batch_size, -1))
  102. return score
  103. def _normalization_factor(self, features):
  104. '''
  105. Compute the total score for all the paths.
  106. '''
  107. forward_var = self.init_alphas
  108. forward_var = self.expand(forward_var, 1)
  109. for idx in range(self.seq_length):
  110. feat = features[:, idx:(idx+1), :]
  111. emit_score = self.reshape(feat, (self.batch_size, self.target_size, 1))
  112. next_tag_var = emit_score + self.transitions + forward_var
  113. forward_var = self.log_sum_exp(next_tag_var)
  114. forward_var = self.reshape(forward_var, (self.batch_size, 1, self.target_size))
  115. terminal_var = forward_var + self.reshape(self.transitions[(self.target_size-1):self.target_size, :], (1, -1))
  116. alpha = self.log_sum_exp(terminal_var)
  117. alpha = self.reshape(alpha, (self.batch_size, -1))
  118. return alpha
  119. def _decoder(self, features):
  120. '''
  121. Viterbi decode for evaluation.
  122. '''
  123. backpointers = ()
  124. forward_var = self.init_alphas
  125. for idx in range(self.seq_length):
  126. feat = features[:, idx:(idx+1), :]
  127. feat = self.reshape(feat, (self.batch_size, self.target_size))
  128. bptrs_t = ()
  129. next_tag_var = self.expand(forward_var, 1) + self.transitions
  130. best_tag_id, best_tag_value = self.argmax(next_tag_var)
  131. bptrs_t += (best_tag_id,)
  132. forward_var = best_tag_value + feat
  133. backpointers += (bptrs_t,)
  134. terminal_var = forward_var + self.reshape(self.transitions[(self.target_size-1):self.target_size, :], (1, -1))
  135. best_tag_id, _ = self.argmax(terminal_var)
  136. return backpointers, best_tag_id
  137. def construct(self, features, label):
  138. if self.is_training:
  139. forward_score = self._normalization_factor(features)
  140. gold_score = self._realpath_score(features, label)
  141. return_value = self.mean(forward_score - gold_score)
  142. else:
  143. path_list, tag = self._decoder(features)
  144. return_value = path_list, tag
  145. return return_value
  146. def postprocess(backpointers, best_tag_id):
  147. '''
  148. Do postprocess
  149. '''
  150. best_tag_id = best_tag_id.asnumpy()
  151. batch_size = len(best_tag_id)
  152. best_path = []
  153. for i in range(batch_size):
  154. best_path.append([])
  155. best_local_id = best_tag_id[i]
  156. best_path[-1].append(best_local_id)
  157. for bptrs_t in reversed(backpointers):
  158. bptrs_t = bptrs_t[0].asnumpy()
  159. local_idx = bptrs_t[i]
  160. best_local_id = local_idx[best_local_id]
  161. best_path[-1].append(best_local_id)
  162. # Pop off the start tag (we dont want to return that to the caller)
  163. best_path[-1].pop()
  164. best_path[-1].reverse()
  165. return best_path