You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.dataset.WaitedDSCallback.rst 4.2 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. mindspore.dataset.WaitedDSCallback
  2. ==================================
  3. .. py:class:: mindspore.dataset.WaitedDSCallback(step_size=1)
  4. 用于自定义与训练回调同步的数据集回调类的抽象基类。
  5. 此类可用于自定义在step或epoch结束后执行的回调方法。
  6. 例如在自动数据增强中根据上一个epoch的loss值来更新增强算子参数配置。
  7. **参数:**
  8. **step_size** (int, optional):每个step包含的数据行数。step大小通常与batch大小相等(默认值为1)。
  9. **样例:**
  10. >>> from mindspore.dataset import WaitedDSCallback
  11. >>>
  12. >>> my_cb = WaitedDSCallback(32)
  13. >>> # dataset为任意数据集实例
  14. >>> data = data.map(operations=AugOp(), callbacks=my_cb)
  15. >>> data = data.batch(32)
  16. >>> # 定义网络
  17. >>> model.train(epochs, data, callbacks=[my_cb])
  18. .. py:method:: begin(run_context)
  19. 用于定义在网络训练开始前执行的回调方法。
  20. **参数:**
  21. **run_context** (RunContext):网络训练运行信息。
  22. .. py:method:: ds_begin(ds_run_context)
  23. 用于定义在数据处理管道启动前执行的回调方法。
  24. **参数:**
  25. **ds_run_context** (RunContext):数据处理管道运行信息。
  26. .. py:method:: ds_epoch_begin(ds_run_context)
  27. 内部方法,不能被调用或者重写。通过重写mindspore.dataset.DSCallback.ds_epoch_begin
  28. 实现与mindspore.train.callback.Callback.epoch_end回调同步。
  29. **参数:**
  30. **ds_run_context**:数据处理管道运行信息。
  31. .. py:method:: ds_epoch_end(ds_run_context)
  32. 用于定义在每个数据epoch结束后执行的回调方法。
  33. **参数:**
  34. **ds_run_context** (RunContext):数据处理管道运行信息。
  35. .. py:method:: ds_step_begin(ds_run_context)
  36. 内部方法,不能被调用或者重写。通过重写mindspore.dataset.DSCallback.ds_step_begin
  37. 实现与mindspore.train.callback.Callback.step_end回调同步。
  38. **参数:**
  39. **ds_run_context**:数据处理管道运行信息。
  40. .. py:method:: ds_step_end(ds_run_context)
  41. 用于定义在每个数据step结束后执行的回调方法。
  42. **参数:**
  43. **ds_run_context** (RunContext):数据处理管道运行信息。
  44. .. py:method:: end(run_context)
  45. 内部方法,当网络训练结束时释放等待。
  46. **参数:**
  47. **run_context**:网络训练运行信息。
  48. .. py:method:: epoch_begin(run_context)
  49. 用于定义在每个训练epoch开始前执行的回调方法。
  50. **参数:**
  51. **run_context** (RunContext):网络训练运行信息。
  52. .. py:method:: epoch_end(run_context)
  53. 内部方法,不能被调用或重写。通过重写mindspore.train.callback.Callback.epoch_end来释放ds_epoch_begin的等待。
  54. **参数:**
  55. **run_context**:网络训练运行信息。
  56. .. py:method:: step_begin(run_context)
  57. 用于定义在每个训练step开始前执行的回调方法。
  58. **参数:**
  59. **run_context** (RunContext):网络训练运行信息。
  60. .. py:method:: step_end(run_context)
  61. 内部方法,不能被调用或重写。通过重写mindspore.train.callback.Callback.step_end来释放 `ds_step_begin` 的等待。
  62. **参数:**
  63. **run_context**:网络训练运行信息。
  64. .. py:method:: sync_epoch_begin(train_run_context, ds_run_context)
  65. 用于定义在每个数据epoch开始前,训练epoch结束后执行的回调方法。
  66. **参数:**
  67. - **train_run_context**:包含前一个epoch的反馈信息的网络训练运行信息。
  68. - **ds_run_context**:数据处理管道运行信息。
  69. .. py:method:: sync_step_begin(train_run_context, ds_run_context)
  70. 用于定义在每个数据step开始前,训练step结束后执行的回调方法。
  71. **参数:**
  72. - **train_run_context**:包含前一个step的反馈信息的网络训练运行信息。
  73. - **ds_run_context**:数据处理管道运行信息。