Merge pull request !4181 from zhouyaqiang0/mastertags/v0.7.0-beta
| @@ -381,6 +381,7 @@ class DeepLabV3(nn.Cell): | |||||
| self.concat = P.Concat(axis=2) | self.concat = P.Concat(axis=2) | ||||
| self.expand_dims = P.ExpandDims() | self.expand_dims = P.ExpandDims() | ||||
| self.reduce_mean = P.ReduceMean() | self.reduce_mean = P.ReduceMean() | ||||
| self.argmax = P.Argmax(axis=1) | |||||
| self.sample_common = P.ResizeBilinear((int(feature_shape[2]), | self.sample_common = P.ResizeBilinear((int(feature_shape[2]), | ||||
| int(feature_shape[3])), | int(feature_shape[3])), | ||||
| align_corners=True) | align_corners=True) | ||||
| @@ -419,6 +420,8 @@ class DeepLabV3(nn.Cell): | |||||
| logits_i = self.expand_dims(logits_i, 2) | logits_i = self.expand_dims(logits_i, 2) | ||||
| logits = self.concat((logits, logits_i)) | logits = self.concat((logits, logits_i)) | ||||
| logits = self.reduce_mean(logits, 2) | logits = self.reduce_mean(logits, 2) | ||||
| if not self.training: | |||||
| logits = self.argmax(logits) | |||||
| return logits | return logits | ||||
| @@ -42,6 +42,8 @@ class OhemLoss(nn.Cell): | |||||
| self.loss_weight = 1.0 | self.loss_weight = 1.0 | ||||
| def construct(self, logits, labels): | def construct(self, logits, labels): | ||||
| if not self.training: | |||||
| return 0 | |||||
| logits = self.transpose(logits, (0, 2, 3, 1)) | logits = self.transpose(logits, (0, 2, 3, 1)) | ||||
| logits = self.reshape(logits, (-1, self.num)) | logits = self.reshape(logits, (-1, self.num)) | ||||
| labels = F.cast(labels, mstype.int32) | labels = F.cast(labels, mstype.int32) | ||||
| @@ -50,10 +50,7 @@ class MiouPrecision(Metric): | |||||
| raise ValueError('Need 2 inputs (y_pred, y), but got {}'.format(len(inputs))) | raise ValueError('Need 2 inputs (y_pred, y), but got {}'.format(len(inputs))) | ||||
| predict_in = self._convert_data(inputs[0]) | predict_in = self._convert_data(inputs[0]) | ||||
| label_in = self._convert_data(inputs[1]) | label_in = self._convert_data(inputs[1]) | ||||
| if predict_in.shape[1] != self._num_class: | |||||
| raise ValueError('Class number not match, last input data contain {} classes, but current data contain {} ' | |||||
| 'classes'.format(self._num_class, predict_in.shape[1])) | |||||
| pred = np.argmax(predict_in, axis=1) | |||||
| pred = predict_in | |||||
| label = label_in | label = label_in | ||||
| if len(label.flatten()) != len(pred.flatten()): | if len(label.flatten()) != len(pred.flatten()): | ||||
| print('Skipping: len(gt) = {:d}, len(pred) = {:d}'.format(len(label.flatten()), len(pred.flatten()))) | print('Skipping: len(gt) = {:d}, len(pred) = {:d}'.format(len(label.flatten()), len(pred.flatten()))) | ||||