Browse Source

!8760 Add broadcast for a and x support for IGamma

From: @peixu_ren
Reviewed-by: @liangchenghui,@zh_qh
Signed-off-by: @zh_qh
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
2424b8bd19
1 changed files with 4 additions and 1 deletions
  1. +4
    -1
      mindspore/nn/layer/math.py

+ 4
- 1
mindspore/nn/layer/math.py View File

@@ -646,10 +646,13 @@ class IGamma(Cell):
x_dtype = self.dtype(x)
_check_input_dtype("input_a", a_dtype, [mstype.float16, mstype.float32], self.cls_name)
_check_input_dtype("input_x", x_dtype, a_dtype, self.cls_name)
x_is_zero = self.equal(x, 0)
domain_error = self.logicalor(self.less(x, 0), self.less(a, 0))
use_igammac = self.logicaland(self.greater(x, 1), self.greater(x, a))
ax = a * self.log(x) - x - self.lgamma(a)
boradcastto = P.BroadcastTo(self.shape(ax))
a = boradcastto(a)
x = boradcastto(x)
x_is_zero = self.equal(x, 0)
if a_dtype == mstype.float16:
log_maxfloat = self.log_maxfloat16
else:


Loading…
Cancel
Save