You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

lstm.py 3.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """LSTM."""
  16. import numpy as np
  17. from mindspore import Tensor, nn, context
  18. from mindspore.ops import operations as P
  19. # Initialize short-term memory (h) and long-term memory (c) to 0
  20. def lstm_default_state(batch_size, hidden_size, num_layers, bidirectional):
  21. """init default input."""
  22. num_directions = 1
  23. if bidirectional:
  24. num_directions = 2
  25. if context.get_context("device_target") == "CPU":
  26. h_list = []
  27. c_list = []
  28. i = 0
  29. while i < num_layers:
  30. hi = Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32))
  31. h_list.append(hi)
  32. ci = Tensor(np.zeros((num_directions, batch_size, hidden_size)).astype(np.float32))
  33. c_list.append(ci)
  34. i = i + 1
  35. h = tuple(h_list)
  36. c = tuple(c_list)
  37. return h, c
  38. h = Tensor(
  39. np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32))
  40. c = Tensor(
  41. np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32))
  42. return h, c
  43. class SentimentNet(nn.Cell):
  44. """Sentiment network structure."""
  45. def __init__(self,
  46. vocab_size,
  47. embed_size,
  48. num_hiddens,
  49. num_layers,
  50. bidirectional,
  51. num_classes,
  52. weight,
  53. batch_size):
  54. super(SentimentNet, self).__init__()
  55. # Mapp words to vectors
  56. self.embedding = nn.Embedding(vocab_size,
  57. embed_size,
  58. embedding_table=weight)
  59. self.embedding.embedding_table.requires_grad = False
  60. self.trans = P.Transpose()
  61. self.perm = (1, 0, 2)
  62. self.encoder = nn.LSTM(input_size=embed_size,
  63. hidden_size=num_hiddens,
  64. num_layers=num_layers,
  65. has_bias=True,
  66. bidirectional=bidirectional,
  67. dropout=0.0)
  68. self.h, self.c = lstm_default_state(batch_size, num_hiddens, num_layers, bidirectional)
  69. self.concat = P.Concat(1)
  70. if bidirectional:
  71. self.decoder = nn.Dense(num_hiddens * 4, num_classes)
  72. else:
  73. self.decoder = nn.Dense(num_hiddens * 2, num_classes)
  74. def construct(self, inputs):
  75. # input:(64,500,300)
  76. embeddings = self.embedding(inputs)
  77. embeddings = self.trans(embeddings, self.perm)
  78. output, _ = self.encoder(embeddings, (self.h, self.c))
  79. # states[i] size(64,200) -> encoding.size(64,400)
  80. encoding = self.concat((output[0], output[-1]))
  81. outputs = self.decoder(encoding)
  82. return outputs