You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.dataset.WaitedDSCallback.rst 7.6 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. mindspore.dataset.WaitedDSCallback
  2. ==================================
  3. .. py:class:: mindspore.dataset.WaitedDSCallback(step_size=1)
  4. 数据集自定义回调类的抽象基类,用于与训练回调类(`mindspore.callback <https://mindspore.cn/docs/api/zh-CN/master/api_python/mindspore.train.html#mindspore.train.callback.Callback>`_)的同步。
  5. 可用于在每个step或epoch开始前执行自定义的回调方法,注意,第二个step或epoch开始时才会触发该调用。
  6. 例如在自动数据增强中根据上一个epoch的loss值来更新增强算子参数配置。
  7. 用户可通过 `train_run_context` 获取模型相关信息。如 `network` 、 `train_network` 、 `epoch_num` 、 `batch_num` 、 `loss_fn` 、 `optimizer` 、 `parallel_mode` 、 `device_number` 、 `list_callback` 、 `cur_epoch_num` 、 `cur_step_num` 、 `dataset_sink_mode` 、 `net_outputs` 等,详见 `mindspore.callback <https://mindspore.cn/docs/api/zh-CN/master/api_python/mindspore.train.html#mindspore.train.callback.Callback>`_ 。
  8. 用户可通过 `ds_run_context` 获取数据处理管道相关信息。包括 `cur_epoch_num` (当前epoch数)、 `cur_step_num_in_epoch` (当前epoch的step数)、 `cur_step_num` (当前step数)。
  9. **参数:**
  10. - **step_size** (int, optional) - 每个step包含的数据行数。通常step_size与batch_size一致,默认值:1。
  11. **样例:**
  12. >>> import mindspore.nn as nn
  13. >>> from mindspore.dataset import WaitedDSCallback
  14. >>> from mindspore import context
  15. >>> from mindspore.train import Model
  16. >>> from mindspore.train.callback import Callback
  17. >>>
  18. >>> context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
  19. >>>
  20. >>> # 自定义用于数据处理管道同步数据的回调类
  21. >>> class MyWaitedCallback(WaitedDSCallback):
  22. ... def __init__(self, events, step_size=1):
  23. ... super().__init__(step_size)
  24. ... self.events = events
  25. ...
  26. ... # epoch开始前数据处理管道要执行的回调函数
  27. ... def sync_epoch_begin(self, train_run_context, ds_run_context):
  28. ... event = f"ds_epoch_begin_{ds_run_context.cur_epoch_num}_{ds_run_context.cur_step_num}"
  29. ... self.events.append(event)
  30. ...
  31. ... # step开始前数据处理管道要执行的回调函数
  32. ... def sync_step_begin(self, train_run_context, ds_run_context):
  33. ... event = f"ds_step_begin_{ds_run_context.cur_epoch_num}_{ds_run_context.cur_step_num}"
  34. ... self.events.append(event)
  35. >>>
  36. >>> # 自定义用于网络训练时同步数据的回调类
  37. >>> class MyMSCallback(Callback):
  38. ... def __init__(self, events):
  39. ... self.events = events
  40. ...
  41. ... # epoch结束网络训练要执行的回调函数
  42. ... def epoch_end(self, run_context):
  43. ... cb_params = run_context.original_args()
  44. ... event = f"ms_epoch_end_{cb_params.cur_epoch_num}_{cb_params.cur_step_num}"
  45. ... self.events.append(event)
  46. ...
  47. ... # step结束网络训练要执行的回调函数
  48. ... def step_end(self, run_context):
  49. ... cb_params = run_context.original_args()
  50. ... event = f"ms_step_end_{cb_params.cur_epoch_num}_{cb_params.cur_step_num}"
  51. ... self.events.append(event)
  52. >>>
  53. >>> # 自定义网络
  54. >>> class Net(nn.Cell):
  55. ... def construct(self, x, y):
  56. ... return x
  57. >>>
  58. >>> # 声明一个网络训练与数据处理同步的数据
  59. >>> events = []
  60. >>>
  61. >>> # 声明数据处理管道和网络训练的回调类
  62. >>> my_cb1 = MyWaitedCallback(events, 1)
  63. >>> my_cb2 = MyMSCallback(events)
  64. >>> arr = [1, 2, 3, 4]
  65. >>> # 构建数据处理管道
  66. >>> data = ds.NumpySlicesDataset((arr, arr), column_names=["c1", "c2"], shuffle=False)
  67. >>> # 将数据处理管道的回调类加入到map中
  68. >>> data = data.map(operations=(lambda x: x), callbacks=my_cb1)
  69. >>>
  70. >>> net = Net()
  71. >>> model = Model(net)
  72. >>>
  73. >>> # 将数据处理管道和网络训练的回调类加入到模型训练的回调列表中
  74. >>> model.train(2, data, dataset_sink_mode=False, callbacks=[my_cb2, my_cb1])
  75. .. py:method:: begin(run_context)
  76. 用于定义在网络训练开始前执行的回调方法。
  77. **参数:**
  78. - **run_context** (RunContext) - 网络训练运行信息。
  79. .. py:method:: ds_begin(ds_run_context)
  80. 用于定义在数据处理管道启动前执行的回调方法。
  81. **参数:**
  82. - **ds_run_context** (RunContext) - 数据处理管道运行信息。
  83. .. py:method:: ds_epoch_begin(ds_run_context)
  84. 内部方法,不能被调用或者重写。通过重写mindspore.dataset.DSCallback.ds_epoch_begin 实现与mindspore.train.callback.Callback.epoch_end回调同步。
  85. **参数:**
  86. **ds_run_context**:数据处理管道运行信息。
  87. .. py:method:: ds_epoch_end(ds_run_context)
  88. 用于定义在每个数据epoch结束后执行的回调方法。
  89. **参数:**
  90. - **ds_run_context** (RunContext) - 数据处理管道运行信息。
  91. .. py:method:: ds_step_begin(ds_run_context)
  92. 内部方法,不能被调用或者重写。通过重写mindspore.dataset.DSCallback.ds_step_begin
  93. 实现与mindspore.train.callback.Callback.step_end回调同步。
  94. **参数:**
  95. **ds_run_context**:数据处理管道运行信息。
  96. .. py:method:: ds_step_end(ds_run_context)
  97. 用于定义在每个数据step结束后执行的回调方法。
  98. **参数:**
  99. - **ds_run_context** (RunContext) - 数据处理管道运行信息。
  100. .. py:method:: end(run_context)
  101. 内部方法,当网络训练结束时释放等待。
  102. **参数:**
  103. **run_context**:网络训练运行信息。
  104. .. py:method:: epoch_begin(run_context)
  105. 用于定义在每个训练epoch开始前执行的回调方法。
  106. **参数:**
  107. - **run_context** (RunContext) - 网络训练运行信息。
  108. .. py:method:: epoch_end(run_context)
  109. 内部方法,不能被调用或重写。通过重写mindspore.train.callback.Callback.epoch_end来释放ds_epoch_begin的等待。
  110. **参数:**
  111. **run_context**:网络训练运行信息。
  112. .. py:method:: step_begin(run_context)
  113. 用于定义在每个训练step开始前执行的回调方法。
  114. **参数:**
  115. - **run_context** (RunContext) - 网络训练运行信息。
  116. .. py:method:: step_end(run_context)
  117. 内部方法,不能被调用或重写。通过重写mindspore.train.callback.Callback.step_end来释放 `ds_step_begin` 的等待。
  118. **参数:**
  119. **run_context**:网络训练运行信息。
  120. .. py:method:: sync_epoch_begin(train_run_context, ds_run_context)
  121. 用于定义在每个数据epoch开始前,训练epoch结束后执行的回调方法。
  122. **参数:**
  123. - **train_run_context**:包含前一个epoch的反馈信息的网络训练运行信息。
  124. - **ds_run_context**:数据处理管道运行信息。
  125. .. py:method:: sync_step_begin(train_run_context, ds_run_context)
  126. 用于定义在每个数据step开始前,训练step结束后执行的回调方法。
  127. **参数:**
  128. - **train_run_context**:包含前一个step的反馈信息的网络训练运行信息。
  129. - **ds_run_context**:数据处理管道运行信息。