mindspore.dataset.WaitedDSCallback ================================== .. py:class:: mindspore.dataset.WaitedDSCallback(step_size=1) 数据集自定义回调类的抽象基类,用于与训练回调类(`mindspore.callback `_)的同步。 可用于在每个step或epoch开始前执行自定义的回调方法,注意,第二个step或epoch开始时才会触发该调用。 例如在自动数据增强中根据上一个epoch的loss值来更新增强算子参数配置。 用户可通过 `train_run_context` 获取模型相关信息。如 `network` 、 `train_network` 、 `epoch_num` 、 `batch_num` 、 `loss_fn` 、 `optimizer` 、 `parallel_mode` 、 `device_number` 、 `list_callback` 、 `cur_epoch_num` 、 `cur_step_num` 、 `dataset_sink_mode` 、 `net_outputs` 等,详见 `mindspore.callback `_ 。 用户可通过 `ds_run_context` 获取数据处理管道相关信息。包括 `cur_epoch_num` (当前epoch数)、 `cur_step_num_in_epoch` (当前epoch的step数)、 `cur_step_num` (当前step数)。 **参数:** - **step_size** (int, optional) - 每个step包含的数据行数。通常step_size与batch_size一致,默认值:1。 **样例:** >>> import mindspore.nn as nn >>> from mindspore.dataset import WaitedDSCallback >>> from mindspore import context >>> from mindspore.train import Model >>> from mindspore.train.callback import Callback >>> >>> context.set_context(mode=context.GRAPH_MODE, device_target="CPU") >>> >>> # 自定义用于数据处理管道同步数据的回调类 >>> class MyWaitedCallback(WaitedDSCallback): ... def __init__(self, events, step_size=1): ... super().__init__(step_size) ... self.events = events ... ... # epoch开始前数据处理管道要执行的回调函数 ... def sync_epoch_begin(self, train_run_context, ds_run_context): ... event = f"ds_epoch_begin_{ds_run_context.cur_epoch_num}_{ds_run_context.cur_step_num}" ... self.events.append(event) ... ... # step开始前数据处理管道要执行的回调函数 ... def sync_step_begin(self, train_run_context, ds_run_context): ... event = f"ds_step_begin_{ds_run_context.cur_epoch_num}_{ds_run_context.cur_step_num}" ... self.events.append(event) >>> >>> # 自定义用于网络训练时同步数据的回调类 >>> class MyMSCallback(Callback): ... def __init__(self, events): ... self.events = events ... ... # epoch结束网络训练要执行的回调函数 ... def epoch_end(self, run_context): ... cb_params = run_context.original_args() ... event = f"ms_epoch_end_{cb_params.cur_epoch_num}_{cb_params.cur_step_num}" ... self.events.append(event) ... ... # step结束网络训练要执行的回调函数 ... def step_end(self, run_context): ... cb_params = run_context.original_args() ... event = f"ms_step_end_{cb_params.cur_epoch_num}_{cb_params.cur_step_num}" ... self.events.append(event) >>> >>> # 自定义网络 >>> class Net(nn.Cell): ... def construct(self, x, y): ... return x >>> >>> # 声明一个网络训练与数据处理同步的数据 >>> events = [] >>> >>> # 声明数据处理管道和网络训练的回调类 >>> my_cb1 = MyWaitedCallback(events, 1) >>> my_cb2 = MyMSCallback(events) >>> arr = [1, 2, 3, 4] >>> # 构建数据处理管道 >>> data = ds.NumpySlicesDataset((arr, arr), column_names=["c1", "c2"], shuffle=False) >>> # 将数据处理管道的回调类加入到map中 >>> data = data.map(operations=(lambda x: x), callbacks=my_cb1) >>> >>> net = Net() >>> model = Model(net) >>> >>> # 将数据处理管道和网络训练的回调类加入到模型训练的回调列表中 >>> model.train(2, data, dataset_sink_mode=False, callbacks=[my_cb2, my_cb1]) .. py:method:: begin(run_context) 用于定义在网络训练开始前执行的回调方法。 **参数:** - **run_context** (RunContext) - 网络训练运行信息。 .. py:method:: ds_begin(ds_run_context) 用于定义在数据处理管道启动前执行的回调方法。 **参数:** - **ds_run_context** (RunContext) - 数据处理管道运行信息。 .. py:method:: ds_epoch_begin(ds_run_context) 内部方法,不能被调用或者重写。通过重写mindspore.dataset.DSCallback.ds_epoch_begin 实现与mindspore.train.callback.Callback.epoch_end回调同步。 **参数:** **ds_run_context**:数据处理管道运行信息。 .. py:method:: ds_epoch_end(ds_run_context) 用于定义在每个数据epoch结束后执行的回调方法。 **参数:** - **ds_run_context** (RunContext) - 数据处理管道运行信息。 .. py:method:: ds_step_begin(ds_run_context) 内部方法,不能被调用或者重写。通过重写mindspore.dataset.DSCallback.ds_step_begin 实现与mindspore.train.callback.Callback.step_end回调同步。 **参数:** **ds_run_context**:数据处理管道运行信息。 .. py:method:: ds_step_end(ds_run_context) 用于定义在每个数据step结束后执行的回调方法。 **参数:** - **ds_run_context** (RunContext) - 数据处理管道运行信息。 .. py:method:: end(run_context) 内部方法,当网络训练结束时释放等待。 **参数:** **run_context**:网络训练运行信息。 .. py:method:: epoch_begin(run_context) 用于定义在每个训练epoch开始前执行的回调方法。 **参数:** - **run_context** (RunContext) - 网络训练运行信息。 .. py:method:: epoch_end(run_context) 内部方法,不能被调用或重写。通过重写mindspore.train.callback.Callback.epoch_end来释放ds_epoch_begin的等待。 **参数:** **run_context**:网络训练运行信息。 .. py:method:: step_begin(run_context) 用于定义在每个训练step开始前执行的回调方法。 **参数:** - **run_context** (RunContext) - 网络训练运行信息。 .. py:method:: step_end(run_context) 内部方法,不能被调用或重写。通过重写mindspore.train.callback.Callback.step_end来释放 `ds_step_begin` 的等待。 **参数:** **run_context**:网络训练运行信息。 .. py:method:: sync_epoch_begin(train_run_context, ds_run_context) 用于定义在每个数据epoch开始前,训练epoch结束后执行的回调方法。 **参数:** - **train_run_context**:包含前一个epoch的反馈信息的网络训练运行信息。 - **ds_run_context**:数据处理管道运行信息。 .. py:method:: sync_step_begin(train_run_context, ds_run_context) 用于定义在每个数据step开始前,训练step结束后执行的回调方法。 **参数:** - **train_run_context**:包含前一个step的反馈信息的网络训练运行信息。 - **ds_run_context**:数据处理管道运行信息。