You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore.dataset.WaitedDSCallback.rst 4.2 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. mindspore.dataset.WaitedDSCallback
  2. ==================================
  3. .. py:class:: mindspore.dataset.WaitedDSCallback(step_size=1)
  4. 用于自定义与训练回调同步的数据集回调类的抽象基类。
  5. 此类可用于自定义在step或epoch结束后执行的回调方法。
  6. 例如在自动数据增强中根据上一个epoch的loss值来更新增强算子参数配置。
  7. **参数:**
  8. - **step_size** (int, optional) - 每个step包含的数据行数。step大小通常与batch大小相等(默认值为1)。
  9. **样例:**
  10. >>> from mindspore.dataset import WaitedDSCallback
  11. >>>
  12. >>> my_cb = WaitedDSCallback(32)
  13. >>> # dataset为任意数据集实例
  14. >>> data = data.map(operations=AugOp(), callbacks=my_cb)
  15. >>> data = data.batch(32)
  16. >>> # 定义网络
  17. >>> model.train(epochs, data, callbacks=[my_cb])
  18. .. py:method:: begin(run_context)
  19. 用于定义在网络训练开始前执行的回调方法。
  20. **参数:**
  21. - **run_context** (RunContext) - 网络训练运行信息。
  22. .. py:method:: ds_begin(ds_run_context)
  23. 用于定义在数据处理管道启动前执行的回调方法。
  24. **参数:**
  25. - **ds_run_context** (RunContext) - 数据处理管道运行信息。
  26. .. py:method:: ds_epoch_begin(ds_run_context)
  27. 内部方法,不能被调用或者重写。通过重写mindspore.dataset.DSCallback.ds_epoch_begin 实现与mindspore.train.callback.Callback.epoch_end回调同步。
  28. **参数:**
  29. **ds_run_context**:数据处理管道运行信息。
  30. .. py:method:: ds_epoch_end(ds_run_context)
  31. 用于定义在每个数据epoch结束后执行的回调方法。
  32. **参数:**
  33. - **ds_run_context** (RunContext) - 数据处理管道运行信息。
  34. .. py:method:: ds_step_begin(ds_run_context)
  35. 内部方法,不能被调用或者重写。通过重写mindspore.dataset.DSCallback.ds_step_begin
  36. 实现与mindspore.train.callback.Callback.step_end回调同步。
  37. **参数:**
  38. **ds_run_context**:数据处理管道运行信息。
  39. .. py:method:: ds_step_end(ds_run_context)
  40. 用于定义在每个数据step结束后执行的回调方法。
  41. **参数:**
  42. - **ds_run_context** (RunContext) - 数据处理管道运行信息。
  43. .. py:method:: end(run_context)
  44. 内部方法,当网络训练结束时释放等待。
  45. **参数:**
  46. **run_context**:网络训练运行信息。
  47. .. py:method:: epoch_begin(run_context)
  48. 用于定义在每个训练epoch开始前执行的回调方法。
  49. **参数:**
  50. - **run_context** (RunContext) - 网络训练运行信息。
  51. .. py:method:: epoch_end(run_context)
  52. 内部方法,不能被调用或重写。通过重写mindspore.train.callback.Callback.epoch_end来释放ds_epoch_begin的等待。
  53. **参数:**
  54. **run_context**:网络训练运行信息。
  55. .. py:method:: step_begin(run_context)
  56. 用于定义在每个训练step开始前执行的回调方法。
  57. **参数:**
  58. - **run_context** (RunContext) - 网络训练运行信息。
  59. .. py:method:: step_end(run_context)
  60. 内部方法,不能被调用或重写。通过重写mindspore.train.callback.Callback.step_end来释放 `ds_step_begin` 的等待。
  61. **参数:**
  62. **run_context**:网络训练运行信息。
  63. .. py:method:: sync_epoch_begin(train_run_context, ds_run_context)
  64. 用于定义在每个数据epoch开始前,训练epoch结束后执行的回调方法。
  65. **参数:**
  66. - **train_run_context**:包含前一个epoch的反馈信息的网络训练运行信息。
  67. - **ds_run_context**:数据处理管道运行信息。
  68. .. py:method:: sync_step_begin(train_run_context, ds_run_context)
  69. 用于定义在每个数据step开始前,训练step结束后执行的回调方法。
  70. **参数:**
  71. - **train_run_context**:包含前一个step的反馈信息的网络训练运行信息。
  72. - **ds_run_context**:数据处理管道运行信息。