| @@ -712,8 +712,8 @@ def merge_sliced_parameter(sliced_parameters, strategy=None): | |||||
| Args: | Args: | ||||
| sliced_parameters (list[Parameter]): Parameter slices in order of rank_id. | sliced_parameters (list[Parameter]): Parameter slices in order of rank_id. | ||||
| strategy (dict): Parameter slice strategy. Default: None. | |||||
| If strategy is None, just merge parameter slices in 0 axis order. | |||||
| strategy (dict): Parameter slice strategy, the default is None. | |||||
| If strategy is None, just merge parameter slices in 0 axis order. | |||||
| - key (str): Parameter name. | - key (str): Parameter name. | ||||
| - value (<class 'node_strategy_pb2.ParallelLayouts'>): Slice strategy of this parameter. | - value (<class 'node_strategy_pb2.ParallelLayouts'>): Slice strategy of this parameter. | ||||
| @@ -728,11 +728,15 @@ def merge_sliced_parameter(sliced_parameters, strategy=None): | |||||
| Examples: | Examples: | ||||
| >>> strategy = build_searched_strategy("./strategy_train.ckpt") | >>> strategy = build_searched_strategy("./strategy_train.ckpt") | ||||
| >>> sliced_parameters = [\ | |||||
| Parameter(Tensor(np.array([0.00023915, 0.00013939, -0.00098059])), "network.embedding_table"), \ | |||||
| Parameter(Tensor(np.array([0.00015815, 0.00015458, -0.00012125])), "network.embedding_table"), \ | |||||
| Parameter(Tensor(np.array([0.00042165, 0.00029692, -0.00007941])), "network.embedding_tabel"), \ | |||||
| Parameter(Tensor(np.array([0.00084451, 0.00089960, -0.00010431])), "network.embedding_table")] | |||||
| >>> sliced_parameters = [ | |||||
| >>> Parameter(Tensor(np.array([0.00023915, 0.00013939, -0.00098059])), | |||||
| >>> "network.embedding_table"), | |||||
| >>> Parameter(Tensor(np.array([0.00015815, 0.00015458, -0.00012125])), | |||||
| >>> "network.embedding_table"), | |||||
| >>> Parameter(Tensor(np.array([0.00042165, 0.00029692, -0.00007941])), | |||||
| >>> "network.embedding_tabel"), | |||||
| >>> Parameter(Tensor(np.array([0.00084451, 0.00089960, -0.00010431])), | |||||
| >>> "network.embedding_table")] | |||||
| >>> merged_parameter = merge_sliced_parameter(sliced_parameters, strategy) | >>> merged_parameter = merge_sliced_parameter(sliced_parameters, strategy) | ||||
| """ | """ | ||||
| if not isinstance(sliced_parameters, list): | if not isinstance(sliced_parameters, list): | ||||