|
|
|
@@ -143,7 +143,7 @@ class Rcnn(nn.Cell): |
|
|
|
|
|
|
|
if self.training: |
|
|
|
bbox_weights = self.cast(self.logicaland(self.greater(labels, 0), mask), mstype.int32) * labels |
|
|
|
labels = self.cast(self.onehot(labels, self.num_classes, self.on_value, self.off_value), self.ms_type) |
|
|
|
labels = self.onehot(labels, self.num_classes, self.on_value, self.off_value) |
|
|
|
bbox_targets = self.tile(self.expandims(bbox_targets, 1), (1, self.num_classes, 1)) |
|
|
|
|
|
|
|
loss, loss_cls, loss_reg, loss_print = self.loss(x_cls, x_reg, bbox_targets, bbox_weights, labels, mask) |
|
|
|
|