Browse Source

fix textcnn problem

tags/v1.1.0
Yanjun Peng 5 years ago
parent
commit
f76758cbe8
1 changed files with 6 additions and 15 deletions
  1. +6
    -15
      model_zoo/official/nlp/textcnn/src/textcnn.py

+ 6
- 15
model_zoo/official/nlp/textcnn/src/textcnn.py View File

@@ -14,7 +14,6 @@
# ============================================================================
"""TextCNN"""

import numpy as np
import mindspore.nn as nn
import mindspore.ops.operations as P
from mindspore import Tensor
@@ -89,17 +88,9 @@ class SoftmaxCrossEntropyExpand(Cell):

return loss


def _weight_variable(shape, factor=0.01):
init_value = np.random.randn(*shape).astype(np.float32) * factor
return Tensor(init_value)


def make_conv_layer(kernel_size):
weight_shape = (96, 1, *kernel_size)
weight = _weight_variable(weight_shape)
return nn.Conv2d(in_channels=1, out_channels=96, kernel_size=kernel_size, padding=1,
pad_mode="pad", weight_init=weight, has_bias=True)
pad_mode="pad", weight_init='uniform', has_bias=True)


class TextCNN(nn.Cell):
@@ -113,7 +104,7 @@ class TextCNN(nn.Cell):
self.num_classes = num_classes

self.unsqueeze = P.ExpandDims()
self.embedding = nn.Embedding(vocab_len, self.vec_length, embedding_table='normal')
self.embedding = nn.Embedding(vocab_len, self.vec_length, embedding_table='uniform')

self.slice = P.Slice()
self.layer1 = self.make_layer(kernel_height=3)
@@ -125,7 +116,7 @@ class TextCNN(nn.Cell):
self.fc = nn.Dense(96*3, self.num_classes)
self.drop = nn.Dropout(keep_prob=0.5)
self.print = P.Print()
self.reducemean = P.ReduceMax(keep_dims=False)
self.reducemax = P.ReduceMax(keep_dims=False)

def make_layer(self, kernel_height):
return nn.SequentialCell(
@@ -145,9 +136,9 @@ class TextCNN(nn.Cell):
x2 = self.layer2(x)
x3 = self.layer3(x)

x1 = self.reducemean(x1, (2, 3))
x2 = self.reducemean(x2, (2, 3))
x3 = self.reducemean(x3, (2, 3))
x1 = self.reducemax(x1, (2, 3))
x2 = self.reducemax(x2, (2, 3))
x3 = self.reducemax(x3, (2, 3))

x = self.concat((x1, x2, x3))
x = self.drop(x)


Loading…
Cancel
Save