Browse Source

Update code syntax and examples in train

tags/v1.2.0-rc1
zhiqiang 4 years ago
parent
commit
c09827daab
5 changed files with 83 additions and 39 deletions
  1. +1
    -1
      mindspore/train/amp.py
  2. +15
    -7
      mindspore/train/dataset_helper.py
  3. +9
    -2
      mindspore/train/loss_scale_manager.py
  4. +35
    -13
      mindspore/train/model.py
  5. +23
    -16
      mindspore/train/serialization.py

+ 1
- 1
mindspore/train/amp.py View File

@@ -136,7 +136,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs):
keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, overwrite the level setting.
Only `cast_model_type` is `float16`, `keep_batchnorm_fp32` will take effect.
loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else
scale the loss by LossScaleManager. If set, overwrite the level setting.
scale the loss by `LossScaleManager`. If set, overwrite the level setting.
"""
validator.check_value_type('network', network, nn.Cell)
validator.check_value_type('optimizer', optimizer, nn.Optimizer)


+ 15
- 7
mindspore/train/dataset_helper.py View File

@@ -53,17 +53,19 @@ def connect_network_with_dataset(network, dataset_helper):

Args:
network (Cell): The training network for dataset.
dataset_helper(DatasetHelper): A class to process the MindData dataset, it provides the type, shape and queue
dataset_helper (DatasetHelper): A class to process the MindData dataset, it provides the type, shape and queue
name of the dataset to wrap the `GetNext`.

Outputs:
Returns:
Cell, a new network wrapped with 'GetNext' in the case of running the task on Ascend in graph mode, otherwise
it is the input network.

Examples:
>>> from mindspore import DatasetHelper
>>>
>>> # call create_dataset function to create a regular dataset, refer to mindspore.dataset
>>> train_dataset = create_custom_dataset()
>>> dataset_helper = mindspore.DatasetHelper(train_dataset, dataset_sink_mode=True)
>>> dataset_helper = DatasetHelper(train_dataset, dataset_sink_mode=True)
>>> net = Net()
>>> net_with_get_next = connect_network_with_dataset(net, dataset_helper)
"""
@@ -145,14 +147,18 @@ class DatasetHelper:
The iteration of DatasetHelper will provide one epoch data.

Args:
dataset (DataSet): The training dataset iterator.
dataset_sink_mode (bool): If true use GetNext to fetch the data, or else feed the data from host. Default: True.
dataset (Dataset): The training dataset iterator.
dataset_sink_mode (bool): If true use GetNext to fetch the data, or else feed the data
from host. Default: True.
sink_size (int): Control the amount of data in each sink.
If sink_size=-1, sink the complete dataset for each epoch.
If sink_size>0, sink sink_size data for each epoch. Default: -1.
If sink_size=-1, sink the complete dataset for each epoch.
If sink_size>0, sink sink_size data for each epoch.
Default: -1.
epoch_num (int): Control the number of epoch data to send. Default: 1.

Examples:
>>> from mindspore import nn, DatasetHelper
>>>
>>> network = Net()
>>> net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
>>> network = nn.WithLossCell(network, net_loss)
@@ -373,6 +379,7 @@ class _DatasetIterPSServer(_DatasetIter):

self.op = op


class _DatasetIterPSWork(_DatasetIter):
"""Iter for context on MS_WORKER"""

@@ -388,6 +395,7 @@ class _DatasetIterPSWork(_DatasetIter):

self.op = op


class _DatasetIterNormal:
"""Iter for normal(non sink) mode, feed the data from host."""



+ 9
- 2
mindspore/train/loss_scale_manager.py View File

@@ -33,6 +33,7 @@ class LossScaleManager:
def get_update_cell(self):
"""Get the loss scaling update logic cell."""


class FixedLossScaleManager(LossScaleManager):
"""
Fixed loss-scale manager.
@@ -42,9 +43,12 @@ class FixedLossScaleManager(LossScaleManager):
drop_overflow_update (bool): whether to execute optimizer if there is an overflow. Default: True.

Examples:
>>> from mindspore import Model, nn
>>> from mindspore.train.loss_scale_manager import FixedLossScaleManager
>>>
>>> net = Net()
>>> loss_scale_manager = FixedLossScaleManager()
>>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> model = Model(net, loss_scale_manager=loss_scale_manager, optimizer=optim)
"""
def __init__(self, loss_scale=128.0, drop_overflow_update=True):
@@ -87,9 +91,12 @@ class DynamicLossScaleManager(LossScaleManager):
scale_window (int): Maximum continuous normal steps when there is no overflow. Default: 2000.

Examples:
>>> from mindspore import Model, nn
>>> from mindspore.train.loss_scale_manager import DynamicLossScaleManager
>>>
>>> net = Net()
>>> loss_scale_manager = DynamicLossScaleManager()
>>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> model = Model(net, loss_scale_manager=loss_scale_manager, optimizer=optim)
"""
def __init__(self,


+ 35
- 13
mindspore/train/model.py View File

@@ -76,7 +76,6 @@ class Model:
to other metric. Default: None.
amp_level (str): Option for argument `level` in `mindspore.amp.build_train_network`, level for mixed
precision training. Supports ["O0", "O2", "O3", "auto"]. Default: "O0".

- O0: Do not change.
- O2: Cast network to float16, keep batchnorm run in float32, using dynamic loss scale.
- O3: Cast network to float16, with additional property 'keep_batchnorm_fp32=False'.
@@ -93,6 +92,8 @@ class Model:
will be overwritten. Default: True.

Examples:
>>> from mindspore import Model, nn
>>>
>>> class Net(nn.Cell):
... def __init__(self, num_class=10, num_channel=1):
... super(Net, self).__init__()
@@ -118,7 +119,8 @@ class Model:
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
>>> # For details about how to build the dataset, please refer to the tutorial document on the official website.
>>> # For details about how to build the dataset, please refer to the tutorial
>>> # document on the official website.
>>> dataset = create_custom_dataset()
>>> model.train(2, dataset)
"""
@@ -565,23 +567,29 @@ class Model:
returned and passed to the network. Otherwise, a tuple (data, label) should
be returned. The data and label would be passed to the network and loss
function respectively.
callbacks (list, object): List of callback objects or callback object, which should be executed
while training. Default: None.
callbacks (Optional[list[Callback], Callback]): List of callback objects or callback object,
which should be executed while training.
Default: None.
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.
Configure pynative mode or CPU, the training process will be performed with
dataset not sink.
dataset not sink. Default: True.
sink_size (int): Control the amount of data in each sink.
If sink_size = -1, sink the complete dataset for each epoch.
If sink_size > 0, sink sink_size data for each epoch.
If dataset_sink_mode is False, set sink_size as invalid. Default: -1.
If dataset_sink_mode is False, set sink_size as invalid.
Default: -1.

Examples:
>>> from mindspore import Model, nn
>>> from mindspore.train.loss_scale_manager import FixedLossScaleManager
>>>
>>> # For details about how to build the dataset, please refer to the tutorial
>>> # document on the official website.
>>> dataset = create_custom_dataset()
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> loss_scale_manager = FixedLossScaleManager()
>>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> optim = nn.Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager)
>>> model.train(2, dataset)
"""
@@ -690,13 +698,19 @@ class Model:

Args:
valid_dataset (Dataset): Dataset to evaluate the model.
callbacks (list): List of callback objects which should be executed while training. Default: None.
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel. Default: True.
callbacks (Optional[list(Callback)]): List of callback objects which should be executed
while training. Default: None.
dataset_sink_mode (bool): Determines whether to pass the data through dataset channel.
Default: True.

Returns:
Dict, which returns the loss value and metrics values for the model in the test mode.

Examples:
>>> from mindspore import Model, nn
>>>
>>> # For details about how to build the dataset, please refer to the tutorial
>>> # document on the official website.
>>> dataset = create_custom_dataset()
>>> net = Net()
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
@@ -739,14 +753,17 @@ class Model:
Batch data should be put together in one tensor.

Args:
predict_data: The predict data, can be bool, int, float, str, None, tensor,
predict_data (Tensor): The predict data, can be bool, int, float, str, None, tensor,
or tuple, list and dict that store these types.

Returns:
Tensor, array(s) of predictions.

Examples:
>>> input_data = Tensor(np.random.randint(0, 255, [1, 1, 32, 32]), mindspore.float32)
>>> import mindspore as ms
>>> from mindspore import Model, Tensor
>>>
>>> input_data = Tensor(np.random.randint(0, 255, [1, 1, 32, 32]), ms.float32)
>>> model = Model(Net())
>>> result = model.predict(input_data)
"""
@@ -771,12 +788,16 @@ class Model:
predict_data (Tensor): One tensor or multiple tensors of predict data.

Returns:
parameter_layout_dict (dict): Parameter layout dictionary used for load distributed checkpoint
Dict, Parameter layout dictionary used for load distributed checkpoint

Examples:
>>> import numpy as np
>>> import mindspore as ms
>>> from mindspore import Model, context, Tensor
>>>
>>> context.set_context(mode=context.GRAPH_MODE)
>>> context.set_auto_parallel_context(full_batch=True, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
>>> input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32)
>>> input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), ms.float32)
>>> model = Model(Net())
>>> model.infer_predict_layout(input_data)
"""
@@ -802,4 +823,5 @@ class Model:
if param.cache_enable:
Tensor(param).flush_from_cache()


__all__ = ["Model"]

+ 23
- 16
mindspore/train/serialization.py View File

@@ -153,16 +153,17 @@ def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=F
Saves checkpoint info to a specified file.

Args:
save_obj (nn.Cell or list): The cell object or data list(each element is a dictionary, like
[{"name": param_name, "data": param_data},...], the type of param_name would
be string, and the type of param_data would be parameter or tensor).
save_obj (Union[Cell, list]): The cell object or data list(each element is a dictionary, like
[{"name": param_name, "data": param_data},...], the type of
param_name would be string, and the type of param_data would
be parameter or `Tensor`).
ckpt_file_name (str): Checkpoint file name. If the file name already exists, it will be overwritten.
integrated_save (bool): Whether to integrated save in automatic model parallel scene. Default: True
async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False

Raises:
TypeError: If the parameter save_obj is not nn.Cell or list type.And if the parameter integrated_save and
async_save are not bool type.
TypeError: If the parameter save_obj is not `nn.Cell` or list type. And if the parameter
`integrated_save` and `async_save` are not bool type.
"""

if not isinstance(save_obj, nn.Cell) and not isinstance(save_obj, list):
@@ -247,6 +248,8 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
ValueError: Checkpoint file is incorrect.

Examples:
>>> from mindspore import load_checkpoint
>>>
>>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
>>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1")
"""
@@ -349,6 +352,8 @@ def load_param_into_net(net, parameter_dict, strict_load=False):
TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary.

Examples:
>>> from mindspore import load_checkpoint, load_param_into_net
>>>
>>> net = Net()
>>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
>>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1")
@@ -531,7 +536,6 @@ def export(net, *inputs, file_name, file_format='AIR', **kwargs):
inputs (Tensor): Inputs of the `net`.
file_name (str): File name of the model to be exported.
file_format (str): MindSpore currently supports 'AIR', 'ONNX' and 'MINDIR' format for exported model.

- AIR: Ascend Intermediate Representation. An intermediate representation format of Ascend model.
Recommended suffix for output file is '.air'.
- ONNX: Open Neural Network eXchange. An open format built to represent machine learning models.
@@ -541,7 +545,6 @@ def export(net, *inputs, file_name, file_format='AIR', **kwargs):
Recommended suffix for output file is '.mindir'.

kwargs (dict): Configuration options dictionary.

- quant_mode: The mode of quant.
- mean: Input data mean. Default: 127.5.
- std_dev: Input data variance. Default: 127.5.
@@ -928,11 +931,9 @@ def merge_sliced_parameter(sliced_parameters, strategy=None):

Args:
sliced_parameters (list[Parameter]): Parameter slices in order of rank_id.
strategy (dict): Parameter slice strategy, the default is None.
If strategy is None, just merge parameter slices in 0 axis order.

- key (str): Parameter name.
- value (<class 'node_strategy_pb2.ParallelLayouts'>): Slice strategy of this parameter.
strategy (Optional[dict]): Parameter slice strategy, whose key is parameter name and
value is slice strategy of this parameter. If strategy is None, just merge
parameter slices in 0 axis order. Default: None.

Returns:
Parameter, the merged parameter which has the whole data.
@@ -943,6 +944,9 @@ def merge_sliced_parameter(sliced_parameters, strategy=None):
KeyError: The parameter name is not in keys of strategy.

Examples:
>>> from mindspore.common.parameter import Parameter
>>> from mindspore.train import merge_sliced_parameter
>>>
>>> sliced_parameters = [
... Parameter(Tensor(np.array([0.00023915, 0.00013939, -0.00098059])),
... "network.embedding_table"),
@@ -1010,10 +1014,13 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=

Args:
network (Cell): Network for distributed predication.
checkpoint_filenames (list[str]): The name of Checkpoint files in order of rank id.
predict_strategy (dict): Strategy of predication process, whose key is parameter name, and value is a list or
a tuple that the first four elements are [dev_matrix, tensor_map, param_split_shape, field]. If None,
it means that the predication process just uses single device. Default: None.
checkpoint_filenames (list(str)): The name of Checkpoint files
in order of rank id.
predict_strategy (Optional(dict)): Strategy of predication process, whose key
is parameter name, and value is a list or a tuple that the first four
elements are [dev_matrix, tensor_map, param_split_shape, field]. If None,
it means that the predication process just uses single device.
Default: None.

Raises:
TypeError: The type of inputs do not match the requirements.


Loading…
Cancel
Save