|
|
|
@@ -565,7 +565,7 @@ class IGamma(Cell): |
|
|
|
Above :math:`Q(a, x)` is the upper regularized complete Gamma function. |
|
|
|
|
|
|
|
Supported Platforms: |
|
|
|
``Ascend`` |
|
|
|
``Ascend`` ``GPU`` |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **a** (Tensor) - The input tensor. With float32 data type. `a` should have |
|
|
|
@@ -619,8 +619,8 @@ class IGamma(Cell): |
|
|
|
use_igammac = self.logicaland(self.greater(x, 1), self.greater(x, a)) |
|
|
|
ax = a * self.log(x) - x - self.lgamma(a) |
|
|
|
para_shape = self.shape(ax) |
|
|
|
broadcastto = P.BroadcastTo(para_shape) |
|
|
|
if para_shape != (): |
|
|
|
broadcastto = P.BroadcastTo(para_shape) |
|
|
|
x = broadcastto(x) |
|
|
|
a = broadcastto(a) |
|
|
|
x_is_zero = self.equal(x, 0) |
|
|
|
@@ -694,8 +694,8 @@ class LBeta(Cell): |
|
|
|
_check_input_dtype("y", y_dtype, x_dtype, self.cls_name) |
|
|
|
x_plus_y = x + y |
|
|
|
para_shape = self.shape(x_plus_y) |
|
|
|
broadcastto = P.BroadcastTo(para_shape) |
|
|
|
if para_shape != (): |
|
|
|
broadcastto = P.BroadcastTo(para_shape) |
|
|
|
x = broadcastto(x) |
|
|
|
y = broadcastto(y) |
|
|
|
comp_less = self.less(x, y) |
|
|
|
|