You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ds_callback.py 8.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Python callback class
  17. """
  18. import threading
  19. from mindspore._c_dataengine import PyDSCallback
  20. from mindspore.train.callback import Callback
  21. import mindspore.dataset as ds
  22. from .validators import check_callback
  23. class DSCallback:
  24. """
  25. Abstract base class used to build a dataset callback class.
  26. Args:
  27. step_size (int, optional): The number of steps before the step_begin and step_end are called (Default=1).
  28. Examples:
  29. >>> class PrintInfo(DSCallback):
  30. >>> def ds_epoch_end(self, ds_run_context):
  31. >>> print(cb_params.cur_epoch_num)
  32. >>> print(cb_params.cur_step_num)
  33. >>>
  34. >>> data = data.map(operations=op, callbacks=PrintInfo())
  35. """
  36. @check_callback
  37. def __init__(self, step_size=1):
  38. self.step_size = step_size
  39. def ds_begin(self, ds_run_context):
  40. """
  41. Called before the data pipeline is started.
  42. Args:
  43. ds_run_context (RunContext): Include some information of the pipeline.
  44. """
  45. def ds_epoch_begin(self, ds_run_context):
  46. """
  47. Called before a new epoch is started.
  48. Args:
  49. ds_run_context (RunContext): Include some information of the pipeline.
  50. """
  51. def ds_epoch_end(self, ds_run_context):
  52. """
  53. Called after an epoch is finished.
  54. Args:
  55. ds_run_context (RunContext): Include some information of the pipeline.
  56. """
  57. def ds_step_begin(self, ds_run_context):
  58. """
  59. Called before n steps are started.
  60. Args:
  61. ds_run_context (RunContext): Include some information of the pipeline.
  62. """
  63. def ds_step_end(self, ds_run_context):
  64. """
  65. Called after n steps are finished.
  66. Args:
  67. ds_run_context (RunContext): Include some information of the pipeline.
  68. """
  69. def create_runtime_obj(self):
  70. """
  71. Creates a runtime (C++) object from the callback methods defined by the user.
  72. Returns: _c_dataengine.PyDSCallback
  73. """
  74. c_cb = PyDSCallback(self.step_size)
  75. at_least_one = False
  76. if self.__class__.ds_begin != DSCallback.ds_begin:
  77. c_cb.set_begin(self.ds_begin)
  78. at_least_one = True
  79. if self.__class__.ds_epoch_begin != DSCallback.ds_epoch_begin:
  80. c_cb.set_epoch_begin(self.ds_epoch_begin)
  81. at_least_one = True
  82. if self.__class__.ds_epoch_end != DSCallback.ds_epoch_end:
  83. c_cb.set_epoch_end(self.ds_epoch_end)
  84. at_least_one = True
  85. if self.__class__.ds_step_begin != DSCallback.ds_step_begin:
  86. c_cb.set_step_begin(self.ds_step_begin)
  87. at_least_one = True
  88. if self.__class__.ds_step_end != DSCallback.ds_step_end:
  89. c_cb.set_step_end(self.ds_step_end)
  90. at_least_one = True
  91. if not at_least_one:
  92. raise AttributeError("Provided Callback class did not override any of the 6 callback methods.")
  93. return c_cb
  94. class WaitedDSCallback(Callback, DSCallback):
  95. """
  96. Abstract base class used to build a dataset callback class that are synchronized with the training callback.
  97. This class can be used to execute a user defined logic right after the previous step or epoch.
  98. For example, one augmentation needs the loss from the previous trained epoch to update some of its parameters.
  99. Examples:
  100. >>> my_cb = MyWaitedCallback(32)
  101. >>> data = data.map(operations=AugOp(), callbacks=my_cb)
  102. >>> data = data.batch(32)
  103. >>> # define the model
  104. >>> model.train(epochs, data, callbacks=[my_cb])
  105. Args:
  106. step_size: the number of rows in each step.
  107. Usually the step size will be equal to the batch size (Default=1)
  108. """
  109. def __init__(self, step_size=1):
  110. super().__init__()
  111. self.step_size = step_size
  112. self.step_event = threading.Event()
  113. self.step_run_context = None
  114. self.epoch_event = threading.Event()
  115. self.epoch_run_context = None
  116. self.training_ended = False
  117. def sync_epoch_begin(self, train_run_context, ds_run_context):
  118. """
  119. Called before a new dataset epoch is started and after the previous training epoch is ended.
  120. Args:
  121. train_run_context: Include some information of the model with feedback from the previous epoch.
  122. ds_run_context: Include some information of the dataset pipeline.
  123. """
  124. def sync_step_begin(self, train_run_context, ds_run_context):
  125. """
  126. Called before a new dataset step is started and after the previous training step is ended.
  127. Args:
  128. train_run_context: Include some information of the model with feedback from the previous step.
  129. ds_run_context: Include some information of the dataset pipeline.
  130. """
  131. def epoch_end(self, run_context):
  132. """
  133. Internal method, do not call/override. Defines epoch_end of Callback to release the wait in ds_epoch_begin.
  134. Args:
  135. run_context: Include some information of the model.
  136. """
  137. self.epoch_run_context = run_context
  138. self.epoch_event.set()
  139. def ds_epoch_begin(self, ds_run_context):
  140. """
  141. Internal method, do not call/override. Defines ds_epoch_begin of DSCallback to wait for MS epoch_end callback.
  142. Args:
  143. ds_run_context: Include some information of the pipeline.
  144. """
  145. if ds_run_context.cur_epoch_num > 1:
  146. if not self.training_ended:
  147. success = self.epoch_event.wait(timeout=ds.config.get_callback_timeout())
  148. self.epoch_event.clear()
  149. if not success:
  150. raise RuntimeError(f"ds_epoch_begin timed out after {ds.config.get_callback_timeout()} second(s)")
  151. # by the time this thread wakes up, self.epoch_run_context is already available
  152. self.sync_epoch_begin(self.epoch_run_context, ds_run_context)
  153. def step_end(self, run_context):
  154. """
  155. Internal method, do not call/override. Defines step_end of Callback to release the wait in ds_step_begin.
  156. Args:
  157. run_context: Include some information of the model.
  158. """
  159. self.step_run_context = run_context
  160. self.step_event.set()
  161. def ds_step_begin(self, ds_run_context):
  162. """
  163. Internal method, do not call/override. Defines ds_step_begin of DSCallback to wait for MS step_end callback.
  164. Args:
  165. ds_run_context: Include some information of the pipeline.
  166. """
  167. if ds_run_context.cur_step_num > self.step_size:
  168. if not self.training_ended:
  169. success = self.step_event.wait(timeout=ds.config.get_callback_timeout())
  170. self.step_event.clear()
  171. if not success:
  172. raise RuntimeError(f"ds_step_begin timed out after {ds.config.get_callback_timeout()} second(s)")
  173. # by the time this thread wakes up, self.epoch_run_context is already available
  174. self.sync_step_begin(self.step_run_context, ds_run_context)
  175. def create_runtime_obj(self):
  176. """
  177. Creates a runtime (C++) object from the callback methods defined by the user. This method is internal.
  178. Returns: _c_dataengine.PyDSCallback
  179. """
  180. c_cb = PyDSCallback(self.step_size)
  181. at_least_one = False
  182. if self.__class__.sync_step_begin != WaitedDSCallback.sync_step_begin:
  183. c_cb.set_step_begin(self.ds_step_begin)
  184. at_least_one = True
  185. if self.__class__.sync_epoch_begin != WaitedDSCallback.sync_epoch_begin:
  186. c_cb.set_epoch_begin(self.ds_epoch_begin)
  187. at_least_one = True
  188. if not at_least_one:
  189. raise AttributeError("Provided Callback class did not override any of the 2 callback methods.")
  190. return c_cb
  191. def end(self, run_context):
  192. self.epoch_end(run_context)
  193. self.step_end(run_context)
  194. self.training_ended = True