|
|
|
@@ -20,7 +20,6 @@ from .. import functional as F |
|
|
|
from ..primitive import constexpr |
|
|
|
from .multitype_ops import _constexpr_utils as const_utils |
|
|
|
from ...common import dtype as mstype |
|
|
|
from ...common.tensor import Tensor |
|
|
|
from ..._checkparam import Validator as validator |
|
|
|
from ..._checkparam import check_int_positive |
|
|
|
from ..._checkparam import Rel |
|
|
|
@@ -134,9 +133,7 @@ def multinomial(inputs, num_sample=None, replacement=True, seed=0): |
|
|
|
n_dist = 1 |
|
|
|
if len(shape(inputs)) > 1: |
|
|
|
n_dist = shape(inputs)[-2] |
|
|
|
a = Tensor(0.0, mstype.float32) |
|
|
|
b = Tensor(1.0, mstype.float32) |
|
|
|
random_uniform = P.UniformReal(seed=seed)((n_dist * num_sample,), a, b) |
|
|
|
random_uniform = P.UniformReal(seed=seed)((n_dist * num_sample,)) |
|
|
|
if n_dist != 1: |
|
|
|
random_uniform = reshape(random_uniform, (n_dist, num_sample)) |
|
|
|
vals = P.RealDiv()(P.Log()(random_uniform), inputs + 1e-6) |
|
|
|
|