From f526c76617efa91a98e1c79c6a37f24a29d647ce Mon Sep 17 00:00:00 2001 From: lizhenyu Date: Fri, 11 Dec 2020 09:30:53 +0800 Subject: [PATCH] refine random seed reserve process --- mindspore/common/parameter.py | 6 +++--- mindspore/common/seed.py | 2 ++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index 9529f18a0a..f27b1b7825 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -449,6 +449,9 @@ class Parameter(MetaTensor_): return self if self.inited_param is not None: return self.inited_param + if _is_role_worker() and self.cache_enable: + global_seed, op_seed = _get_global_and_op_seed() + _insert_weight_init_info(self.name, global_seed, op_seed) if layout is not None: if not isinstance(layout, tuple): raise TypeError("The layout should be tuple! layout is {}.".format(layout)) @@ -463,9 +466,6 @@ class Parameter(MetaTensor_): else: data = self.init_mode.to_tensor(slice_index, layout[2], layout[5]) else: - if _is_role_worker() and self.cache_enable: - global_seed, op_seed = _get_global_and_op_seed() - _insert_weight_init_info(self.name, global_seed, op_seed) if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, MetaTensor)): if _is_role_worker() or _is_role_sched(): data = self.init_mode.to_tensor(0, [1]) diff --git a/mindspore/common/seed.py b/mindspore/common/seed.py index ad3cd859e9..f2dd9c263a 100644 --- a/mindspore/common/seed.py +++ b/mindspore/common/seed.py @@ -203,6 +203,8 @@ def _get_global_and_op_seed(): global_seed = DEFAULT_GRAPH_SEED elif global_seed is None: global_seed = 0 + if op_seed is None: + op_seed = 0 Validator.check_non_negative_int(op_seed, "seed", "init") temp_seed = _get_op_seed(op_seed, "init") seeds = _truncate_seed(global_seed), _truncate_seed(temp_seed)