Browse Source

GPU Lstm network

tags/v0.6.0-beta
wilfChen 5 years ago
parent
commit
fd2fa2bcba
1 changed files with 1 additions and 1 deletions
  1. +1
    -1
      model_zoo/lstm/src/lstm.py

+ 1
- 1
model_zoo/lstm/src/lstm.py View File

@@ -88,6 +88,6 @@ class SentimentNet(nn.Cell):
embeddings = self.trans(embeddings, self.perm) embeddings = self.trans(embeddings, self.perm)
output, _ = self.encoder(embeddings, (self.h, self.c)) output, _ = self.encoder(embeddings, (self.h, self.c))
# states[i] size(64,200) -> encoding.size(64,400) # states[i] size(64,200) -> encoding.size(64,400)
encoding = self.concat((output[0], output[-1]))
encoding = self.concat((output[0], output[499]))
outputs = self.decoder(encoding) outputs = self.decoder(encoding)
return outputs return outputs

Loading…
Cancel
Save