From: @c_34 Reviewed-by: @wuxuejian,@zhouyaqiang0,@oacjiewen Signed-off-by: @oacjiewentags/v1.2.0-rc1
| @@ -39,7 +39,7 @@ class GridAnchorGenerator: | |||||
| scales_grid = scales_grid.reshape([-1]) | scales_grid = scales_grid.reshape([-1]) | ||||
| aspect_ratios_grid = aspect_ratios_grid.reshape([-1]) | aspect_ratios_grid = aspect_ratios_grid.reshape([-1]) | ||||
| feature_size = [self.image_shape[0] / step, self.image_shape[0] / step] | |||||
| feature_size = [self.image_shape[0] / step, self.image_shape[1] / step] | |||||
| grid_height, grid_width = feature_size | grid_height, grid_width = feature_size | ||||
| base_size = np.array([self.scale * step, self.scale * step]).astype(np.float32) | base_size = np.array([self.scale * step, self.scale * step]).astype(np.float32) | ||||
| @@ -122,7 +122,7 @@ def preprocess_fn(img_id, image, box, is_training): | |||||
| def _data_aug(image, box, is_training, image_size=(300, 300)): | def _data_aug(image, box, is_training, image_size=(300, 300)): | ||||
| """Data augmentation function.""" | """Data augmentation function.""" | ||||
| ih, iw, _ = image.shape | ih, iw, _ = image.shape | ||||
| w, h = image_size | |||||
| h, w = image_size | |||||
| if not is_training: | if not is_training: | ||||
| return _infer_data(image, image_size) | return _infer_data(image, image_size) | ||||
| @@ -356,19 +356,16 @@ class SsdMobilenetV1Fpn(nn.Cell): | |||||
| Examples:backbone | Examples:backbone | ||||
| SsdMobilenetV1Fpn(config, True). | SsdMobilenetV1Fpn(config, True). | ||||
| """ | """ | ||||
| def __init__(self, config, is_training=True): | |||||
| def __init__(self, config): | |||||
| super(SsdMobilenetV1Fpn, self).__init__() | super(SsdMobilenetV1Fpn, self).__init__() | ||||
| self.multi_box = WeightSharedMultiBox(config) | self.multi_box = WeightSharedMultiBox(config) | ||||
| self.is_training = is_training | |||||
| if not is_training: | |||||
| self.activation = P.Sigmoid() | |||||
| self.activation = P.Sigmoid() | |||||
| self.feature_extractor = mobilenet_v1_fpn(config) | self.feature_extractor = mobilenet_v1_fpn(config) | ||||
| def construct(self, x): | def construct(self, x): | ||||
| features = self.feature_extractor(x) | features = self.feature_extractor(x) | ||||
| pred_loc, pred_label = self.multi_box(features) | pred_loc, pred_label = self.multi_box(features) | ||||
| if not self.is_training: | |||||
| if not self.training: | |||||
| pred_label = self.activation(pred_label) | pred_label = self.activation(pred_label) | ||||
| pred_loc = F.cast(pred_loc, mstype.float32) | pred_loc = F.cast(pred_loc, mstype.float32) | ||||
| pred_label = F.cast(pred_label, mstype.float32) | pred_label = F.cast(pred_label, mstype.float32) | ||||