|
|
|
@@ -58,6 +58,7 @@ class ImageGradients(Cell): |
|
|
|
super(ImageGradients, self).__init__() |
|
|
|
|
|
|
|
def construct(self, images): |
|
|
|
_check_input_4d(F.shape(images), "images", self.cls_name) |
|
|
|
batch_size, depth, height, width = P.Shape()(images) |
|
|
|
dy = images[:, :, 1:, :] - images[:, :, :height - 1, :] |
|
|
|
dy_last = P.Fill()(P.DType()(images), (batch_size, depth, 1, width), 0) |
|
|
|
@@ -151,8 +152,8 @@ class SSIM(Cell): |
|
|
|
self.mean = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=filter_size) |
|
|
|
|
|
|
|
def construct(self, img1, img2): |
|
|
|
_check_input_4d(F.shape(img1), "img1", "SSIM") |
|
|
|
_check_input_4d(F.shape(img2), "img2", "SSIM") |
|
|
|
_check_input_4d(F.shape(img1), "img1", self.cls_name) |
|
|
|
_check_input_4d(F.shape(img2), "img2", self.cls_name) |
|
|
|
P.SameTypeShape()(img1, img2) |
|
|
|
max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val) |
|
|
|
img1 = _convert_img_dtype_to_float32(img1, self.max_val) |
|
|
|
@@ -244,8 +245,8 @@ class PSNR(Cell): |
|
|
|
self.max_val = max_val |
|
|
|
|
|
|
|
def construct(self, img1, img2): |
|
|
|
_check_input_4d(F.shape(img1), "img1", "PSNR") |
|
|
|
_check_input_4d(F.shape(img2), "img2", "PSNR") |
|
|
|
_check_input_4d(F.shape(img1), "img1", self.cls_name) |
|
|
|
_check_input_4d(F.shape(img2), "img2", self.cls_name) |
|
|
|
P.SameTypeShape()(img1, img2) |
|
|
|
max_val = _convert_img_dtype_to_float32(self.max_val, self.max_val) |
|
|
|
img1 = _convert_img_dtype_to_float32(img1, self.max_val) |
|
|
|
|