|
|
|
@@ -646,10 +646,13 @@ class IGamma(Cell): |
|
|
|
x_dtype = self.dtype(x) |
|
|
|
_check_input_dtype("input_a", a_dtype, [mstype.float16, mstype.float32], self.cls_name) |
|
|
|
_check_input_dtype("input_x", x_dtype, a_dtype, self.cls_name) |
|
|
|
x_is_zero = self.equal(x, 0) |
|
|
|
domain_error = self.logicalor(self.less(x, 0), self.less(a, 0)) |
|
|
|
use_igammac = self.logicaland(self.greater(x, 1), self.greater(x, a)) |
|
|
|
ax = a * self.log(x) - x - self.lgamma(a) |
|
|
|
boradcastto = P.BroadcastTo(self.shape(ax)) |
|
|
|
a = boradcastto(a) |
|
|
|
x = boradcastto(x) |
|
|
|
x_is_zero = self.equal(x, 0) |
|
|
|
if a_dtype == mstype.float16: |
|
|
|
log_maxfloat = self.log_maxfloat16 |
|
|
|
else: |
|
|
|
|