From: @Somnus2020 Reviewed-by: @kingxian Signed-off-by: @kingxiantags/v1.1.0
| @@ -32,6 +32,15 @@ class PhiloxGenerator { | |||||
| counter_[3] = static_cast<uint32_t>(seed_ >> 32); | counter_[3] = static_cast<uint32_t>(seed_ >> 32); | ||||
| } | } | ||||
| explicit PhiloxGenerator(uint64_t seed_, uint64_t seed2_) { | |||||
| key_var_[0] = static_cast<uint32_t>(seed_); | |||||
| key_var_[1] = static_cast<uint32_t>(seed_ >> 32); | |||||
| counter_[0] = 0; | |||||
| counter_[1] = 0; | |||||
| counter_[2] = static_cast<uint32_t>(seed2_); | |||||
| counter_[3] = static_cast<uint32_t>(seed2_ >> 32); | |||||
| } | |||||
| ~PhiloxGenerator() = default; | ~PhiloxGenerator() = default; | ||||
| void Jump(); | void Jump(); | ||||
| @@ -20,7 +20,7 @@ | |||||
| #include "runtime/device/cpu/cpu_device_address.h" | #include "runtime/device/cpu/cpu_device_address.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| bool InitRandomNormal(float mean, float stddev, std::vector<int64_t> out_shape, int64_t seed, | |||||
| bool InitRandomNormal(float mean, float stddev, std::vector<int64_t> out_shape, int64_t seed, int64_t seed2, | |||||
| const py::object &output_tensor) { | const py::object &output_tensor) { | ||||
| if (out_shape.size() == 0) { | if (out_shape.size() == 0) { | ||||
| std::cout << "output data shape is error" << std::endl; | std::cout << "output data shape is error" << std::endl; | ||||
| @@ -41,7 +41,8 @@ bool InitRandomNormal(float mean, float stddev, std::vector<int64_t> out_shape, | |||||
| } | } | ||||
| int64_t batchSize = total_count / thread_num; | int64_t batchSize = total_count / thread_num; | ||||
| std::vector<std::thread> threads(thread_num); | std::vector<std::thread> threads(thread_num); | ||||
| mindspore::PhiloxGenerator generator = mindspore::PhiloxGenerator(seed); | |||||
| seed = (seed == 0 && seed2 == 0) ? clock() : seed; | |||||
| mindspore::PhiloxGenerator generator = mindspore::PhiloxGenerator(seed, seed2); | |||||
| if (thread_num != 1) { | if (thread_num != 1) { | ||||
| for (uint32_t i = 0; i < thread_num - 1; i++) { | for (uint32_t i = 0; i < thread_num - 1; i++) { | ||||
| float *offset_ptr = start_ptr + batchSize * i; | float *offset_ptr = start_ptr + batchSize * i; | ||||
| @@ -85,7 +85,7 @@ bool FillRandoms(PhiloxGenerator generator, float *output, int64_t vet_size, int | |||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| bool InitRandomNormal(float mean, float stddev, std::vector<int64_t> out_shape, int64_t seed, | |||||
| bool InitRandomNormal(float mean, float stddev, std::vector<int64_t> out_shape, int64_t seed, int64_t seed2, | |||||
| const py::object &output_tensor); | const py::object &output_tensor); | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -45,16 +45,13 @@ class Initializer: | |||||
| @property | @property | ||||
| def seed(self): | def seed(self): | ||||
| if self._seed is None: | if self._seed is None: | ||||
| seed_ = get_seed() if get_seed() is not None else 1 | |||||
| _, seed = _get_graph_seed(seed_, "init") | |||||
| seed, seed2 = _get_graph_seed(get_seed(), "init") | |||||
| else: | else: | ||||
| seed = self._seed | |||||
| return seed | |||||
| seed, seed2 = self._seed + 1, 0 | |||||
| return seed, seed2 | |||||
| @seed.setter | @seed.setter | ||||
| def seed(self, value): | def seed(self, value): | ||||
| if not isinstance(value, int): | |||||
| raise TypeError("'value' must be int type.") | |||||
| self._seed = value | self._seed = value | ||||
| def _initialize(self, *kwargs): | def _initialize(self, *kwargs): | ||||
| @@ -367,9 +364,9 @@ class Normal(Initializer): | |||||
| self.sigma = sigma | self.sigma = sigma | ||||
| def _initialize(self, arr): | def _initialize(self, arr): | ||||
| seed = self.seed | |||||
| seed, seed2 = self.seed | |||||
| output_tensor = Tensor(np.zeros(arr.shape, dtype=np.float32)) | output_tensor = Tensor(np.zeros(arr.shape, dtype=np.float32)) | ||||
| random_normal(0, self.sigma, arr.shape, seed, output_tensor) | |||||
| random_normal(0, self.sigma, arr.shape, seed, seed2, output_tensor) | |||||
| output_data = output_tensor.asnumpy() | output_data = output_tensor.asnumpy() | ||||
| output_data *= self.sigma | output_data *= self.sigma | ||||
| _assignment(arr, output_data) | _assignment(arr, output_data) | ||||
| @@ -18,6 +18,7 @@ import mindspore.dataset as de | |||||
| from mindspore._checkparam import Validator | from mindspore._checkparam import Validator | ||||
| # constants | # constants | ||||
| DEFAULT_GRAPH_SEED = 87654321 | |||||
| _MAXINT32 = 2**31 - 1 | _MAXINT32 = 2**31 - 1 | ||||
| keyConstant = [3528531795, 2654435769, 3449720151, 3144134277] | keyConstant = [3528531795, 2654435769, 3449720151, 3144134277] | ||||
| @@ -210,7 +211,9 @@ def _get_graph_seed(op_seed, kernel_name): | |||||
| >>> _get_graph_seed(seed, 'normal') | >>> _get_graph_seed(seed, 'normal') | ||||
| """ | """ | ||||
| global_seed = get_seed() | global_seed = get_seed() | ||||
| if global_seed is None: | |||||
| if global_seed == 0: | |||||
| global_seed = DEFAULT_GRAPH_SEED | |||||
| elif global_seed is None: | |||||
| global_seed = 0 | global_seed = 0 | ||||
| if op_seed is None: | if op_seed is None: | ||||
| op_seed = 0 | op_seed = 0 | ||||
| @@ -465,7 +465,7 @@ class MetaTensor(MetaTensor_): | |||||
| def __exit__(self, ptype, value, trace): | def __exit__(self, ptype, value, trace): | ||||
| if self.need_set_seed: | if self.need_set_seed: | ||||
| np.random.seed(self._np_seed) | np.random.seed(self._np_seed) | ||||
| self.init.seed = self.seed | |||||
| self.init.seed, _ = self.seed | |||||
| with seed_context(self.init): | with seed_context(self.init): | ||||
| self.init(arr) | self.init(arr) | ||||
| @@ -39,7 +39,7 @@ class WithBNNLossCell(Cell): | |||||
| Examples: | Examples: | ||||
| >>> net = Net() | >>> net = Net() | ||||
| >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||||
| >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False) | |||||
| >>> net_with_criterion_object = WithBNNLossCell(net, loss_fn) | >>> net_with_criterion_object = WithBNNLossCell(net, loss_fn) | ||||
| >>> net_with_criterion = net_with_criterion_object() | >>> net_with_criterion = net_with_criterion_object() | ||||
| >>> | >>> | ||||
| @@ -46,7 +46,7 @@ class WithLossCell(Cell): | |||||
| Examples: | Examples: | ||||
| >>> net = Net() | >>> net = Net() | ||||
| >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||||
| >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False) | |||||
| >>> net_with_criterion = nn.WithLossCell(net, loss_fn) | >>> net_with_criterion = nn.WithLossCell(net, loss_fn) | ||||
| >>> | >>> | ||||
| >>> batch_size = 2 | >>> batch_size = 2 | ||||