| @@ -136,7 +136,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs): | |||||
| keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, overwrite the level setting. | keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, overwrite the level setting. | ||||
| Only `cast_model_type` is `float16`, `keep_batchnorm_fp32` will take effect. | Only `cast_model_type` is `float16`, `keep_batchnorm_fp32` will take effect. | ||||
| loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else | loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else | ||||
| scale the loss by LossScaleManager. If set, overwrite the level setting. | |||||
| scale the loss by `LossScaleManager`. If set, overwrite the level setting. | |||||
| """ | """ | ||||
| validator.check_value_type('network', network, nn.Cell) | validator.check_value_type('network', network, nn.Cell) | ||||
| validator.check_value_type('optimizer', optimizer, nn.Optimizer) | validator.check_value_type('optimizer', optimizer, nn.Optimizer) | ||||
| @@ -53,17 +53,19 @@ def connect_network_with_dataset(network, dataset_helper): | |||||
| Args: | Args: | ||||
| network (Cell): The training network for dataset. | network (Cell): The training network for dataset. | ||||
| dataset_helper(DatasetHelper): A class to process the MindData dataset, it provides the type, shape and queue | |||||
| dataset_helper (DatasetHelper): A class to process the MindData dataset, it provides the type, shape and queue | |||||
| name of the dataset to wrap the `GetNext`. | name of the dataset to wrap the `GetNext`. | ||||
| Outputs: | |||||
| Returns: | |||||
| Cell, a new network wrapped with 'GetNext' in the case of running the task on Ascend in graph mode, otherwise | Cell, a new network wrapped with 'GetNext' in the case of running the task on Ascend in graph mode, otherwise | ||||
| it is the input network. | it is the input network. | ||||
| Examples: | Examples: | ||||
| >>> from mindspore import DatasetHelper | |||||
| >>> | |||||
| >>> # call create_dataset function to create a regular dataset, refer to mindspore.dataset | >>> # call create_dataset function to create a regular dataset, refer to mindspore.dataset | ||||
| >>> train_dataset = create_custom_dataset() | >>> train_dataset = create_custom_dataset() | ||||
| >>> dataset_helper = mindspore.DatasetHelper(train_dataset, dataset_sink_mode=True) | |||||
| >>> dataset_helper = DatasetHelper(train_dataset, dataset_sink_mode=True) | |||||
| >>> net = Net() | >>> net = Net() | ||||
| >>> net_with_get_next = connect_network_with_dataset(net, dataset_helper) | >>> net_with_get_next = connect_network_with_dataset(net, dataset_helper) | ||||
| """ | """ | ||||
| @@ -145,14 +147,18 @@ class DatasetHelper: | |||||
| The iteration of DatasetHelper will provide one epoch data. | The iteration of DatasetHelper will provide one epoch data. | ||||
| Args: | Args: | ||||
| dataset (DataSet): The training dataset iterator. | |||||
| dataset_sink_mode (bool): If true use GetNext to fetch the data, or else feed the data from host. Default: True. | |||||
| dataset (Dataset): The training dataset iterator. | |||||
| dataset_sink_mode (bool): If true use GetNext to fetch the data, or else feed the data | |||||
| from host. Default: True. | |||||
| sink_size (int): Control the amount of data in each sink. | sink_size (int): Control the amount of data in each sink. | ||||
| If sink_size=-1, sink the complete dataset for each epoch. | |||||
| If sink_size>0, sink sink_size data for each epoch. Default: -1. | |||||
| If sink_size=-1, sink the complete dataset for each epoch. | |||||
| If sink_size>0, sink sink_size data for each epoch. | |||||
| Default: -1. | |||||
| epoch_num (int): Control the number of epoch data to send. Default: 1. | epoch_num (int): Control the number of epoch data to send. Default: 1. | ||||
| Examples: | Examples: | ||||
| >>> from mindspore import nn, DatasetHelper | |||||
| >>> | |||||
| >>> network = Net() | >>> network = Net() | ||||
| >>> net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | >>> net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") | ||||
| >>> network = nn.WithLossCell(network, net_loss) | >>> network = nn.WithLossCell(network, net_loss) | ||||
| @@ -373,6 +379,7 @@ class _DatasetIterPSServer(_DatasetIter): | |||||
| self.op = op | self.op = op | ||||
| class _DatasetIterPSWork(_DatasetIter): | class _DatasetIterPSWork(_DatasetIter): | ||||
| """Iter for context on MS_WORKER""" | """Iter for context on MS_WORKER""" | ||||
| @@ -388,6 +395,7 @@ class _DatasetIterPSWork(_DatasetIter): | |||||
| self.op = op | self.op = op | ||||
| class _DatasetIterNormal: | class _DatasetIterNormal: | ||||
| """Iter for normal(non sink) mode, feed the data from host.""" | """Iter for normal(non sink) mode, feed the data from host.""" | ||||
| @@ -33,6 +33,7 @@ class LossScaleManager: | |||||
| def get_update_cell(self): | def get_update_cell(self): | ||||
| """Get the loss scaling update logic cell.""" | """Get the loss scaling update logic cell.""" | ||||
| class FixedLossScaleManager(LossScaleManager): | class FixedLossScaleManager(LossScaleManager): | ||||
| """ | """ | ||||
| Fixed loss-scale manager. | Fixed loss-scale manager. | ||||
| @@ -42,9 +43,12 @@ class FixedLossScaleManager(LossScaleManager): | |||||
| drop_overflow_update (bool): whether to execute optimizer if there is an overflow. Default: True. | drop_overflow_update (bool): whether to execute optimizer if there is an overflow. Default: True. | ||||
| Examples: | Examples: | ||||
| >>> from mindspore import Model, nn | |||||
| >>> from mindspore.train.loss_scale_manager import FixedLossScaleManager | |||||
| >>> | |||||
| >>> net = Net() | >>> net = Net() | ||||
| >>> loss_scale_manager = FixedLossScaleManager() | >>> loss_scale_manager = FixedLossScaleManager() | ||||
| >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||||
| >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||||
| >>> model = Model(net, loss_scale_manager=loss_scale_manager, optimizer=optim) | >>> model = Model(net, loss_scale_manager=loss_scale_manager, optimizer=optim) | ||||
| """ | """ | ||||
| def __init__(self, loss_scale=128.0, drop_overflow_update=True): | def __init__(self, loss_scale=128.0, drop_overflow_update=True): | ||||
| @@ -87,9 +91,12 @@ class DynamicLossScaleManager(LossScaleManager): | |||||
| scale_window (int): Maximum continuous normal steps when there is no overflow. Default: 2000. | scale_window (int): Maximum continuous normal steps when there is no overflow. Default: 2000. | ||||
| Examples: | Examples: | ||||
| >>> from mindspore import Model, nn | |||||
| >>> from mindspore.train.loss_scale_manager import DynamicLossScaleManager | |||||
| >>> | |||||
| >>> net = Net() | >>> net = Net() | ||||
| >>> loss_scale_manager = DynamicLossScaleManager() | >>> loss_scale_manager = DynamicLossScaleManager() | ||||
| >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||||
| >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||||
| >>> model = Model(net, loss_scale_manager=loss_scale_manager, optimizer=optim) | >>> model = Model(net, loss_scale_manager=loss_scale_manager, optimizer=optim) | ||||
| """ | """ | ||||
| def __init__(self, | def __init__(self, | ||||
| @@ -76,7 +76,6 @@ class Model: | |||||
| to other metric. Default: None. | to other metric. Default: None. | ||||
| amp_level (str): Option for argument `level` in `mindspore.amp.build_train_network`, level for mixed | amp_level (str): Option for argument `level` in `mindspore.amp.build_train_network`, level for mixed | ||||
| precision training. Supports ["O0", "O2", "O3", "auto"]. Default: "O0". | precision training. Supports ["O0", "O2", "O3", "auto"]. Default: "O0". | ||||
| - O0: Do not change. | - O0: Do not change. | ||||
| - O2: Cast network to float16, keep batchnorm run in float32, using dynamic loss scale. | - O2: Cast network to float16, keep batchnorm run in float32, using dynamic loss scale. | ||||
| - O3: Cast network to float16, with additional property 'keep_batchnorm_fp32=False'. | - O3: Cast network to float16, with additional property 'keep_batchnorm_fp32=False'. | ||||
| @@ -93,6 +92,8 @@ class Model: | |||||
| will be overwritten. Default: True. | will be overwritten. Default: True. | ||||
| Examples: | Examples: | ||||
| >>> from mindspore import Model, nn | |||||
| >>> | |||||
| >>> class Net(nn.Cell): | >>> class Net(nn.Cell): | ||||
| ... def __init__(self, num_class=10, num_channel=1): | ... def __init__(self, num_class=10, num_channel=1): | ||||
| ... super(Net, self).__init__() | ... super(Net, self).__init__() | ||||
| @@ -118,7 +119,8 @@ class Model: | |||||
| >>> loss = nn.SoftmaxCrossEntropyWithLogits() | >>> loss = nn.SoftmaxCrossEntropyWithLogits() | ||||
| >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) | >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) | ||||
| >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) | >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) | ||||
| >>> # For details about how to build the dataset, please refer to the tutorial document on the official website. | |||||
| >>> # For details about how to build the dataset, please refer to the tutorial | |||||
| >>> # document on the official website. | |||||
| >>> dataset = create_custom_dataset() | >>> dataset = create_custom_dataset() | ||||
| >>> model.train(2, dataset) | >>> model.train(2, dataset) | ||||
| """ | """ | ||||
| @@ -565,23 +567,29 @@ class Model: | |||||
| returned and passed to the network. Otherwise, a tuple (data, label) should | returned and passed to the network. Otherwise, a tuple (data, label) should | ||||
| be returned. The data and label would be passed to the network and loss | be returned. The data and label would be passed to the network and loss | ||||
| function respectively. | function respectively. | ||||
| callbacks (list, object): List of callback objects or callback object, which should be executed | |||||
| while training. Default: None. | |||||
| callbacks (Optional[list[Callback], Callback]): List of callback objects or callback object, | |||||
| which should be executed while training. | |||||
| Default: None. | |||||
| dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True. | dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True. | ||||
| Configure pynative mode or CPU, the training process will be performed with | Configure pynative mode or CPU, the training process will be performed with | ||||
| dataset not sink. | |||||
| dataset not sink. Default: True. | |||||
| sink_size (int): Control the amount of data in each sink. | sink_size (int): Control the amount of data in each sink. | ||||
| If sink_size = -1, sink the complete dataset for each epoch. | If sink_size = -1, sink the complete dataset for each epoch. | ||||
| If sink_size > 0, sink sink_size data for each epoch. | If sink_size > 0, sink sink_size data for each epoch. | ||||
| If dataset_sink_mode is False, set sink_size as invalid. Default: -1. | |||||
| If dataset_sink_mode is False, set sink_size as invalid. | |||||
| Default: -1. | |||||
| Examples: | Examples: | ||||
| >>> from mindspore import Model, nn | |||||
| >>> from mindspore.train.loss_scale_manager import FixedLossScaleManager | >>> from mindspore.train.loss_scale_manager import FixedLossScaleManager | ||||
| >>> | |||||
| >>> # For details about how to build the dataset, please refer to the tutorial | |||||
| >>> # document on the official website. | |||||
| >>> dataset = create_custom_dataset() | >>> dataset = create_custom_dataset() | ||||
| >>> net = Net() | >>> net = Net() | ||||
| >>> loss = nn.SoftmaxCrossEntropyWithLogits() | >>> loss = nn.SoftmaxCrossEntropyWithLogits() | ||||
| >>> loss_scale_manager = FixedLossScaleManager() | >>> loss_scale_manager = FixedLossScaleManager() | ||||
| >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||||
| >>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||||
| >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager) | >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager) | ||||
| >>> model.train(2, dataset) | >>> model.train(2, dataset) | ||||
| """ | """ | ||||
| @@ -690,13 +698,19 @@ class Model: | |||||
| Args: | Args: | ||||
| valid_dataset (Dataset): Dataset to evaluate the model. | valid_dataset (Dataset): Dataset to evaluate the model. | ||||
| callbacks (list): List of callback objects which should be executed while training. Default: None. | |||||
| dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True. | |||||
| callbacks (Optional[list(Callback)]): List of callback objects which should be executed | |||||
| while training. Default: None. | |||||
| dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. | |||||
| Default: True. | |||||
| Returns: | Returns: | ||||
| Dict, which returns the loss value and metrics values for the model in the test mode. | Dict, which returns the loss value and metrics values for the model in the test mode. | ||||
| Examples: | Examples: | ||||
| >>> from mindspore import Model, nn | |||||
| >>> | |||||
| >>> # For details about how to build the dataset, please refer to the tutorial | |||||
| >>> # document on the official website. | |||||
| >>> dataset = create_custom_dataset() | >>> dataset = create_custom_dataset() | ||||
| >>> net = Net() | >>> net = Net() | ||||
| >>> loss = nn.SoftmaxCrossEntropyWithLogits() | >>> loss = nn.SoftmaxCrossEntropyWithLogits() | ||||
| @@ -739,14 +753,17 @@ class Model: | |||||
| Batch data should be put together in one tensor. | Batch data should be put together in one tensor. | ||||
| Args: | Args: | ||||
| predict_data: The predict data, can be bool, int, float, str, None, tensor, | |||||
| predict_data (Tensor): The predict data, can be bool, int, float, str, None, tensor, | |||||
| or tuple, list and dict that store these types. | or tuple, list and dict that store these types. | ||||
| Returns: | Returns: | ||||
| Tensor, array(s) of predictions. | Tensor, array(s) of predictions. | ||||
| Examples: | Examples: | ||||
| >>> input_data = Tensor(np.random.randint(0, 255, [1, 1, 32, 32]), mindspore.float32) | |||||
| >>> import mindspore as ms | |||||
| >>> from mindspore import Model, Tensor | |||||
| >>> | |||||
| >>> input_data = Tensor(np.random.randint(0, 255, [1, 1, 32, 32]), ms.float32) | |||||
| >>> model = Model(Net()) | >>> model = Model(Net()) | ||||
| >>> result = model.predict(input_data) | >>> result = model.predict(input_data) | ||||
| """ | """ | ||||
| @@ -771,12 +788,16 @@ class Model: | |||||
| predict_data (Tensor): One tensor or multiple tensors of predict data. | predict_data (Tensor): One tensor or multiple tensors of predict data. | ||||
| Returns: | Returns: | ||||
| parameter_layout_dict (dict): Parameter layout dictionary used for load distributed checkpoint | |||||
| Dict, Parameter layout dictionary used for load distributed checkpoint | |||||
| Examples: | Examples: | ||||
| >>> import numpy as np | |||||
| >>> import mindspore as ms | |||||
| >>> from mindspore import Model, context, Tensor | |||||
| >>> | |||||
| >>> context.set_context(mode=context.GRAPH_MODE) | >>> context.set_context(mode=context.GRAPH_MODE) | ||||
| >>> context.set_auto_parallel_context(full_batch=True, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL) | >>> context.set_auto_parallel_context(full_batch=True, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL) | ||||
| >>> input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32) | |||||
| >>> input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), ms.float32) | |||||
| >>> model = Model(Net()) | >>> model = Model(Net()) | ||||
| >>> model.infer_predict_layout(input_data) | >>> model.infer_predict_layout(input_data) | ||||
| """ | """ | ||||
| @@ -802,4 +823,5 @@ class Model: | |||||
| if param.cache_enable: | if param.cache_enable: | ||||
| Tensor(param).flush_from_cache() | Tensor(param).flush_from_cache() | ||||
| __all__ = ["Model"] | __all__ = ["Model"] | ||||
| @@ -153,16 +153,17 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F | |||||
| Saves checkpoint info to a specified file. | Saves checkpoint info to a specified file. | ||||
| Args: | Args: | ||||
| save_obj (nn.Cell or list): The cell object or data list(each element is a dictionary, like | |||||
| [{"name": param_name, "data": param_data},...], the type of param_name would | |||||
| be string, and the type of param_data would be parameter or tensor). | |||||
| save_obj (Union[Cell, list]): The cell object or data list(each element is a dictionary, like | |||||
| [{"name": param_name, "data": param_data},...], the type of | |||||
| param_name would be string, and the type of param_data would | |||||
| be parameter or `Tensor`). | |||||
| ckpt_file_name (str): Checkpoint file name. If the file name already exists, it will be overwritten. | ckpt_file_name (str): Checkpoint file name. If the file name already exists, it will be overwritten. | ||||
| integrated_save (bool): Whether to integrated save in automatic model parallel scene. Default: True | integrated_save (bool): Whether to integrated save in automatic model parallel scene. Default: True | ||||
| async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False | async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False | ||||
| Raises: | Raises: | ||||
| TypeError: If the parameter save_obj is not nn.Cell or list type.And if the parameter integrated_save and | |||||
| async_save are not bool type. | |||||
| TypeError: If the parameter save_obj is not `nn.Cell` or list type. And if the parameter | |||||
| `integrated_save` and `async_save` are not bool type. | |||||
| """ | """ | ||||
| if not isinstance(save_obj, nn.Cell) and not isinstance(save_obj, list): | if not isinstance(save_obj, nn.Cell) and not isinstance(save_obj, list): | ||||
| @@ -247,6 +248,8 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N | |||||
| ValueError: Checkpoint file is incorrect. | ValueError: Checkpoint file is incorrect. | ||||
| Examples: | Examples: | ||||
| >>> from mindspore import load_checkpoint | |||||
| >>> | |||||
| >>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt" | >>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt" | ||||
| >>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1") | >>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1") | ||||
| """ | """ | ||||
| @@ -349,6 +352,8 @@ def load_param_into_net(net, parameter_dict, strict_load=False): | |||||
| TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary. | TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary. | ||||
| Examples: | Examples: | ||||
| >>> from mindspore import load_checkpoint, load_param_into_net | |||||
| >>> | |||||
| >>> net = Net() | >>> net = Net() | ||||
| >>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt" | >>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt" | ||||
| >>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1") | >>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1") | ||||
| @@ -531,7 +536,6 @@ def export(net, *inputs, file_name, file_format='AIR', **kwargs): | |||||
| inputs (Tensor): Inputs of the `net`. | inputs (Tensor): Inputs of the `net`. | ||||
| file_name (str): File name of the model to be exported. | file_name (str): File name of the model to be exported. | ||||
| file_format (str): MindSpore currently supports 'AIR', 'ONNX' and 'MINDIR' format for exported model. | file_format (str): MindSpore currently supports 'AIR', 'ONNX' and 'MINDIR' format for exported model. | ||||
| - AIR: Ascend Intermediate Representation. An intermediate representation format of Ascend model. | - AIR: Ascend Intermediate Representation. An intermediate representation format of Ascend model. | ||||
| Recommended suffix for output file is '.air'. | Recommended suffix for output file is '.air'. | ||||
| - ONNX: Open Neural Network eXchange. An open format built to represent machine learning models. | - ONNX: Open Neural Network eXchange. An open format built to represent machine learning models. | ||||
| @@ -541,7 +545,6 @@ def export(net, *inputs, file_name, file_format='AIR', **kwargs): | |||||
| Recommended suffix for output file is '.mindir'. | Recommended suffix for output file is '.mindir'. | ||||
| kwargs (dict): Configuration options dictionary. | kwargs (dict): Configuration options dictionary. | ||||
| - quant_mode: The mode of quant. | - quant_mode: The mode of quant. | ||||
| - mean: Input data mean. Default: 127.5. | - mean: Input data mean. Default: 127.5. | ||||
| - std_dev: Input data variance. Default: 127.5. | - std_dev: Input data variance. Default: 127.5. | ||||
| @@ -928,11 +931,9 @@ def merge_sliced_parameter(sliced_parameters, strategy=None): | |||||
| Args: | Args: | ||||
| sliced_parameters (list[Parameter]): Parameter slices in order of rank_id. | sliced_parameters (list[Parameter]): Parameter slices in order of rank_id. | ||||
| strategy (dict): Parameter slice strategy, the default is None. | |||||
| If strategy is None, just merge parameter slices in 0 axis order. | |||||
| - key (str): Parameter name. | |||||
| - value (<class 'node_strategy_pb2.ParallelLayouts'>): Slice strategy of this parameter. | |||||
| strategy (Optional[dict]): Parameter slice strategy, whose key is parameter name and | |||||
| value is slice strategy of this parameter. If strategy is None, just merge | |||||
| parameter slices in 0 axis order. Default: None. | |||||
| Returns: | Returns: | ||||
| Parameter, the merged parameter which has the whole data. | Parameter, the merged parameter which has the whole data. | ||||
| @@ -943,6 +944,9 @@ def merge_sliced_parameter(sliced_parameters, strategy=None): | |||||
| KeyError: The parameter name is not in keys of strategy. | KeyError: The parameter name is not in keys of strategy. | ||||
| Examples: | Examples: | ||||
| >>> from mindspore.common.parameter import Parameter | |||||
| >>> from mindspore.train import merge_sliced_parameter | |||||
| >>> | |||||
| >>> sliced_parameters = [ | >>> sliced_parameters = [ | ||||
| ... Parameter(Tensor(np.array([0.00023915, 0.00013939, -0.00098059])), | ... Parameter(Tensor(np.array([0.00023915, 0.00013939, -0.00098059])), | ||||
| ... "network.embedding_table"), | ... "network.embedding_table"), | ||||
| @@ -1010,10 +1014,13 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy= | |||||
| Args: | Args: | ||||
| network (Cell): Network for distributed predication. | network (Cell): Network for distributed predication. | ||||
| checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id. | |||||
| predict_strategy (dict): Strategy of predication process, whose key is parameter name, and value is a list or | |||||
| a tuple that the first four elements are [dev_matrix, tensor_map, param_split_shape, field]. If None, | |||||
| it means that the predication process just uses single device. Default: None. | |||||
| checkpoint_filenames (list(str)): The name of Checkpoint files | |||||
| in order of rank id. | |||||
| predict_strategy (Optional(dict)): Strategy of predication process, whose key | |||||
| is parameter name, and value is a list or a tuple that the first four | |||||
| elements are [dev_matrix, tensor_map, param_split_shape, field]. If None, | |||||
| it means that the predication process just uses single device. | |||||
| Default: None. | |||||
| Raises: | Raises: | ||||
| TypeError: The type of inputs do not match the requirements. | TypeError: The type of inputs do not match the requirements. | ||||