- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """model textrcnn"""
- import numpy as np
-
- import mindspore.nn as nn
- from mindspore.ops import operations as P
- from mindspore.ops import functional as F
- from mindspore.common.parameter import Parameter
- from mindspore import Tensor
- from mindspore.common import dtype as mstype
-
-
- class textrcnn(nn.Cell):
- """class textrcnn"""
-
- def __init__(self, weight, vocab_size, cell, batch_size):
- super(textrcnn, self).__init__()
- self.num_hiddens = 512
- self.embed_size = 300
- self.num_classes = 2
- self.batch_size = batch_size
- k = (1 / self.num_hiddens) ** 0.5
-
- self.embedding = nn.Embedding(vocab_size, self.embed_size, embedding_table=weight)
- self.embedding.embedding_table.requires_grad = False
- self.cell = cell
-
- self.cast = P.Cast()
-
- self.h1 = Tensor(np.zeros(shape=(self.batch_size, self.num_hiddens)).astype(np.float16))
- self.c1 = Tensor(np.zeros(shape=(self.batch_size, self.num_hiddens)).astype(np.float16))
-
- if cell == "lstm":
- self.lstm = P.DynamicRNN(forget_bias=0.0)
- self.w1_fw = Parameter(
- np.random.uniform(-k, k, (self.embed_size + self.num_hiddens, 4 * self.num_hiddens)).astype(
- np.float16), name="w1_fw")
- self.b1_fw = Parameter(np.random.uniform(-k, k, (4 * self.num_hiddens)).astype(np.float16),
- name="b1_fw")
- self.w1_bw = Parameter(
- np.random.uniform(-k, k, (self.embed_size + self.num_hiddens, 4 * self.num_hiddens)).astype(
- np.float16), name="w1_bw")
- self.b1_bw = Parameter(np.random.uniform(-k, k, (4 * self.num_hiddens)).astype(np.float16),
- name="b1_bw")
- self.h1 = Tensor(np.zeros(shape=(1, self.batch_size, self.num_hiddens)).astype(np.float16))
- self.c1 = Tensor(np.zeros(shape=(1, self.batch_size, self.num_hiddens)).astype(np.float16))
-
- if cell == "vanilla":
- self.rnnW_fw = nn.Dense(self.num_hiddens, self.num_hiddens)
- self.rnnU_fw = nn.Dense(self.embed_size, self.num_hiddens)
- self.rnnW_bw = nn.Dense(self.num_hiddens, self.num_hiddens)
- self.rnnU_bw = nn.Dense(self.embed_size, self.num_hiddens)
-
- if cell == "gru":
- self.rnnWr_fw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens)
- self.rnnWz_fw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens)
- self.rnnWh_fw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens)
- self.rnnWr_bw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens)
- self.rnnWz_bw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens)
- self.rnnWh_bw = nn.Dense(self.num_hiddens + self.embed_size, self.num_hiddens)
- self.ones = Tensor(np.ones(shape=(self.batch_size, self.num_hiddens)).astype(np.float16))
- self.rnnWr_fw.to_float(mstype.float16)
- self.rnnWz_fw.to_float(mstype.float16)
- self.rnnWh_fw.to_float(mstype.float16)
- self.rnnWr_bw.to_float(mstype.float16)
- self.rnnWz_bw.to_float(mstype.float16)
- self.rnnWh_bw.to_float(mstype.float16)
-
- self.transpose = P.Transpose()
- self.reduce_max = P.ReduceMax()
- self.expand_dims = P.ExpandDims()
- self.concat = P.Concat()
-
- self.reshape = P.Reshape()
- self.left_pad_tensor = Tensor(np.zeros((1, self.batch_size, self.num_hiddens)).astype(np.float16))
- self.right_pad_tensor = Tensor(np.zeros((1, self.batch_size, self.num_hiddens)).astype(np.float16))
- self.output_dense = nn.Dense(self.num_hiddens * 1, 2)
- self.concat0 = P.Concat(0)
- self.concat2 = P.Concat(2)
- self.concat1 = P.Concat(1)
- self.text_rep_dense = nn.Dense(2 * self.num_hiddens + self.embed_size, self.num_hiddens)
- self.mydense = nn.Dense(self.num_hiddens, 2)
- self.drop_out = nn.Dropout(keep_prob=0.7)
- self.tanh = P.Tanh()
- self.sigmoid = P.Sigmoid()
- self.slice = P.Slice()
- self.text_rep_dense.to_float(mstype.float16)
- self.mydense.to_float(mstype.float16)
- self.output_dense.to_float(mstype.float16)
-
- def construct(self, x):
- """class construction"""
- # x: bs, sl
- output_fw = x
- output_bw = x
-
- if self.cell == "vanilla":
- x = self.embedding(x) # bs, sl, emb_size
- x = self.cast(x, mstype.float16)
- x = self.transpose(x, (1, 0, 2)) # sl, bs, emb_size
- x = self.drop_out(x) # sl,bs, emb_size
-
- h1_fw = self.cast(self.h1, mstype.float16) # bs, num_hidden
- h1_fw = self.tanh(self.rnnW_fw(h1_fw) + self.rnnU_fw(x[0, :, :])) # bs, num_hidden
- output_fw = self.expand_dims(h1_fw, 0) # 1, bs, num_hidden
-
- for i in range(1, F.shape(x)[0]):
- h1_fw = self.tanh(self.rnnW_fw(h1_fw) + self.rnnU_fw(x[i, :, :])) # 1, bs, num_hidden
- h1_after_expand_fw = self.expand_dims(h1_fw, 0)
- output_fw = self.concat((output_fw, h1_after_expand_fw)) # 2/3/4.., bs, num_hidden
- output_fw = self.cast(output_fw, mstype.float16) # sl, bs, num_hidden
-
- h1_bw = self.cast(self.h1, mstype.float16) # bs, num_hidden
- h1_bw = self.tanh(self.rnnW_bw(h1_bw) + self.rnnU_bw(x[F.shape(x)[0] - 1, :, :])) # bs, num_hidden
- output_bw = self.expand_dims(h1_bw, 0) # 1, bs, num_hidden
-
- for i in range(F.shape(x)[0] - 2, -1, -1):
- h1_bw = self.tanh(self.rnnW_bw(h1_bw) + self.rnnU_bw(x[i, :, :])) # 1, bs, num_hidden
- h1_after_expand_bw = self.expand_dims(h1_bw, 0)
- output_bw = self.concat((h1_after_expand_bw, output_bw)) # 2/3/4.., bs, num_hidden
- output_bw = self.cast(output_bw, mstype.float16) # sl, bs, num_hidden
-
- if self.cell == "gru":
- x = self.embedding(x) # bs, sl, emb_size
- x = self.cast(x, mstype.float16)
- x = self.transpose(x, (1, 0, 2)) # sl, bs, emb_size
- x = self.drop_out(x) # sl,bs, emb_size
-
- h_fw = self.cast(self.h1, mstype.float16)
-
- h_x_fw = self.concat1((h_fw, x[0, :, :]))
- r_fw = self.sigmoid(self.rnnWr_fw(h_x_fw))
- z_fw = self.sigmoid(self.rnnWz_fw(h_x_fw))
- h_tilde_fw = self.tanh(self.rnnWh_fw(self.concat1((r_fw * h_fw, x[0, :, :]))))
- h_fw = (self.ones - z_fw) * h_fw + z_fw * h_tilde_fw
- output_fw = self.expand_dims(h_fw, 0)
-
- for i in range(1, F.shape(x)[0]):
- h_x_fw = self.concat1((h_fw, x[i, :, :]))
- r_fw = self.sigmoid(self.rnnWr_fw(h_x_fw))
- z_fw = self.sigmoid(self.rnnWz_fw(h_x_fw))
- h_tilde_fw = self.tanh(self.rnnWh_fw(self.concat1((r_fw * h_fw, x[i, :, :]))))
- h_fw = (self.ones - z_fw) * h_fw + z_fw * h_tilde_fw
- h_after_expand_fw = self.expand_dims(h_fw, 0)
- output_fw = self.concat((output_fw, h_after_expand_fw))
- output_fw = self.cast(output_fw, mstype.float16)
-
- h_bw = self.cast(self.h1, mstype.float16) # bs, num_hidden
-
- h_x_bw = self.concat1((h_bw, x[F.shape(x)[0] - 1, :, :]))
- r_bw = self.sigmoid(self.rnnWr_bw(h_x_bw))
- z_bw = self.sigmoid(self.rnnWz_bw(h_x_bw))
- h_tilde_bw = self.tanh(self.rnnWh_bw(self.concat1((r_bw * h_bw, x[F.shape(x)[0] - 1, :, :]))))
- h_bw = (self.ones - z_bw) * h_bw + z_bw * h_tilde_bw
- output_bw = self.expand_dims(h_bw, 0)
- for i in range(F.shape(x)[0] - 2, -1, -1):
- h_x_bw = self.concat1((h_bw, x[i, :, :]))
- r_bw = self.sigmoid(self.rnnWr_bw(h_x_bw))
- z_bw = self.sigmoid(self.rnnWz_bw(h_x_bw))
- h_tilde_bw = self.tanh(self.rnnWh_bw(self.concat1((r_bw * h_bw, x[i, :, :]))))
- h_bw = (self.ones - z_bw) * h_bw + z_bw * h_tilde_bw
- h_after_expand_bw = self.expand_dims(h_bw, 0)
- output_bw = self.concat((h_after_expand_bw, output_bw))
- output_bw = self.cast(output_bw, mstype.float16)
- if self.cell == 'lstm':
- x = self.embedding(x) # bs, sl, emb_size
- x = self.cast(x, mstype.float16)
- x = self.transpose(x, (1, 0, 2)) # sl, bs, emb_size
- x = self.drop_out(x) # sl,bs, emb_size
-
- h1_fw_init = self.h1 # bs, num_hidden
- c1_fw_init = self.c1 # bs, num_hidden
-
- _, output_fw, _, _, _, _, _, _ = self.lstm(x, self.w1_fw, self.b1_fw, None, h1_fw_init, c1_fw_init)
- output_fw = self.cast(output_fw, mstype.float16) # sl, bs, num_hidden
-
- h1_bw_init = self.h1 # bs, num_hidden
- c1_bw_init = self.c1 # bs, num_hidden
- _, output_bw, _, _, _, _, _, _ = self.lstm(x, self.w1_bw, self.b1_bw, None, h1_bw_init, c1_bw_init)
- output_bw = self.cast(output_bw, mstype.float16) # sl, bs, hidden
-
- c_left = self.concat0((self.left_pad_tensor, output_fw[:F.shape(x)[0] - 1])) # sl, bs, num_hidden
- c_right = self.concat0((output_bw[1:], self.right_pad_tensor)) # sl, bs, num_hidden
- output = self.concat2((c_left, self.cast(x, mstype.float16), c_right)) # sl, bs, 2*num_hidden+emb_size
- output = self.cast(output, mstype.float16)
-
- output_flat = self.reshape(output, (F.shape(x)[0] * self.batch_size, 2 * self.num_hiddens + self.embed_size))
- output_dense = self.text_rep_dense(output_flat) # sl*bs, num_hidden
- output_dense = self.tanh(output_dense) # sl*bs, num_hidden
- output = self.reshape(output_dense, (F.shape(x)[0], self.batch_size, self.num_hiddens)) # sl, bs, num_hidden
- output = self.reduce_max(output, 0) # bs, num_hidden
- outputs = self.cast(self.mydense(output), mstype.float16) # bs, num_classes
- return outputs
|