From 7bda2afb4ceaeab4bbfc3810c30c6f712987a02e Mon Sep 17 00:00:00 2001 From: chenhaozhe Date: Wed, 6 Jan 2021 10:14:01 +0800 Subject: [PATCH] fix bugs of wrong height and width of ssd --- model_zoo/official/cv/ssd/src/anchor_generator.py | 2 +- model_zoo/official/cv/ssd/src/dataset.py | 2 +- model_zoo/official/cv/ssd/src/ssd.py | 9 +++------ 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/model_zoo/official/cv/ssd/src/anchor_generator.py b/model_zoo/official/cv/ssd/src/anchor_generator.py index 62e2676d16..9941032f3a 100644 --- a/model_zoo/official/cv/ssd/src/anchor_generator.py +++ b/model_zoo/official/cv/ssd/src/anchor_generator.py @@ -39,7 +39,7 @@ class GridAnchorGenerator: scales_grid = scales_grid.reshape([-1]) aspect_ratios_grid = aspect_ratios_grid.reshape([-1]) - feature_size = [self.image_shape[0] / step, self.image_shape[0] / step] + feature_size = [self.image_shape[0] / step, self.image_shape[1] / step] grid_height, grid_width = feature_size base_size = np.array([self.scale * step, self.scale * step]).astype(np.float32) diff --git a/model_zoo/official/cv/ssd/src/dataset.py b/model_zoo/official/cv/ssd/src/dataset.py index f70024e5f1..738bb39784 100644 --- a/model_zoo/official/cv/ssd/src/dataset.py +++ b/model_zoo/official/cv/ssd/src/dataset.py @@ -122,7 +122,7 @@ def preprocess_fn(img_id, image, box, is_training): def _data_aug(image, box, is_training, image_size=(300, 300)): """Data augmentation function.""" ih, iw, _ = image.shape - w, h = image_size + h, w = image_size if not is_training: return _infer_data(image, image_size) diff --git a/model_zoo/official/cv/ssd/src/ssd.py b/model_zoo/official/cv/ssd/src/ssd.py index 6e9d1df45d..d9e4b82ddf 100644 --- a/model_zoo/official/cv/ssd/src/ssd.py +++ b/model_zoo/official/cv/ssd/src/ssd.py @@ -356,19 +356,16 @@ class SsdMobilenetV1Fpn(nn.Cell): Examples:backbone SsdMobilenetV1Fpn(config, True). """ - def __init__(self, config, is_training=True): + def __init__(self, config): super(SsdMobilenetV1Fpn, self).__init__() self.multi_box = WeightSharedMultiBox(config) - self.is_training = is_training - if not is_training: - self.activation = P.Sigmoid() - + self.activation = P.Sigmoid() self.feature_extractor = mobilenet_v1_fpn(config) def construct(self, x): features = self.feature_extractor(x) pred_loc, pred_label = self.multi_box(features) - if not self.is_training: + if not self.training: pred_label = self.activation(pred_label) pred_loc = F.cast(pred_loc, mstype.float32) pred_label = F.cast(pred_label, mstype.float32)