Browse Source

add interface doc for set_param_ps

tags/v1.1.0
lizhenyu 5 years ago
parent
commit
d265bc96a1
2 changed files with 15 additions and 1 deletions
  1. +11
    -0
      mindspore/common/parameter.py
  2. +4
    -1
      mindspore/nn/cell.py

+ 11
- 0
mindspore/common/parameter.py View File

@@ -167,6 +167,17 @@ class Parameter(MetaTensor_):
"""For parse check."""

def set_param_ps(self, init_in_server=False):
"""
Set whether the trainable parameter is updated by parameter server and whether the
trainable parameter is initialized on server.

Note:
It only works when a running task is in the parameter server mode.

Args:
init_in_server (bool): Whether trainable parameter updated by parameter server is
initialized on server. Default: False.
"""
if _is_role_worker() or _is_role_pserver() or _is_role_sched():
if init_in_server and (not self.name.endswith("embedding_table")):
raise RuntimeError("Can not initialize parameter '{}' in server, only parameters of "


+ 4
- 1
mindspore/nn/cell.py View File

@@ -1018,13 +1018,16 @@ class Cell(Cell_):

def set_param_ps(self, recurse=True, init_in_server=False):
"""
Set whether the trainable parameter is updated by parameter server.
Set whether the trainable parameters are updated by parameter server and whether the
trainable parameters are initialized on server.

Note:
It only works when a running task is in the parameter server mode.

Args:
recurse (bool): Whether sets the trainable parameters of subcells. Default: True.
init_in_server (bool): Whether trainable parameters updated by parameter server are
initialized on server. Default: False.
"""
params = self.trainable_params(recurse)
for param in params:


Loading…
Cancel
Save