| @@ -66,13 +66,19 @@ class ImageGradients(Cell): | |||||
| check = _check_input_4d(F.shape(images), "images", self.cls_name) | check = _check_input_4d(F.shape(images), "images", self.cls_name) | ||||
| images = F.depend(images, check) | images = F.depend(images, check) | ||||
| batch_size, depth, height, width = P.Shape()(images) | batch_size, depth, height, width = P.Shape()(images) | ||||
| dy = images[:, :, 1:, :] - images[:, :, :height - 1, :] | |||||
| dy_last = P.Fill()(P.DType()(images), (batch_size, depth, 1, width), 0) | |||||
| dy = P.Concat(2)((dy, dy_last)) | |||||
| dx = images[:, :, :, 1:] - images[:, :, :, :width - 1] | |||||
| dx_last = P.Fill()(P.DType()(images), (batch_size, depth, height, 1), 0) | |||||
| dx = P.Concat(3)((dx, dx_last)) | |||||
| if height == 1: | |||||
| dy = P.Fill()(P.DType()(images), (batch_size, depth, 1, width), 0) | |||||
| else: | |||||
| dy = images[:, :, 1:, :] - images[:, :, :height - 1, :] | |||||
| dy_last = P.Fill()(P.DType()(images), (batch_size, depth, 1, width), 0) | |||||
| dy = P.Concat(2)((dy, dy_last)) | |||||
| if width == 1: | |||||
| dx = P.Fill()(P.DType()(images), (batch_size, depth, height, 1), 0) | |||||
| else: | |||||
| dx = images[:, :, :, 1:] - images[:, :, :, :width - 1] | |||||
| dx_last = P.Fill()(P.DType()(images), (batch_size, depth, height, 1), 0) | |||||
| dx = P.Concat(3)((dx, dx_last)) | |||||
| return dy, dx | return dy, dx | ||||