|
|
|
@@ -356,19 +356,16 @@ class SsdMobilenetV1Fpn(nn.Cell): |
|
|
|
Examples:backbone |
|
|
|
SsdMobilenetV1Fpn(config, True). |
|
|
|
""" |
|
|
|
def __init__(self, config, is_training=True): |
|
|
|
def __init__(self, config): |
|
|
|
super(SsdMobilenetV1Fpn, self).__init__() |
|
|
|
self.multi_box = WeightSharedMultiBox(config) |
|
|
|
self.is_training = is_training |
|
|
|
if not is_training: |
|
|
|
self.activation = P.Sigmoid() |
|
|
|
|
|
|
|
self.activation = P.Sigmoid() |
|
|
|
self.feature_extractor = mobilenet_v1_fpn(config) |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
features = self.feature_extractor(x) |
|
|
|
pred_loc, pred_label = self.multi_box(features) |
|
|
|
if not self.is_training: |
|
|
|
if not self.training: |
|
|
|
pred_label = self.activation(pred_label) |
|
|
|
pred_loc = F.cast(pred_loc, mstype.float32) |
|
|
|
pred_label = F.cast(pred_label, mstype.float32) |
|
|
|
|