|
|
|
@@ -338,7 +338,8 @@ class MSSSIM(Cell): |
|
|
|
def construct(self, img1, img2): |
|
|
|
_check_input_4d(F.shape(img1), "img1", self.cls_name) |
|
|
|
_check_input_4d(F.shape(img2), "img2", self.cls_name) |
|
|
|
_check_input_dtype(F.dtype(img1), 'img1', mstype.number_type, self.cls_name) |
|
|
|
valid_type = [mstype.float64, mstype.float32, mstype.float16, mstype.uint8] |
|
|
|
_check_input_dtype(F.dtype(img1), 'img1', valid_type, self.cls_name) |
|
|
|
P.SameTypeShape()(img1, img2) |
|
|
|
dtype_max_val = _get_dtype_max(F.dtype(img1)) |
|
|
|
max_val = F.scalar_cast(self.max_val, F.dtype(img1)) |
|
|
|
|