| @@ -240,7 +240,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N | |||||
| ValueError: Checkpoint file is incorrect. | ValueError: Checkpoint file is incorrect. | ||||
| Examples: | Examples: | ||||
| >>> ckpt_file_name = "./checkpoint/LeNet5-2_1875.ckpt" | |||||
| >>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt" | |||||
| >>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1") | >>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1") | ||||
| """ | """ | ||||
| if not isinstance(ckpt_file_name, str): | if not isinstance(ckpt_file_name, str): | ||||
| @@ -341,8 +341,9 @@ def load_param_into_net(net, parameter_dict, strict_load=False): | |||||
| TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary. | TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary. | ||||
| Examples: | Examples: | ||||
| >>> net = LeNet5() | |||||
| >>> param_dict = load_checkpoint("LeNet5-2_1875.ckpt") | |||||
| >>> net = Net() | |||||
| >>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt" | |||||
| >>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1") | |||||
| >>> load_param_into_net(net, param_dict) | >>> load_param_into_net(net, param_dict) | ||||
| """ | """ | ||||
| if not isinstance(net, nn.Cell): | if not isinstance(net, nn.Cell): | ||||
| @@ -783,9 +784,6 @@ def build_searched_strategy(strategy_filename): | |||||
| ValueError: Strategy file is incorrect. | ValueError: Strategy file is incorrect. | ||||
| TypeError: Strategy_filename is not str. | TypeError: Strategy_filename is not str. | ||||
| Examples: | |||||
| >>> strategy_filename = "./strategy_train.ckpt" | |||||
| >>> strategy = build_searched_strategy(strategy_filename) | |||||
| """ | """ | ||||
| if not isinstance(strategy_filename, str): | if not isinstance(strategy_filename, str): | ||||
| raise TypeError(f"The strategy_filename should be str, but got {type(strategy_filename)}.") | raise TypeError(f"The strategy_filename should be str, but got {type(strategy_filename)}.") | ||||
| @@ -836,17 +834,16 @@ def merge_sliced_parameter(sliced_parameters, strategy=None): | |||||
| KeyError: The parameter name is not in keys of strategy. | KeyError: The parameter name is not in keys of strategy. | ||||
| Examples: | Examples: | ||||
| >>> strategy = build_searched_strategy("./strategy_train.ckpt") | |||||
| >>> sliced_parameters = [ | >>> sliced_parameters = [ | ||||
| >>> Parameter(Tensor(np.array([0.00023915, 0.00013939, -0.00098059])), | |||||
| >>> "network.embedding_table"), | |||||
| >>> Parameter(Tensor(np.array([0.00015815, 0.00015458, -0.00012125])), | |||||
| >>> "network.embedding_table"), | |||||
| >>> Parameter(Tensor(np.array([0.00042165, 0.00029692, -0.00007941])), | |||||
| >>> "network.embedding_table"), | |||||
| >>> Parameter(Tensor(np.array([0.00084451, 0.00089960, -0.00010431])), | |||||
| >>> "network.embedding_table")] | |||||
| >>> merged_parameter = merge_sliced_parameter(sliced_parameters, strategy) | |||||
| ... Parameter(Tensor(np.array([0.00023915, 0.00013939, -0.00098059])), | |||||
| ... "network.embedding_table"), | |||||
| ... Parameter(Tensor(np.array([0.00015815, 0.00015458, -0.00012125])), | |||||
| ... "network.embedding_table"), | |||||
| ... Parameter(Tensor(np.array([0.00042165, 0.00029692, -0.00007941])), | |||||
| ... "network.embedding_table"), | |||||
| ... Parameter(Tensor(np.array([0.00084451, 0.00089960, -0.00010431])), | |||||
| ... "network.embedding_table")] | |||||
| >>> merged_parameter = merge_sliced_parameter(sliced_parameters) | |||||
| """ | """ | ||||
| if not isinstance(sliced_parameters, list): | if not isinstance(sliced_parameters, list): | ||||
| raise TypeError(f"The sliced_parameters should be list, but got {type(sliced_parameters)}.") | raise TypeError(f"The sliced_parameters should be list, but got {type(sliced_parameters)}.") | ||||