Browse Source

IGamma can be supported on GPU, add the GPU platform and fix some bugs

tags/v1.1.0
peixu_ren 5 years ago
parent
commit
db8ec0d281
1 changed files with 3 additions and 3 deletions
  1. +3
    -3
      mindspore/nn/layer/math.py

+ 3
- 3
mindspore/nn/layer/math.py View File

@@ -565,7 +565,7 @@ class IGamma(Cell):
Above :math:`Q(a, x)` is the upper regularized complete Gamma function.

Supported Platforms:
``Ascend``
``Ascend`` ``GPU``

Inputs:
- **a** (Tensor) - The input tensor. With float32 data type. `a` should have
@@ -619,8 +619,8 @@ class IGamma(Cell):
use_igammac = self.logicaland(self.greater(x, 1), self.greater(x, a))
ax = a * self.log(x) - x - self.lgamma(a)
para_shape = self.shape(ax)
broadcastto = P.BroadcastTo(para_shape)
if para_shape != ():
broadcastto = P.BroadcastTo(para_shape)
x = broadcastto(x)
a = broadcastto(a)
x_is_zero = self.equal(x, 0)
@@ -694,8 +694,8 @@ class LBeta(Cell):
_check_input_dtype("y", y_dtype, x_dtype, self.cls_name)
x_plus_y = x + y
para_shape = self.shape(x_plus_y)
broadcastto = P.BroadcastTo(para_shape)
if para_shape != ():
broadcastto = P.BroadcastTo(para_shape)
x = broadcastto(x)
y = broadcastto(y)
comp_less = self.less(x, y)


Loading…
Cancel
Save