You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

textrcnn.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """model textrcnn"""
  16. import numpy as np
  17. import mindspore.nn as nn
  18. from mindspore.ops import operations as P
  19. from mindspore.ops import functional as F
  20. from mindspore.common.parameter import Parameter
  21. from mindspore import Tensor
  22. from mindspore.common import dtype as mstype
  23. class textrcnn(nn.Cell):
  24. """class textrcnn"""
  25. def __init__(self, weight, vocab_size, cell, batch_size):
  26. super(textrcnn, self).__init__()
  27. self.num_hiddens = 512
  28. self.embed_size = 300
  29. self.num_classes = 2
  30. self.batch_size = batch_size
  31. k = (1 / self.num_hiddens) ** 0.5
  32. self.embedding = nn.Embedding(vocab_size, self.embed_size, embedding_table=weight)
  33. self.embedding.embedding_table.requires_grad = False
  34. self.cell = cell
  35. self.cast = P.Cast()
  36. self.h1 = Tensor(np.zeros(shape=(self.batch_size, self.num_hiddens)).astype(np.float16))
  37. self.c1 = Tensor(np.zeros(shape=(self.batch_size, self.num_hiddens)).astype(np.float16))
  38. if cell == "lstm":
  39. self.lstm = P.DynamicRNN(forget_bias=0.0)
  40. self.w1_fw = Parameter(
  41. np.random.uniform(-k, k, (self.embed_size + self.num_hiddens, 4 * self.num_hiddens)).astype(
  42. np.float16), name="w1_fw")
  43. self.b1_fw = Parameter(np.random.uniform(-k, k, (4 * self.num_hiddens)).astype(np.float16),
  44. name="b1_fw")
  45. self.w1_bw = Parameter(
  46. np.random.uniform(-k, k, (self.embed_size + self.num_hiddens, 4 * self.num_hiddens)).astype(
  47. np.float16), name="w1_bw")
  48. self.b1_bw = Parameter(np.random.uniform(-k, k, (4 * self.num_hiddens)).astype(np.float16),
  49. name="b1_bw")
  50. self.h1 = Tensor(np.zeros(shape=(1, self.batch_size, self.num_hiddens)).astype(np.float16))
  51. self.c1 = Tensor(np.zeros(shape=(1, self.batch_size, self.num_hiddens)).astype(np.float16))
  52. if cell == "vanilla":
  53. self.rnnW_fw = nn.Dense(self.num_hiddens, self.num_hiddens)
  54. self.rnnU_fw = nn.Dense(self.embed_size, self.num_hiddens)
  55. self.rnnW_bw = nn.Dense(self.num_hiddens, self.num_hiddens)
  56. self.rnnU_bw = nn.Dense(self.embed_size, self.num_hiddens)
  57. if cell == "gru":
  58. self.rnnWr_fw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens)
  59. self.rnnWz_fw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens)
  60. self.rnnWh_fw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens)
  61. self.rnnWr_bw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens)
  62. self.rnnWz_bw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens)
  63. self.rnnWh_bw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens)
  64. self.ones = Tensor(np.ones(shape=(self.batch_size, self.num_hiddens)).astype(np.float16))
  65. self.rnnWr_fw.to_float(mstype.float16)
  66. self.rnnWz_fw.to_float(mstype.float16)
  67. self.rnnWh_fw.to_float(mstype.float16)
  68. self.rnnWr_bw.to_float(mstype.float16)
  69. self.rnnWz_bw.to_float(mstype.float16)
  70. self.rnnWh_bw.to_float(mstype.float16)
  71. self.transpose = P.Transpose()
  72. self.reduce_max = P.ReduceMax()
  73. self.expand_dims = P.ExpandDims()
  74. self.concat = P.Concat()
  75. self.reshape = P.Reshape()
  76. self.left_pad_tensor = Tensor(np.zeros((1, self.batch_size, self.num_hiddens)).astype(np.float16))
  77. self.right_pad_tensor = Tensor(np.zeros((1, self.batch_size, self.num_hiddens)).astype(np.float16))
  78. self.output_dense = nn.Dense(self.num_hiddens * 1, 2)
  79. self.concat0 = P.Concat(0)
  80. self.concat2 = P.Concat(2)
  81. self.concat1 = P.Concat(1)
  82. self.text_rep_dense = nn.Dense(2 * self.num_hiddens + self.embed_size, self.num_hiddens)
  83. self.mydense = nn.Dense(self.num_hiddens, 2)
  84. self.drop_out = nn.Dropout(keep_prob=0.7)
  85. self.tanh = P.Tanh()
  86. self.sigmoid = P.Sigmoid()
  87. self.slice = P.Slice()
  88. self.text_rep_dense.to_float(mstype.float16)
  89. self.mydense.to_float(mstype.float16)
  90. self.output_dense.to_float(mstype.float16)
  91. def construct(self, x):
  92. """class construction"""
  93. # x: bs, sl
  94. output_fw = x
  95. output_bw = x
  96. if self.cell == "vanilla":
  97. x = self.embedding(x) # bs, sl, emb_size
  98. x = self.cast(x, mstype.float16)
  99. x = self.transpose(x, (1, 0, 2)) # sl, bs, emb_size
  100. x = self.drop_out(x) # sl,bs, emb_size
  101. h1_fw = self.cast(self.h1, mstype.float16) # bs, num_hidden
  102. h1_fw = self.tanh(self.rnnW_fw(h1_fw) + self.rnnU_fw(x[0, :, :])) # bs, num_hidden
  103. output_fw = self.expand_dims(h1_fw, 0) # 1, bs, num_hidden
  104. for i in range(1, F.shape(x)[0]):
  105. h1_fw = self.tanh(self.rnnW_fw(h1_fw) + self.rnnU_fw(x[i, :, :])) # 1, bs, num_hidden
  106. h1_after_expand_fw = self.expand_dims(h1_fw, 0)
  107. output_fw = self.concat((output_fw, h1_after_expand_fw)) # 2/3/4.., bs, num_hidden
  108. output_fw = self.cast(output_fw, mstype.float16) # sl, bs, num_hidden
  109. h1_bw = self.cast(self.h1, mstype.float16) # bs, num_hidden
  110. h1_bw = self.tanh(self.rnnW_bw(h1_bw) + self.rnnU_bw(x[F.shape(x)[0] - 1, :, :])) # bs, num_hidden
  111. output_bw = self.expand_dims(h1_bw, 0) # 1, bs, num_hidden
  112. for i in range(F.shape(x)[0] - 2, -1, -1):
  113. h1_bw = self.tanh(self.rnnW_bw(h1_bw) + self.rnnU_bw(x[i, :, :])) # 1, bs, num_hidden
  114. h1_after_expand_bw = self.expand_dims(h1_bw, 0)
  115. output_bw = self.concat((h1_after_expand_bw, output_bw)) # 2/3/4.., bs, num_hidden
  116. output_bw = self.cast(output_bw, mstype.float16) # sl, bs, num_hidden
  117. if self.cell == "gru":
  118. x = self.embedding(x) # bs, sl, emb_size
  119. x = self.cast(x, mstype.float16)
  120. x = self.transpose(x, (1, 0, 2)) # sl, bs, emb_size
  121. x = self.drop_out(x) # sl,bs, emb_size
  122. h_fw = self.cast(self.h1, mstype.float16)
  123. h_x_fw = self.concat1((h_fw, x[0, :, :]))
  124. r_fw = self.sigmoid(self.rnnWr_fw(h_x_fw))
  125. z_fw = self.sigmoid(self.rnnWz_fw(h_x_fw))
  126. h_tilde_fw = self.tanh(self.rnnWh_fw(self.concat1((r_fw * h_fw, x[0, :, :]))))
  127. h_fw = (self.ones - z_fw) * h_fw + z_fw * h_tilde_fw
  128. output_fw = self.expand_dims(h_fw, 0)
  129. for i in range(1, F.shape(x)[0]):
  130. h_x_fw = self.concat1((h_fw, x[i, :, :]))
  131. r_fw = self.sigmoid(self.rnnWr_fw(h_x_fw))
  132. z_fw = self.sigmoid(self.rnnWz_fw(h_x_fw))
  133. h_tilde_fw = self.tanh(self.rnnWh_fw(self.concat1((r_fw * h_fw, x[i, :, :]))))
  134. h_fw = (self.ones - z_fw) * h_fw + z_fw * h_tilde_fw
  135. h_after_expand_fw = self.expand_dims(h_fw, 0)
  136. output_fw = self.concat((output_fw, h_after_expand_fw))
  137. output_fw = self.cast(output_fw, mstype.float16)
  138. h_bw = self.cast(self.h1, mstype.float16) # bs, num_hidden
  139. h_x_bw = self.concat1((h_bw, x[F.shape(x)[0] - 1, :, :]))
  140. r_bw = self.sigmoid(self.rnnWr_bw(h_x_bw))
  141. z_bw = self.sigmoid(self.rnnWz_bw(h_x_bw))
  142. h_tilde_bw = self.tanh(self.rnnWh_bw(self.concat1((r_bw * h_bw, x[F.shape(x)[0] - 1, :, :]))))
  143. h_bw = (self.ones - z_bw) * h_bw + z_bw * h_tilde_bw
  144. output_bw = self.expand_dims(h_bw, 0)
  145. for i in range(F.shape(x)[0] - 2, -1, -1):
  146. h_x_bw = self.concat1((h_bw, x[i, :, :]))
  147. r_bw = self.sigmoid(self.rnnWr_bw(h_x_bw))
  148. z_bw = self.sigmoid(self.rnnWz_bw(h_x_bw))
  149. h_tilde_bw = self.tanh(self.rnnWh_bw(self.concat1((r_bw * h_bw, x[i, :, :]))))
  150. h_bw = (self.ones - z_bw) * h_bw + z_bw * h_tilde_bw
  151. h_after_expand_bw = self.expand_dims(h_bw, 0)
  152. output_bw = self.concat((h_after_expand_bw, output_bw))
  153. output_bw = self.cast(output_bw, mstype.float16)
  154. if self.cell == 'lstm':
  155. x = self.embedding(x) # bs, sl, emb_size
  156. x = self.cast(x, mstype.float16)
  157. x = self.transpose(x, (1, 0, 2)) # sl, bs, emb_size
  158. x = self.drop_out(x) # sl,bs, emb_size
  159. h1_fw_init = self.h1 # bs, num_hidden
  160. c1_fw_init = self.c1 # bs, num_hidden
  161. _, output_fw, _, _, _, _, _, _ = self.lstm(x, self.w1_fw, self.b1_fw, None, h1_fw_init, c1_fw_init)
  162. output_fw = self.cast(output_fw, mstype.float16) # sl, bs, num_hidden
  163. h1_bw_init = self.h1 # bs, num_hidden
  164. c1_bw_init = self.c1 # bs, num_hidden
  165. _, output_bw, _, _, _, _, _, _ = self.lstm(x, self.w1_bw, self.b1_bw, None, h1_bw_init, c1_bw_init)
  166. output_bw = self.cast(output_bw, mstype.float16) # sl, bs, hidden
  167. c_left = self.concat0((self.left_pad_tensor, output_fw[:F.shape(x)[0] - 1])) # sl, bs, num_hidden
  168. c_right = self.concat0((output_bw[1:], self.right_pad_tensor)) # sl, bs, num_hidden
  169. output = self.concat2((c_left, self.cast(x, mstype.float16), c_right)) # sl, bs, 2*num_hidden+emb_size
  170. output = self.cast(output, mstype.float16)
  171. output_flat = self.reshape(output, (F.shape(x)[0] * self.batch_size, 2 * self.num_hiddens + self.embed_size))
  172. output_dense = self.text_rep_dense(output_flat) # sl*bs, num_hidden
  173. output_dense = self.tanh(output_dense) # sl*bs, num_hidden
  174. output = self.reshape(output_dense, (F.shape(x)[0], self.batch_size, self.num_hiddens)) # sl, bs, num_hidden
  175. output = self.reduce_max(output, 0) # bs, num_hidden
  176. outputs = self.cast(self.mydense(output), mstype.float16) # bs, num_classes
  177. return outputs