Browse Source

resolve frontend layout

tags/v1.0.0
caozhou 5 years ago
parent
commit
f4f0e9af1f
1 changed files with 11 additions and 7 deletions
  1. +11
    -7
      mindspore/train/serialization.py

+ 11
- 7
mindspore/train/serialization.py View File

@@ -712,8 +712,8 @@ def merge_sliced_parameter(sliced_parameters, strategy=None):

Args:
sliced_parameters (list[Parameter]): Parameter slices in order of rank_id.
strategy (dict): Parameter slice strategy. Default: None.
If strategy is None, just merge parameter slices in 0 axis order.
strategy (dict): Parameter slice strategy, the default is None.
If strategy is None, just merge parameter slices in 0 axis order.

- key (str): Parameter name.
- value (<class 'node_strategy_pb2.ParallelLayouts'>): Slice strategy of this parameter.
@@ -728,11 +728,15 @@ def merge_sliced_parameter(sliced_parameters, strategy=None):

Examples:
>>> strategy = build_searched_strategy("./strategy_train.ckpt")
>>> sliced_parameters = [\
Parameter(Tensor(np.array([0.00023915, 0.00013939, -0.00098059])), "network.embedding_table"), \
Parameter(Tensor(np.array([0.00015815, 0.00015458, -0.00012125])), "network.embedding_table"), \
Parameter(Tensor(np.array([0.00042165, 0.00029692, -0.00007941])), "network.embedding_tabel"), \
Parameter(Tensor(np.array([0.00084451, 0.00089960, -0.00010431])), "network.embedding_table")]
>>> sliced_parameters = [
>>> Parameter(Tensor(np.array([0.00023915, 0.00013939, -0.00098059])),
>>> "network.embedding_table"),
>>> Parameter(Tensor(np.array([0.00015815, 0.00015458, -0.00012125])),
>>> "network.embedding_table"),
>>> Parameter(Tensor(np.array([0.00042165, 0.00029692, -0.00007941])),
>>> "network.embedding_tabel"),
>>> Parameter(Tensor(np.array([0.00084451, 0.00089960, -0.00010431])),
>>> "network.embedding_table")]
>>> merged_parameter = merge_sliced_parameter(sliced_parameters, strategy)
"""
if not isinstance(sliced_parameters, list):


Loading…
Cancel
Save