You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

textcnn.py 5.1 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """TextCNN"""
  16. import mindspore.nn as nn
  17. import mindspore.ops.operations as P
  18. from mindspore import Tensor
  19. from mindspore.nn.cell import Cell
  20. import mindspore.ops.functional as F
  21. import mindspore
  22. class SoftmaxCrossEntropyExpand(Cell):
  23. r"""
  24. Computes softmax cross entropy between logits and labels. Implemented by expanded formula.
  25. This is a wrapper of several functions.
  26. .. math::
  27. \ell(x_i, t_i) = -log\left(\frac{\exp(x_{t_i})}{\sum_j \exp(x_j)}\right),
  28. where :math:`x_i` is a 1D score Tensor, :math:`t_i` is the target class.
  29. Note:
  30. When argument sparse is set to True, the format of label is the index
  31. range from :math:`0` to :math:`C - 1` instead of one-hot vectors.
  32. Args:
  33. sparse(bool): Specifies whether labels use sparse format or not. Default: False.
  34. Inputs:
  35. - **input_data** (Tensor) - Tensor of shape :math:`(x_1, x_2, ..., x_R)`.
  36. - **label** (Tensor) - Tensor of shape :math:`(y_1, y_2, ..., y_S)`.
  37. Outputs:
  38. Tensor, a scalar tensor including the mean loss.
  39. Examples:
  40. >>> loss = nn.SoftmaxCrossEntropyExpand(sparse=True)
  41. >>> input_data = Tensor(np.ones([64, 512]), dtype=mindspore.float32)
  42. >>> label = Tensor(np.ones([64]), dtype=mindspore.int32)
  43. >>> loss(input_data, label)
  44. """
  45. def __init__(self, sparse=False):
  46. super(SoftmaxCrossEntropyExpand, self).__init__()
  47. self.exp = P.Exp()
  48. self.reduce_sum = P.ReduceSum(keep_dims=True)
  49. self.onehot = P.OneHot()
  50. self.on_value = Tensor(1.0, mindspore.float32)
  51. self.off_value = Tensor(0.0, mindspore.float32)
  52. self.div = P.Div()
  53. self.log = P.Log()
  54. self.sum_cross_entropy = P.ReduceSum(keep_dims=False)
  55. self.mul = P.Mul()
  56. self.mul2 = P.Mul()
  57. self.cast = P.Cast()
  58. self.reduce_mean = P.ReduceMean(keep_dims=False)
  59. self.sparse = sparse
  60. self.reduce_max = P.ReduceMax(keep_dims=True)
  61. self.sub = P.Sub()
  62. def construct(self, logit, label):
  63. """
  64. construct
  65. """
  66. logit_max = self.reduce_max(logit, -1)
  67. exp = self.exp(self.sub(logit, logit_max))
  68. exp_sum = self.reduce_sum(exp, -1)
  69. softmax_result = self.div(exp, exp_sum)
  70. if self.sparse:
  71. label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
  72. softmax_result_log = self.log(softmax_result)
  73. loss = self.sum_cross_entropy((self.mul(softmax_result_log, label)), -1)
  74. loss = self.mul2(F.scalar_to_array(-1.0), loss)
  75. loss = self.reduce_mean(loss, -1)
  76. return loss
  77. def make_conv_layer(kernel_size):
  78. return nn.Conv2d(in_channels=1, out_channels=96, kernel_size=kernel_size, padding=1,
  79. pad_mode="pad", has_bias=True)
  80. class TextCNN(nn.Cell):
  81. """
  82. TextCNN architecture
  83. """
  84. def __init__(self, vocab_len, word_len, num_classes, vec_length, embedding_table='uniform'):
  85. super(TextCNN, self).__init__()
  86. self.vec_length = vec_length
  87. self.word_len = word_len
  88. self.num_classes = num_classes
  89. self.unsqueeze = P.ExpandDims()
  90. self.embedding = nn.Embedding(vocab_len, self.vec_length, embedding_table=embedding_table)
  91. self.slice = P.Slice()
  92. self.layer1 = self.make_layer(kernel_height=3)
  93. self.layer2 = self.make_layer(kernel_height=4)
  94. self.layer3 = self.make_layer(kernel_height=5)
  95. self.concat = P.Concat(1)
  96. self.fc = nn.Dense(96*3, self.num_classes)
  97. self.drop = nn.Dropout(keep_prob=0.5)
  98. self.print = P.Print()
  99. self.reducemax = P.ReduceMax(keep_dims=False)
  100. def make_layer(self, kernel_height):
  101. return nn.SequentialCell(
  102. [
  103. make_conv_layer((kernel_height, self.vec_length)), nn.ReLU(),
  104. nn.MaxPool2d(kernel_size=(self.word_len-kernel_height+1, 1)),
  105. ]
  106. )
  107. def construct(self, x):
  108. """
  109. construct
  110. """
  111. x = self.unsqueeze(x, 1)
  112. x = self.embedding(x)
  113. x1 = self.layer1(x)
  114. x2 = self.layer2(x)
  115. x3 = self.layer3(x)
  116. x1 = self.reducemax(x1, (2, 3))
  117. x2 = self.reducemax(x2, (2, 3))
  118. x3 = self.reducemax(x3, (2, 3))
  119. x = self.concat((x1, x2, x3))
  120. x = self.drop(x)
  121. x = self.fc(x)
  122. return x