Browse Source

move set_seed() out from construct() of Net

tags/v1.0.0
Yi Huaijie 5 years ago
parent
commit
b863324d90
4 changed files with 4 additions and 8 deletions
  1. +1
    -2
      tests/st/ops/ascend/test_compoite_random_ops/test_gamma.py
  2. +1
    -2
      tests/st/ops/ascend/test_compoite_random_ops/test_normal.py
  3. +1
    -2
      tests/st/ops/ascend/test_compoite_random_ops/test_poisson.py
  4. +1
    -2
      tests/st/ops/ascend/test_compoite_random_ops/test_uniform.py

+ 1
- 2
tests/st/ops/ascend/test_compoite_random_ops/test_gamma.py View File

@@ -22,7 +22,7 @@ from mindspore.ops import composite as C
from mindspore.common import set_seed

context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
set_seed(20)

class Net(nn.Cell):
def __init__(self, shape, seed=0):
@@ -31,7 +31,6 @@ class Net(nn.Cell):
self.seed = seed

def construct(self, alpha, beta):
set_seed(20)
return C.gamma(self.shape, alpha, beta, self.seed)




+ 1
- 2
tests/st/ops/ascend/test_compoite_random_ops/test_normal.py View File

@@ -22,7 +22,7 @@ from mindspore.ops import composite as C
from mindspore.common import set_seed

context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
set_seed(20)

class Net(nn.Cell):
def __init__(self, shape, seed=0):
@@ -31,7 +31,6 @@ class Net(nn.Cell):
self.seed = seed

def construct(self, mean, stddev):
set_seed(20)
return C.normal(self.shape, mean, stddev, self.seed)




+ 1
- 2
tests/st/ops/ascend/test_compoite_random_ops/test_poisson.py View File

@@ -22,7 +22,7 @@ from mindspore.ops import composite as C
from mindspore.common import set_seed

context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
set_seed(20)

class Net(nn.Cell):
def __init__(self, shape, seed=0):
@@ -31,7 +31,6 @@ class Net(nn.Cell):
self.seed = seed

def construct(self, mean):
set_seed(20)
return C.poisson(self.shape, mean, self.seed)




+ 1
- 2
tests/st/ops/ascend/test_compoite_random_ops/test_uniform.py View File

@@ -22,7 +22,7 @@ from mindspore.ops import composite as C
from mindspore.common import set_seed

context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
set_seed(20)

class Net(nn.Cell):
def __init__(self, shape, seed=0):
@@ -31,7 +31,6 @@ class Net(nn.Cell):
self.seed = seed

def construct(self, minval, maxval):
set_seed(20)
return C.uniform(self.shape, minval, maxval, self.seed)




Loading…
Cancel
Save