|
|
|
@@ -1,4 +1,4 @@ |
|
|
|
# Copyright 2020 Huawei Technologies Co., Ltd |
|
|
|
# Copyright 2020-2022 Huawei Technologies Co., Ltd |
|
|
|
# |
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
|
# you may not use this file except in compliance with the License. |
|
|
|
@@ -62,7 +62,7 @@ class StandardNormal(PrimitiveWithInfer): |
|
|
|
def __init__(self, seed=0, seed2=0): |
|
|
|
"""Initialize StandardNormal""" |
|
|
|
self.init_prim_io_names(inputs=['shape'], outputs=['output']) |
|
|
|
self.add_prim_attr("_random_effect", True) |
|
|
|
self.add_prim_attr("side_effect_hidden", True) |
|
|
|
Validator.check_non_negative_int(seed, "seed", self.name) |
|
|
|
Validator.check_non_negative_int(seed2, "seed2", self.name) |
|
|
|
|
|
|
|
@@ -119,7 +119,7 @@ class StandardLaplace(PrimitiveWithInfer): |
|
|
|
def __init__(self, seed=0, seed2=0): |
|
|
|
"""Initialize StandardLaplace""" |
|
|
|
self.init_prim_io_names(inputs=['shape'], outputs=['output']) |
|
|
|
self.add_prim_attr("_random_effect", True) |
|
|
|
self.add_prim_attr("side_effect_hidden", True) |
|
|
|
Validator.check_value_type('seed', seed, [int], self.name) |
|
|
|
Validator.check_value_type('seed2', seed2, [int], self.name) |
|
|
|
|
|
|
|
@@ -196,7 +196,7 @@ class Gamma(PrimitiveWithInfer): |
|
|
|
def __init__(self, seed=0, seed2=0): |
|
|
|
"""Initialize Gamma""" |
|
|
|
self.init_prim_io_names(inputs=['shape', 'alpha', 'beta'], outputs=['output']) |
|
|
|
self.add_prim_attr("_random_effect", True) |
|
|
|
self.add_prim_attr("side_effect_hidden", True) |
|
|
|
Validator.check_non_negative_int(seed, "seed", self.name) |
|
|
|
Validator.check_non_negative_int(seed2, "seed2", self.name) |
|
|
|
|
|
|
|
@@ -262,7 +262,7 @@ class Poisson(PrimitiveWithInfer): |
|
|
|
def __init__(self, seed=0, seed2=0): |
|
|
|
"""Initialize Poisson""" |
|
|
|
self.init_prim_io_names(inputs=['shape', 'mean'], outputs=['output']) |
|
|
|
self.add_prim_attr("_random_effect", True) |
|
|
|
self.add_prim_attr("side_effect_hidden", True) |
|
|
|
Validator.check_non_negative_int(seed, "seed", self.name) |
|
|
|
Validator.check_non_negative_int(seed2, "seed2", self.name) |
|
|
|
|
|
|
|
@@ -334,7 +334,7 @@ class UniformInt(PrimitiveWithInfer): |
|
|
|
def __init__(self, seed=0, seed2=0): |
|
|
|
"""Initialize UniformInt""" |
|
|
|
self.init_prim_io_names(inputs=['shape', 'minval', 'maxval'], outputs=['output']) |
|
|
|
self.add_prim_attr("_random_effect", True) |
|
|
|
self.add_prim_attr("side_effect_hidden", True) |
|
|
|
Validator.check_non_negative_int(seed, "seed", self.name) |
|
|
|
Validator.check_non_negative_int(seed2, "seed2", self.name) |
|
|
|
|
|
|
|
@@ -451,7 +451,7 @@ class RandomChoiceWithMask(PrimitiveWithInfer): |
|
|
|
Validator.check_positive_int(count, "count", self.name) |
|
|
|
Validator.check_value_type('seed', seed, [int], self.name) |
|
|
|
Validator.check_value_type('seed2', seed2, [int], self.name) |
|
|
|
self.add_prim_attr("_random_effect", True) |
|
|
|
self.add_prim_attr("side_effect_hidden", True) |
|
|
|
|
|
|
|
def infer_shape(self, x_shape): |
|
|
|
Validator.check_int(len(x_shape), 1, Rel.GE, "input_x rank", self.name) |
|
|
|
@@ -513,7 +513,7 @@ class RandomCategorical(PrimitiveWithInfer): |
|
|
|
Validator.check_type_name("dtype", dtype, valid_values, self.name) |
|
|
|
self.init_prim_io_names(inputs=['logits', 'num_samples', 'seed'], |
|
|
|
outputs=['output']) |
|
|
|
self.add_prim_attr("_random_effect", True) |
|
|
|
self.add_prim_attr("side_effect_hidden", True) |
|
|
|
|
|
|
|
def __infer__(self, logits, num_samples, seed): |
|
|
|
logits_dtype = logits['dtype'] |
|
|
|
@@ -580,7 +580,7 @@ class Multinomial(PrimitiveWithInfer): |
|
|
|
Validator.check_non_negative_int(seed, "seed", self.name) |
|
|
|
Validator.check_non_negative_int(seed2, "seed2", self.name) |
|
|
|
self.init_prim_io_names(inputs=['input', 'num_sample'], outputs=['output']) |
|
|
|
self.add_prim_attr("_random_effect", True) |
|
|
|
self.add_prim_attr("side_effect_hidden", True) |
|
|
|
|
|
|
|
def __infer__(self, inputs, num_samples): |
|
|
|
input_shape = inputs["shape"] |
|
|
|
|