Browse Source

!12517 enhance for ctpn performance

From: @qujianwei
Reviewed-by: @c_34,@oacjiewen
Signed-off-by: @c_34
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
264f265de8
6 changed files with 13 additions and 14 deletions
  1. +2
    -2
      model_zoo/official/cv/ctpn/src/CTPN/anchor_generator.py
  2. +1
    -1
      model_zoo/official/cv/ctpn/src/CTPN/bbox_assign_sample.py
  3. +1
    -1
      model_zoo/official/cv/ctpn/src/CTPN/proposal_generator.py
  4. +6
    -7
      model_zoo/official/cv/ctpn/src/CTPN/rpn.py
  5. +1
    -1
      model_zoo/official/cv/ctpn/src/dataset.py
  6. +2
    -2
      model_zoo/official/cv/ctpn/src/network_define.py

+ 2
- 2
model_zoo/official/cv/ctpn/src/CTPN/anchor_generator.py View File

@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FasterRcnn anchor generator."""
"""CTPN anchor generator."""
import numpy as np
class AnchorGenerator():
"""Anchor generator for FasterRcnn."""
"""Anchor generator for CTPN."""
def __init__(self, config):
"""Anchor generator init method."""
self.base_size = config.anchor_base


+ 1
- 1
model_zoo/official/cv/ctpn/src/CTPN/bbox_assign_sample.py View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FasterRcnn positive and negative sample screening for RPN."""
"""CTPN positive and negative sample screening for RPN."""

import numpy as np
import mindspore.nn as nn


+ 1
- 1
model_zoo/official/cv/ctpn/src/CTPN/proposal_generator.py View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FasterRcnn proposal generator."""
"""CTPN proposal generator."""

import numpy as np
import mindspore.nn as nn


+ 6
- 7
model_zoo/official/cv/ctpn/src/CTPN/rpn.py View File

@@ -49,25 +49,24 @@ class RpnRegClsBlock(nn.Cell):
self.lstm_fc = nn.Dense(2*config.hidden_size, 512).to_float(mstype.float16)
self.rpn_cls = nn.Dense(in_channels=512, out_channels=num_anchors * cls_out_channels).to_float(mstype.float16)
self.rpn_reg = nn.Dense(in_channels=512, out_channels=num_anchors * 4).to_float(mstype.float16)
self.shape1 = (config.num_step, config.rnn_batch_size, -1)
self.shape2 = (-1, config.batch_size, config.rnn_batch_size, config.num_step)
self.shape1 = (-1, config.num_step, config.rnn_batch_size)
self.shape2 = (config.batch_size, -1, config.rnn_batch_size, config.num_step)
self.transpose = P.Transpose()
self.print = P.Print()
self.dropout = nn.Dropout(0.8)

def construct(self, x):
x = self.reshape(x, self.shape)
x = self.lstm_fc(x)
x1 = self.rpn_cls(x)
x1 = self.transpose(x1, (1, 0))
x1 = self.reshape(x1, self.shape1)
x1 = self.transpose(x1, (2, 1, 0))
x1 = self.transpose(x1, (0, 2, 1))
x1 = self.reshape(x1, self.shape2)
x1 = self.transpose(x1, (1, 0, 2, 3))
x2 = self.rpn_reg(x)
x2 = self.transpose(x2, (1, 0))
x2 = self.reshape(x2, self.shape1)
x2 = self.transpose(x2, (2, 1, 0))
x2 = self.transpose(x2, (0, 2, 1))
x2 = self.reshape(x2, self.shape2)
x2 = self.transpose(x2, (1, 0, 2, 3))
return x1, x2

class RPN(nn.Cell):


+ 1
- 1
model_zoo/official/cv/ctpn/src/dataset.py View File

@@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================

"""FasterRcnn dataset"""
"""CTPN dataset"""
from __future__ import division
import os
import numpy as np


+ 2
- 2
model_zoo/official/cv/ctpn/src/network_define.py View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FasterRcnn training network wrapper."""
"""CTPN training network wrapper."""

import time
import numpy as np
@@ -82,7 +82,7 @@ class LossCallBack(Callback):
loss_file.close()

class LossNet(nn.Cell):
"""FasterRcnn loss method"""
"""CTPN loss method"""
def construct(self, x1, x2, x3):
return x1



Loading…
Cancel
Save