Browse Source

refine random seed reserve process

tags/v1.1.0
lizhenyu 5 years ago
parent
commit
f526c76617
2 changed files with 5 additions and 3 deletions
  1. +3
    -3
      mindspore/common/parameter.py
  2. +2
    -0
      mindspore/common/seed.py

+ 3
- 3
mindspore/common/parameter.py View File

@@ -449,6 +449,9 @@ class Parameter(MetaTensor_):
return self
if self.inited_param is not None:
return self.inited_param
if _is_role_worker() and self.cache_enable:
global_seed, op_seed = _get_global_and_op_seed()
_insert_weight_init_info(self.name, global_seed, op_seed)
if layout is not None:
if not isinstance(layout, tuple):
raise TypeError("The layout should be tuple! layout is {}.".format(layout))
@@ -463,9 +466,6 @@ class Parameter(MetaTensor_):
else:
data = self.init_mode.to_tensor(slice_index, layout[2], layout[5])
else:
if _is_role_worker() and self.cache_enable:
global_seed, op_seed = _get_global_and_op_seed()
_insert_weight_init_info(self.name, global_seed, op_seed)
if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, MetaTensor)):
if _is_role_worker() or _is_role_sched():
data = self.init_mode.to_tensor(0, [1])


+ 2
- 0
mindspore/common/seed.py View File

@@ -203,6 +203,8 @@ def _get_global_and_op_seed():
global_seed = DEFAULT_GRAPH_SEED
elif global_seed is None:
global_seed = 0
if op_seed is None:
op_seed = 0
Validator.check_non_negative_int(op_seed, "seed", "init")
temp_seed = _get_op_seed(op_seed, "init")
seeds = _truncate_seed(global_seed), _truncate_seed(temp_seed)


Loading…
Cancel
Save