Browse Source

modify comment

tags/v1.3.0
changzherui 4 years ago
parent
commit
4bbf7682fd
7 changed files with 24 additions and 16 deletions
  1. +1
    -1
      mindspore/train/callback/_checkpoint.py
  2. +1
    -1
      mindspore/train/callback/_dataset_graph.py
  3. +1
    -1
      mindspore/train/callback/_loss_monitor.py
  4. +3
    -1
      mindspore/train/callback/_lr_scheduler_callback.py
  5. +6
    -4
      mindspore/train/callback/_time_monitor.py
  6. +4
    -1
      mindspore/train/model.py
  7. +8
    -7
      mindspore/train/serialization.py

+ 1
- 1
mindspore/train/callback/_checkpoint.py View File

@@ -88,7 +88,7 @@ class CheckpointConfig:
mode, currently supports 'AES-GCM' and 'AES-CBC'. Default: 'AES-GCM'.

Raises:
ValueError: If the input_param is None or 0.
ValueError: If input parameter is not the correct type.

Examples:
>>> class LeNet5(nn.Cell):


+ 1
- 1
mindspore/train/callback/_dataset_graph.py View File

@@ -27,7 +27,7 @@ class DatasetGraph:
packages dataset graph into binary data

Args:
dataset (MindData): refer to MindDataset
dataset (MindDataset): Refer to MindDataset.

Returns:
DatasetGraph, a object of lineage_pb2.DatasetGraph.


+ 1
- 1
mindspore/train/callback/_loss_monitor.py View File

@@ -33,7 +33,7 @@ class LossMonitor(Callback):
per_print_times (int): Print the loss each every time. Default: 1.

Raises:
ValueError: If print_step is not an integer or less than zero.
ValueError: If per_print_times is not an integer or less than zero.
"""

def __init__(self, per_print_times=1):


+ 3
- 1
mindspore/train/callback/_lr_scheduler_callback.py View File

@@ -17,11 +17,13 @@
import math
import numpy as np

from mindspore import log as logger
import mindspore.common.dtype as mstype
from mindspore.common.tensor import Tensor
from mindspore.train.callback._callback import Callback
from mindspore.ops import functional as F


class LearningRateScheduler(Callback):
"""
Change the learning_rate during training.
@@ -66,4 +68,4 @@ class LearningRateScheduler(Callback):
new_lr = self.learning_rate_function(lr, cb_params.cur_step_num)
if not math.isclose(lr, new_lr, rel_tol=1e-10):
F.assign(cb_params.optimizer.learning_rate, Tensor(new_lr, mstype.float32))
print(f'At step {cb_params.cur_step_num}, learning_rate change to {new_lr}')
logger.info(f'At step {cb_params.cur_step_num}, learning_rate change to {new_lr}')

+ 6
- 4
mindspore/train/callback/_time_monitor.py View File

@@ -16,7 +16,6 @@

import time

from mindspore import log as logger
from ._callback import Callback


@@ -25,12 +24,16 @@ class TimeMonitor(Callback):
Monitor the time in training.

Args:
data_size (int): Dataset size. Default: None.
data_size (int): How many steps to return time information default is dataset size. Default: None.

Raises:
ValueError: If data_size is not positive int.
"""

def __init__(self, data_size=None):
super(TimeMonitor, self).__init__()
self.data_size = data_size
self.epoch_time = time.time()

def epoch_begin(self, run_context):
self.epoch_time = time.time()
@@ -45,8 +48,7 @@ class TimeMonitor(Callback):
step_size = cb_params.batch_num

if not isinstance(step_size, int) or step_size < 1:
logger.error("data_size must be positive int.")
return
raise ValueError("data_size must be positive int.")

step_seconds = epoch_seconds / step_size
print("epoch time: {:5.3f} ms, per step time: {:5.3f} ms".format(epoch_seconds, step_seconds), flush=True)

+ 4
- 1
mindspore/train/model.py View File

@@ -804,7 +804,10 @@ class Model:
predict_data (Tensor): One tensor or multiple tensors of predict data.

Returns:
Dict, Parameter layout dictionary used for load distributed checkpoint
Dict, Parameter layout dictionary used for load distributed checkpoint.

Raises:
RuntimeError: If get_context is not GRAPH_MODE.

Examples:
>>> # This example should be run with multiple devices. Refer to the tutorial > Distributed Training on


+ 8
- 7
mindspore/train/serialization.py View File

@@ -286,8 +286,7 @@ def load(file_name):
"""
Load MindIR.

The returned object can be executed by a `GraphCell`. However, there are some limitations to the current use
of `GraphCell`, see class :class:`mindspore.nn.GraphCell` for more details.
The returned object can be executed by a `GraphCell`, see class :class:`mindspore.nn.GraphCell` for more details.

Args:
file_name (str): MindIR file name.
@@ -296,7 +295,8 @@ def load(file_name):
Object, a compiled graph that can executed by `GraphCell`.

Raises:
ValueError: MindIR file is incorrect.
ValueError: MindIR file name is incorrect.
RuntimeError: Failed to parse MindIR file.

Examples:
>>> import numpy as np
@@ -644,7 +644,7 @@ def export(net, *inputs, file_name, file_format='AIR', **kwargs):

Args:
net (Cell): MindSpore network.
inputs (Tensor): Inputs of the `net`.
inputs (Tensor): Inputs of the `net`, if the network has multiple inputs, incoming tuple(Tensor).
file_name (str): File name of the model to be exported.
file_format (str): MindSpore currently supports 'AIR', 'ONNX' and 'MINDIR' format for exported model.

@@ -825,6 +825,7 @@ def _mindir_save_together(net_dict, model):
return False
return True


def quant_mode_manage(func):
"""
Inherit the quant_mode in old version.
@@ -840,6 +841,7 @@ def quant_mode_manage(func):
return func(network, *inputs, file_format=file_format, **kwargs)
return warpper


@quant_mode_manage
def _quant_export(network, *inputs, file_format, **kwargs):
"""
@@ -1034,12 +1036,11 @@ def build_searched_strategy(strategy_filename):
strategy_filename (str): Name of strategy file.

Returns:
Dictionary, whose key is parameter name and value is slice strategy of this parameter.
Dict, whose key is parameter name and value is slice strategy of this parameter.

Raises:
ValueError: Strategy file is incorrect.
TypeError: Strategy_filename is not str.

"""
if not isinstance(strategy_filename, str):
raise TypeError(f"The strategy_filename should be str, but got {type(strategy_filename)}.")
@@ -1162,7 +1163,7 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy=
predict_strategy (dict): Strategy of predication process, whose key is parameter name, and value is a list or
a tuple that the first four elements are [dev_matrix, tensor_map, param_split_shape, field]. If None,
it means that the predication process just uses single device. Default: None.
train_strategy_filename (str): Train strategy file. Default: None.
train_strategy_filename (str): Train strategy proto file name. Default: None.
dec_key (Union[None, bytes]): Byte type key used for decryption. If the value is None, the decryption
is not required. Default: None.
dec_mode (str): This parameter is valid only when dec_key is not set to None. Specifies the decryption


Loading…
Cancel
Save