diff --git a/mindspore/nn/layer/math.py b/mindspore/nn/layer/math.py index 15a3100650..3a3024628c 100644 --- a/mindspore/nn/layer/math.py +++ b/mindspore/nn/layer/math.py @@ -565,7 +565,7 @@ class IGamma(Cell): Above :math:`Q(a, x)` is the upper regularized complete Gamma function. Supported Platforms: - ``Ascend`` + ``Ascend`` ``GPU`` Inputs: - **a** (Tensor) - The input tensor. With float32 data type. `a` should have @@ -619,8 +619,8 @@ class IGamma(Cell): use_igammac = self.logicaland(self.greater(x, 1), self.greater(x, a)) ax = a * self.log(x) - x - self.lgamma(a) para_shape = self.shape(ax) - broadcastto = P.BroadcastTo(para_shape) if para_shape != (): + broadcastto = P.BroadcastTo(para_shape) x = broadcastto(x) a = broadcastto(a) x_is_zero = self.equal(x, 0) @@ -694,8 +694,8 @@ class LBeta(Cell): _check_input_dtype("y", y_dtype, x_dtype, self.cls_name) x_plus_y = x + y para_shape = self.shape(x_plus_y) - broadcastto = P.BroadcastTo(para_shape) if para_shape != (): + broadcastto = P.BroadcastTo(para_shape) x = broadcastto(x) y = broadcastto(y) comp_less = self.less(x, y)