Merge pull request !28334 from changzherui/mod_comment1feature/build-system-rewrite
| @@ -1,11 +1,11 @@ | |||
| .. py:class:: mindspore.train.callback.Callback | |||
| 用于构建回调函数的基类。回调函数是一个上下文管理器,在运行模型时被调用。 | |||
| 可以使用此机制进行初始化和释放资源等操作。 | |||
| 用于构建Callback函数的基类。Callback函数是一个上下文管理器,在运行模型时被调用。 | |||
| 可以使用此机制进行一些自定义操作。 | |||
| 回调函数可以在step或epoch中的执行一些操作。 | |||
| 它保存模型相关信息。例如 `network` 、 `train_network` 、 `epoch_num` 、 `batch_num` 、 `loss_fn` 、 `optimizer` 、 `parallel_mode` 、 `device_number` 、 `list_callback` 、 `cur_epoch_num` 、 `cur_step_num` 、 `dataset_sink_mode` 、 `net_outputs` 等。 | |||
| Callback函数可以在step或epoch开始前或结束后执行一些操作。 | |||
| 要创建自定义Callback,需要继承Callback基类并重载它相应的方法,有关自定义Callback的详细信息,请查看 | |||
| `Callback <https://www.mindspore.cn/docs/programming_guide/zh-CN/master/custom_debugging_info.html>`_。 | |||
| **样例:** | |||
| @@ -63,53 +63,97 @@ | |||
| .. py:method:: append_dict | |||
| :property: | |||
| 获取checkpoint中添加字典里面的值。 | |||
| 获取需要额外保存到checkpoint中的字典的值。 | |||
| **返回:** | |||
| Dict: 字典中的值。 | |||
| .. py:method:: async_save | |||
| :property: | |||
| 获取是否异步保存checkpoint。 | |||
| **返回:** | |||
| Bool: 是否异步保存checkpoint。 | |||
| .. py:method:: enc_key | |||
| :property: | |||
| 获取加密的key值。 | |||
| **返回:** | |||
| (None, bytes): 加密的key值。 | |||
| .. py:method:: enc_mode | |||
| :property: | |||
| 获取加密模式。 | |||
| **返回:** | |||
| str: 加密模式。 | |||
| .. py:method:: get_checkpoint_policy() | |||
| 获取checkpoint的保存策略。 | |||
| **返回:** | |||
| Dict: checkpoint的保存策略。 | |||
| .. py:method:: integrated_save | |||
| :property: | |||
| 获取是否合并保存拆分后的Tensor。 | |||
| **返回:** | |||
| Bool: 获取是否合并保存拆分后的Tensor。 | |||
| .. py:method:: keep_checkpoint_max | |||
| :property: | |||
| 获取最多保存checkpoint文件的数量。 | |||
| **返回:** | |||
| Int: 最多保存checkpoint文件的数量。 | |||
| .. py:method:: keep_checkpoint_per_n_minutes | |||
| :property: | |||
| 获取每隔多少分钟保存一个checkpoint文件。 | |||
| **返回:** | |||
| Int: 每隔多少分钟保存一个checkpoint文件。 | |||
| .. py:method:: saved_network | |||
| :property: | |||
| 获取保存的网络。 | |||
| 获取需要保存的网络。 | |||
| **返回:** | |||
| Cell: 需要保存的网络。 | |||
| .. py:method:: save_checkpoint_seconds | |||
| :property: | |||
| 获取每隔多少秒保存一次checkpoint文件。。 | |||
| 获取每隔多少秒保存一次checkpoint文件。 | |||
| **返回:** | |||
| Int: 每隔多少秒保存一次checkpoint文件。 | |||
| .. py:method:: save_checkpoint_steps | |||
| :property: | |||
| 获取每隔多少个step保存一次checkpoint文件。 | |||
| **返回:** | |||
| Int: 每隔多少个step保存一次checkpoint文件。 | |||
| @@ -1,6 +1,6 @@ | |||
| .. py:class:: mindspore.train.callback.LearningRateScheduler(learning_rate_function) | |||
| 在训练期间更改学习率。 | |||
| 用于在训练期间更改学习率。 | |||
| **参数:** | |||
| @@ -21,4 +21,4 @@ | |||
| **参数:** | |||
| - **run_context** (RunContext) - 包含模型的一些基本信息。 | |||
| - **run_context** (RunContext) - 包含模型的相关信息。 | |||
| @@ -3,7 +3,7 @@ | |||
| 提供模型的相关信息。 | |||
| 在Model方法里提供模型的相关信息。 | |||
| 回调函数可以通过调用 `request_stop()` 方法来停止循环。 | |||
| 回调函数可以调用 `request_stop()` 方法来停止迭代。 | |||
| **参数:** | |||
| @@ -11,7 +11,7 @@ | |||
| .. py:method:: get_stop_requested() | |||
| 获取是否停止训练标志。 | |||
| 获取是否停止训练的标志。 | |||
| **返回:** | |||
| @@ -19,11 +19,11 @@ | |||
| .. py:method:: original_args() | |||
| 获取模型的相关信息。 | |||
| 获取模型相关信息的对象。 | |||
| **返回:** | |||
| dict,模型的相关信息。 | |||
| dict,含有模型的相关信息的对象。 | |||
| .. py:method:: request_stop() | |||
| @@ -5,9 +5,8 @@ mindspore.async_ckpt_thread_status | |||
| 获取异步保存checkpoint文件线程的状态。 | |||
| 在执行异步保存checkpoint时,可以通过该函数获取线程状态以确保写入checkpoint文件已完成。 | |||
| 在执行异步保存checkpoint时,判断异步线程是否执行完毕。 | |||
| **返回:** | |||
| True,异步保存checkpoint线程正在运行。 | |||
| False,异步保存checkpoint线程未运行。 | |||
| Bool: True,异步保存checkpoint线程正在运行。False,异步保存checkpoint线程未运行。 | |||
| @@ -5,22 +5,22 @@ mindspore.load | |||
| 加载MindIR文件。 | |||
| 返回的对象可以由 `GraphCell` 执行,更多细节参见类 :class:`mindspore.nn.GraphCell` 。 | |||
| 返回一个可以由 `GraphCell` 执行的对象,更多细节参见类 :class:`mindspore.nn.GraphCell`。 | |||
| **参数:** | |||
| - **file_name** (str) – MindIR文件名。 | |||
| - **file_name** (str) – MindIR文件的全路径名。 | |||
| - **kwargs** (dict) – 配置项字典。 | |||
| - **dec_key** (bytes) - 用于解密的字节类型密钥。 有效长度为 16、24 或 32。 | |||
| - **dec_mode** - 指定解密模式,设置dec_key时生效。可选项:'AES-GCM' | 'AES-CBC'。 默认值:“AES-GCM”。 | |||
| - **dec_mode** (str) - 指定解密模式,设置dec_key时生效。可选项:'AES-GCM' | 'AES-CBC'。 默认值:"AES-GCM"。 | |||
| **返回:** | |||
| Object,一个可以由 `GraphCell` 构成的可执行的编译图。 | |||
| GraphCell,一个可以由 `GraphCell` 构成的可执行的编译图。 | |||
| **异常:** | |||
| - **ValueError** – MindIR 文件名不正确。 | |||
| - **ValueError** – MindIR 文件名不存在或`file_name`不是string类型。 | |||
| - **RuntimeError** - 解析MindIR文件失败。 | |||
| **样例:** | |||
| @@ -20,7 +20,7 @@ mindspore.load_checkpoint | |||
| **异常:** | |||
| - **ValueError** – checkpoint文件格式正确。 | |||
| - **ValueError** – checkpoint文件格式不正确。 | |||
| **样例:** | |||
| @@ -3,7 +3,7 @@ mindspore.load_param_into_net | |||
| .. py:class:: mindspore.load_param_into_net(net, parameter_dict, strict_load=False) | |||
| 将参数加载到网络中。 | |||
| 将参数加载到网络中,返回网络中没有被加载的参数列表。 | |||
| **参数:** | |||
| @@ -3,13 +3,11 @@ mindspore.parse_print | |||
| .. py:class:: mindspore.parse_print(print_file_name) | |||
| 解析由 mindspore.ops.Print 生成的保存数据。 | |||
| 将数据打印到屏幕上。也可以通过设置 `context` 中的参数 `print_file_path` 来关闭,数据会保存在 `print_file_path` 指定的文件中。 parse_print 用于解析保存的文件。 更多信息请参考 :func:`mindspore.context.set_context` 和 :class:`mindspore.ops.Print` 。 | |||
| 解析由 mindspore.ops.Print 生成的数据文件。 | |||
| **参数:** | |||
| **print_file_name** (str) – 保存打印数据的文件名。 | |||
| **print_file_name** (str) – 需要解析的文件名。 | |||
| **返回:** | |||
| @@ -17,12 +15,12 @@ mindspore.parse_print | |||
| **异常:** | |||
| **ValueError** – 指定的文件名可能为空,请确保输入正确的文件名。 | |||
| **ValueError** – 指定的文件不存在或为空。 | |||
| **RuntimeError** - 解析文件失败。 | |||
| **样例:** | |||
| >>> import numpy as np | |||
| >>> import mindspore | |||
| >>> import mindspore.ops as ops | |||
| >>> from mindspore.nn as nn | |||
| >>> from mindspore import Tensor, context | |||
| @@ -40,8 +38,10 @@ mindspore.parse_print | |||
| >>> input_pra = Tensor(x) | |||
| >>> net = PrintInputTensor() | |||
| >>> net(input_pra) | |||
| >>> import mindspore | |||
| >>> data = mindspore.parse_print('./log.data') | |||
| >>> print(data) | |||
| ['print:', Tensor(shape=[2, 4], dtype=Float32, value= | |||
| [[ 1.00000000e+00, 2.00000000e+00, 3.00000000e+00, 4.00000000e+00], | |||
| [ 5.00000000e+00, 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]])] | |||
| [ 5.00000000e+00, 6.00000000e+00, 7.00000000e+00, 8.00000000e+00]])] | |||
| @@ -77,17 +77,13 @@ class Callback: | |||
| """ | |||
| Abstract base class used to build a callback class. Callbacks are context managers | |||
| which will be entered and exited when passing into the Model. | |||
| You can use this mechanism to initialize and release resources automatically. | |||
| You can use this mechanism to do some custom operations. | |||
| Callback function will execute some operations in the current step or epoch. | |||
| Callback function can perform some operations before and after step or epoch. | |||
| To create a custom callback, subclass Callback and override the method associated | |||
| with the stage of interest. For details of Callback fusion, please check | |||
| `Callback <https://www.mindspore.cn/docs/programming_guide/zh-CN/master/custom_debugging_info.html>`_. | |||
| It holds the information of the model. Such as `network`, `train_network`, `epoch_num`, `batch_num`, | |||
| `loss_fn`, `optimizer`, `parallel_mode`, `device_number`, `list_callback`, `cur_epoch_num`, | |||
| `cur_step_num`, `dataset_sink_mode`, `net_outputs` and so on. | |||
| Examples: | |||
| >>> from mindspore import Model, nn | |||
| >>> from mindspore.train.callback import Callback | |||
| @@ -196,7 +196,12 @@ class CheckpointConfig: | |||
| @property | |||
| def save_checkpoint_steps(self): | |||
| """Get the value of _save_checkpoint_steps.""" | |||
| """ | |||
| Get the value of steps to save checkpoint. | |||
| Returns: | |||
| Int, steps to save checkpoint. | |||
| """ | |||
| return self._save_checkpoint_steps | |||
| @property | |||
| @@ -206,46 +211,91 @@ class CheckpointConfig: | |||
| @property | |||
| def keep_checkpoint_max(self): | |||
| """Get the value of _keep_checkpoint_max.""" | |||
| """ | |||
| Get the value of maximum number of checkpoint files can be saved. | |||
| Returns: | |||
| Int, Maximum number of checkpoint files can be saved. | |||
| """ | |||
| return self._keep_checkpoint_max | |||
| @property | |||
| def keep_checkpoint_per_n_minutes(self): | |||
| """Get the value of _keep_checkpoint_per_n_minutes.""" | |||
| """ | |||
| Get the value of save the checkpoint file every n minutes. | |||
| Returns: | |||
| Int, save the checkpoint file every n minutes. | |||
| """ | |||
| return self._keep_checkpoint_per_n_minutes | |||
| @property | |||
| def integrated_save(self): | |||
| """Get the value of _integrated_save.""" | |||
| """ | |||
| Get the value of whether to merge and save the split Tensor in the automatic parallel scenario. | |||
| Returns: | |||
| Bool, whether to merge and save the split Tensor in the automatic parallel scenario. | |||
| """ | |||
| return self._integrated_save | |||
| @property | |||
| def async_save(self): | |||
| """Get the value of _async_save.""" | |||
| """ | |||
| Get the value of whether asynchronous execution saves the checkpoint to a file. | |||
| Returns: | |||
| Bool, whether asynchronous execution saves the checkpoint to a file. | |||
| """ | |||
| return self._async_save | |||
| @property | |||
| def saved_network(self): | |||
| """Get the value of _saved_network""" | |||
| """ | |||
| Get the value of network to be saved in checkpoint file. | |||
| Returns: | |||
| Cell, network to be saved in checkpoint file. | |||
| """ | |||
| return self._saved_network | |||
| @property | |||
| def enc_key(self): | |||
| """Get the value of _enc_key""" | |||
| """ | |||
| Get the value of byte type key used for encryption. | |||
| Returns: | |||
| (None, bytes), byte type key used for encryption. | |||
| """ | |||
| return self._enc_key | |||
| @property | |||
| def enc_mode(self): | |||
| """Get the value of _enc_mode""" | |||
| """ | |||
| Get the value of the encryption mode. | |||
| Returns: | |||
| str, encryption mode. | |||
| """ | |||
| return self._enc_mode | |||
| @property | |||
| def append_dict(self): | |||
| """Get the value of append_dict.""" | |||
| """ | |||
| Get the value of information dict saved to checkpoint file. | |||
| Returns: | |||
| Dict, the information saved to checkpoint file. | |||
| """ | |||
| return self._append_dict | |||
| def get_checkpoint_policy(self): | |||
| """Get the policy of checkpoint.""" | |||
| """ | |||
| Get the policy of checkpoint. | |||
| Returns: | |||
| Dict, the information of checkpoint policy. | |||
| """ | |||
| checkpoint_policy = {'save_checkpoint_steps': self.save_checkpoint_steps, | |||
| 'save_checkpoint_seconds': self.save_checkpoint_seconds, | |||
| 'keep_checkpoint_max': self.keep_checkpoint_max, | |||
| @@ -343,10 +343,10 @@ def load(file_name, **kwargs): | |||
| - dec_mode (str): Specifies the decryption mode, to take effect when dec_key is set. | |||
| Option: 'AES-GCM' | 'AES-CBC'. Default: 'AES-GCM'. | |||
| Returns: | |||
| Object, a compiled graph that can executed by `GraphCell`. | |||
| GraphCell, a compiled graph that can executed by `GraphCell`. | |||
| Raises: | |||
| ValueError: MindIR file name is incorrect. | |||
| ValueError: MindIR file does not exist or `file_name` is not a string. | |||
| RuntimeError: Failed to parse MindIR file. | |||
| Examples: | |||
| @@ -417,7 +417,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N | |||
| Dict, key is parameter name, value is a Parameter. | |||
| Raises: | |||
| ValueError: Checkpoint file is incorrect. | |||
| ValueError: Checkpoint file's format is incorrect. | |||
| Examples: | |||
| >>> from mindspore import load_checkpoint | |||
| @@ -535,7 +535,7 @@ def _check_checkpoint_param(ckpt_file_name, filter_prefix=None): | |||
| def load_param_into_net(net, parameter_dict, strict_load=False): | |||
| """ | |||
| Load parameters into network. | |||
| Load parameters into network, return parameter list that are not loaded in the network. | |||
| Args: | |||
| net (Cell): The network where the parameters will be loaded. | |||
| @@ -546,7 +546,7 @@ def load_param_into_net(net, parameter_dict, strict_load=False): | |||
| on the parameters of the same type, such as float32 to float16. Default: False. | |||
| Returns: | |||
| List, parameter name not loaded into the network | |||
| List, the parameter name which are not loaded into the network. | |||
| Raises: | |||
| TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary. | |||
| @@ -878,9 +878,7 @@ def _generate_front_info_for_param_data_file(is_encrypt, kwargs): | |||
| def _change_file(f, dirname, external_local, is_encrypt, kwargs): | |||
| ''' | |||
| Change to another file to write parameter data | |||
| ''' | |||
| """Change to another file to write parameter data.""" | |||
| # The parameter has been not written in the file | |||
| front_info = _generate_front_info_for_param_data_file(is_encrypt, kwargs) | |||
| f.seek(0, 0) | |||
| @@ -1071,10 +1069,7 @@ def _save_dataset_to_mindir(model, dataset): | |||
| def quant_mode_manage(func): | |||
| """ | |||
| Inherit the quant_mode in old version. | |||
| """ | |||
| """Inherit the quant_mode in old version.""" | |||
| def warpper(network, *inputs, file_format, **kwargs): | |||
| if 'quant_mode' not in kwargs: | |||
| return network | |||
| @@ -1091,9 +1086,7 @@ def quant_mode_manage(func): | |||
| @quant_mode_manage | |||
| def _quant_export(network, *inputs, file_format, **kwargs): | |||
| """ | |||
| Exports MindSpore quantization predict model to deploy with AIR and MINDIR. | |||
| """ | |||
| """Exports MindSpore quantization predict model to deploy with AIR and MINDIR.""" | |||
| supported_device = ["Ascend", "GPU"] | |||
| supported_formats = ['AIR', 'MINDIR'] | |||
| quant_mode_formats = ['QUANT', 'NONQUANT'] | |||
| @@ -1131,23 +1124,20 @@ def _quant_export(network, *inputs, file_format, **kwargs): | |||
| def parse_print(print_file_name): | |||
| """ | |||
| Parse saved data generated by mindspore.ops.Print. Print is used to print data to screen in graph mode. | |||
| It can also been turned off by setting the parameter `print_file_path` in `context`, and the data will be saved | |||
| in a file specified by print_file_path. parse_print is used to parse the saved file. For more information | |||
| please refer to :func:`mindspore.context.set_context` and :class:`mindspore.ops.Print`. | |||
| Parse data file generated by mindspore.ops.Print. | |||
| Args: | |||
| print_file_name (str): The file name of saved print data. | |||
| print_file_name (str): The file name needs to be parsed. | |||
| Returns: | |||
| List, element of list is Tensor. | |||
| Raises: | |||
| ValueError: The print file may be empty, please make sure enter the correct file name. | |||
| ValueError: The print file does not exist or is empty. | |||
| RuntimeError: Failed to parse the file. | |||
| Examples: | |||
| >>> import numpy as np | |||
| >>> import mindspore | |||
| >>> import mindspore.ops as ops | |||
| >>> from mindspore import nn | |||
| >>> from mindspore import Tensor, context | |||
| @@ -1633,12 +1623,11 @@ def async_ckpt_thread_status(): | |||
| """ | |||
| Get the status of asynchronous save checkpoint thread. | |||
| When performing asynchronous save checkpoint, you can get the thread state through this function | |||
| to ensure that write checkpoint file is completed. | |||
| When performing asynchronous save checkpoint, you can determine whether the asynchronous thread is completed. | |||
| Returns: | |||
| True, Asynchronous save checkpoint thread is running. | |||
| False, Asynchronous save checkpoint thread is not executing. | |||
| bool, True, Asynchronous save checkpoint thread is running. | |||
| False, Asynchronous save checkpoint thread is not executing. | |||
| """ | |||
| thr_list = threading.enumerate() | |||
| return True in [ele.getName() == "asyn_save_ckpt" for ele in thr_list] | |||