From 7ae7505caab4fd88b74fde53464afca2cff7018a Mon Sep 17 00:00:00 2001 From: zhouyaqiang Date: Mon, 10 Aug 2020 09:39:52 +0800 Subject: [PATCH] move argmax from host to device --- model_zoo/official/cv/deeplabv3/src/deeplabv3.py | 3 +++ model_zoo/official/cv/deeplabv3/src/losses.py | 2 ++ model_zoo/official/cv/deeplabv3/src/miou_precision.py | 5 +---- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/model_zoo/official/cv/deeplabv3/src/deeplabv3.py b/model_zoo/official/cv/deeplabv3/src/deeplabv3.py index bbfc4dceb3..7b3c8eb53b 100644 --- a/model_zoo/official/cv/deeplabv3/src/deeplabv3.py +++ b/model_zoo/official/cv/deeplabv3/src/deeplabv3.py @@ -381,6 +381,7 @@ class DeepLabV3(nn.Cell): self.concat = P.Concat(axis=2) self.expand_dims = P.ExpandDims() self.reduce_mean = P.ReduceMean() + self.argmax = P.Argmax(axis=1) self.sample_common = P.ResizeBilinear((int(feature_shape[2]), int(feature_shape[3])), align_corners=True) @@ -419,6 +420,8 @@ class DeepLabV3(nn.Cell): logits_i = self.expand_dims(logits_i, 2) logits = self.concat((logits, logits_i)) logits = self.reduce_mean(logits, 2) + if not self.training: + logits = self.argmax(logits) return logits diff --git a/model_zoo/official/cv/deeplabv3/src/losses.py b/model_zoo/official/cv/deeplabv3/src/losses.py index af782c2de9..db45cbb6b6 100644 --- a/model_zoo/official/cv/deeplabv3/src/losses.py +++ b/model_zoo/official/cv/deeplabv3/src/losses.py @@ -42,6 +42,8 @@ class OhemLoss(nn.Cell): self.loss_weight = 1.0 def construct(self, logits, labels): + if not self.training: + return 0 logits = self.transpose(logits, (0, 2, 3, 1)) logits = self.reshape(logits, (-1, self.num)) labels = F.cast(labels, mstype.int32) diff --git a/model_zoo/official/cv/deeplabv3/src/miou_precision.py b/model_zoo/official/cv/deeplabv3/src/miou_precision.py index b73b3947d4..8b3e2f5d08 100644 --- a/model_zoo/official/cv/deeplabv3/src/miou_precision.py +++ b/model_zoo/official/cv/deeplabv3/src/miou_precision.py @@ -50,10 +50,7 @@ class MiouPrecision(Metric): raise ValueError('Need 2 inputs (y_pred, y), but got {}'.format(len(inputs))) predict_in = self._convert_data(inputs[0]) label_in = self._convert_data(inputs[1]) - if predict_in.shape[1] != self._num_class: - raise ValueError('Class number not match, last input data contain {} classes, but current data contain {} ' - 'classes'.format(self._num_class, predict_in.shape[1])) - pred = np.argmax(predict_in, axis=1) + pred = predict_in label = label_in if len(label.flatten()) != len(pred.flatten()): print('Skipping: len(gt) = {:d}, len(pred) = {:d}'.format(len(label.flatten()), len(pred.flatten())))