|
|
|
@@ -49,25 +49,24 @@ class RpnRegClsBlock(nn.Cell): |
|
|
|
self.lstm_fc = nn.Dense(2*config.hidden_size, 512).to_float(mstype.float16) |
|
|
|
self.rpn_cls = nn.Dense(in_channels=512, out_channels=num_anchors * cls_out_channels).to_float(mstype.float16) |
|
|
|
self.rpn_reg = nn.Dense(in_channels=512, out_channels=num_anchors * 4).to_float(mstype.float16) |
|
|
|
self.shape1 = (config.num_step, config.rnn_batch_size, -1) |
|
|
|
self.shape2 = (-1, config.batch_size, config.rnn_batch_size, config.num_step) |
|
|
|
self.shape1 = (-1, config.num_step, config.rnn_batch_size) |
|
|
|
self.shape2 = (config.batch_size, -1, config.rnn_batch_size, config.num_step) |
|
|
|
self.transpose = P.Transpose() |
|
|
|
self.print = P.Print() |
|
|
|
self.dropout = nn.Dropout(0.8) |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
x = self.reshape(x, self.shape) |
|
|
|
x = self.lstm_fc(x) |
|
|
|
x1 = self.rpn_cls(x) |
|
|
|
x1 = self.transpose(x1, (1, 0)) |
|
|
|
x1 = self.reshape(x1, self.shape1) |
|
|
|
x1 = self.transpose(x1, (2, 1, 0)) |
|
|
|
x1 = self.transpose(x1, (0, 2, 1)) |
|
|
|
x1 = self.reshape(x1, self.shape2) |
|
|
|
x1 = self.transpose(x1, (1, 0, 2, 3)) |
|
|
|
x2 = self.rpn_reg(x) |
|
|
|
x2 = self.transpose(x2, (1, 0)) |
|
|
|
x2 = self.reshape(x2, self.shape1) |
|
|
|
x2 = self.transpose(x2, (2, 1, 0)) |
|
|
|
x2 = self.transpose(x2, (0, 2, 1)) |
|
|
|
x2 = self.reshape(x2, self.shape2) |
|
|
|
x2 = self.transpose(x2, (1, 0, 2, 3)) |
|
|
|
return x1, x2 |
|
|
|
|
|
|
|
class RPN(nn.Cell): |
|
|
|
|