You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ds_callback.py 9.2 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """
  16. Python callback class
  17. """
  18. import threading
  19. from mindspore._c_dataengine import PyDSCallback
  20. from mindspore.train.callback import Callback
  21. import mindspore.dataset as ds
  22. from .validators import check_callback
  23. class DSCallback:
  24. """
  25. Abstract base class used to build a dataset callback class.
  26. Args:
  27. step_size (int, optional): The number of steps between the step_begin and step_end are called (Default=1).
  28. Examples:
  29. >>> from mindspore.dataset import DSCallback
  30. >>>
  31. >>> class PrintInfo(DSCallback):
  32. ... def ds_epoch_end(self, ds_run_context):
  33. ... print(cb_params.cur_epoch_num)
  34. ... print(cb_params.cur_step_num)
  35. >>>
  36. >>> # dataset is an instance of Dataset object
  37. >>> dataset = dataset.map(operations=op, callbacks=PrintInfo())
  38. """
  39. @check_callback
  40. def __init__(self, step_size=1):
  41. self.step_size = step_size
  42. def ds_begin(self, ds_run_context):
  43. """
  44. Called before the data pipeline is started.
  45. Args:
  46. ds_run_context (RunContext): Include some information of the pipeline.
  47. """
  48. def ds_epoch_begin(self, ds_run_context):
  49. """
  50. Called before a new epoch is started.
  51. Args:
  52. ds_run_context (RunContext): Include some information of the pipeline.
  53. """
  54. def ds_epoch_end(self, ds_run_context):
  55. """
  56. Called after an epoch is finished.
  57. Args:
  58. ds_run_context (RunContext): Include some information of the pipeline.
  59. """
  60. def ds_step_begin(self, ds_run_context):
  61. """
  62. Called before each step start.
  63. Args:
  64. ds_run_context (RunContext): Include some information of the pipeline.
  65. """
  66. def ds_step_end(self, ds_run_context):
  67. """
  68. Called after each step finished.
  69. Args:
  70. ds_run_context (RunContext): Include some information of the pipeline.
  71. """
  72. def create_runtime_obj(self):
  73. """
  74. Creates a runtime (C++) object from the callback methods defined by the user.
  75. Returns:
  76. _c_dataengine.PyDSCallback.
  77. """
  78. c_cb = PyDSCallback(self.step_size)
  79. at_least_one = False
  80. if self.__class__.ds_begin != DSCallback.ds_begin:
  81. c_cb.set_begin(self.ds_begin)
  82. at_least_one = True
  83. if self.__class__.ds_epoch_begin != DSCallback.ds_epoch_begin:
  84. c_cb.set_epoch_begin(self.ds_epoch_begin)
  85. at_least_one = True
  86. if self.__class__.ds_epoch_end != DSCallback.ds_epoch_end:
  87. c_cb.set_epoch_end(self.ds_epoch_end)
  88. at_least_one = True
  89. if self.__class__.ds_step_begin != DSCallback.ds_step_begin:
  90. c_cb.set_step_begin(self.ds_step_begin)
  91. at_least_one = True
  92. if self.__class__.ds_step_end != DSCallback.ds_step_end:
  93. c_cb.set_step_end(self.ds_step_end)
  94. at_least_one = True
  95. if not at_least_one:
  96. raise AttributeError("Provided Callback class did not override any of the 6 callback methods.")
  97. return c_cb
  98. class WaitedDSCallback(Callback, DSCallback):
  99. """
  100. Abstract base class used to build a dataset callback class that is synchronized with the training callback.
  101. This class can be used to execute a user defined logic right after the previous step or epoch.
  102. For example, one augmentation needs the loss from the previous trained epoch to update some of its parameters.
  103. Args:
  104. step_size (int, optional): The number of rows in each step. Usually the step size
  105. will be equal to the batch size (Default=1).
  106. Examples:
  107. >>> from mindspore.dataset import WaitedDSCallback
  108. >>>
  109. >>> my_cb = WaitedDSCallback(32)
  110. >>> # dataset is an instance of Dataset object
  111. >>> dataset = dataset.map(operations=AugOp(), callbacks=my_cb)
  112. >>> dataset = dataset.batch(32)
  113. >>> # define the model
  114. >>> model.train(epochs, data, callbacks=[my_cb])
  115. """
  116. def __init__(self, step_size=1):
  117. super().__init__()
  118. self.step_size = step_size
  119. self.step_event = threading.Event()
  120. self.step_run_context = None
  121. self.epoch_event = threading.Event()
  122. self.epoch_run_context = None
  123. self.training_ended = False
  124. def sync_epoch_begin(self, train_run_context, ds_run_context):
  125. """
  126. Called before a new dataset epoch is started and after the previous training epoch is ended.
  127. Args:
  128. train_run_context: Include some information of the model with feedback from the previous epoch.
  129. ds_run_context: Include some information of the dataset pipeline.
  130. """
  131. def sync_step_begin(self, train_run_context, ds_run_context):
  132. """
  133. Called before a new dataset step is started and after the previous training step is ended.
  134. Args:
  135. train_run_context: Include some information of the model with feedback from the previous step.
  136. ds_run_context: Include some information of the dataset pipeline.
  137. """
  138. def epoch_end(self, run_context):
  139. """
  140. Internal method, do not call/override. Defines epoch_end of Callback to release the wait in ds_epoch_begin.
  141. Args:
  142. run_context: Include some information of the model.
  143. """
  144. self.epoch_run_context = run_context
  145. self.epoch_event.set()
  146. def ds_epoch_begin(self, ds_run_context):
  147. """
  148. Internal method, do not call/override. Defines ds_epoch_begin of DSCallback to wait for MS epoch_end callback.
  149. Args:
  150. ds_run_context: Include some information of the pipeline.
  151. """
  152. if ds_run_context.cur_epoch_num > 1:
  153. if not self.training_ended:
  154. success = self.epoch_event.wait(timeout=ds.config.get_callback_timeout())
  155. self.epoch_event.clear()
  156. if not success:
  157. raise RuntimeError(f"ds_epoch_begin timed out after {ds.config.get_callback_timeout()} second(s).")
  158. # by the time this thread wakes up, self.epoch_run_context is already available
  159. self.sync_epoch_begin(self.epoch_run_context, ds_run_context)
  160. def step_end(self, run_context):
  161. """
  162. Internal method, do not call/override. Defines step_end of Callback to release the wait in ds_step_begin.
  163. Args:
  164. run_context: Include some information of the model.
  165. """
  166. self.step_run_context = run_context
  167. self.step_event.set()
  168. def ds_step_begin(self, ds_run_context):
  169. """
  170. Internal method, do not call/override. Defines ds_step_begin of DSCallback to wait for MS step_end callback.
  171. Args:
  172. ds_run_context: Include some information of the pipeline.
  173. """
  174. if ds_run_context.cur_step_num > self.step_size:
  175. if not self.training_ended:
  176. success = self.step_event.wait(timeout=ds.config.get_callback_timeout())
  177. self.step_event.clear()
  178. if not success:
  179. raise RuntimeError(f"ds_step_begin timed out after {ds.config.get_callback_timeout()} second(s).")
  180. # by the time this thread wakes up, self.epoch_run_context is already available
  181. self.sync_step_begin(self.step_run_context, ds_run_context)
  182. def create_runtime_obj(self):
  183. """
  184. Creates a runtime (C++) object from the callback methods defined by the user. This method is internal.
  185. Returns:
  186. _c_dataengine.PyDSCallback.
  187. """
  188. c_cb = PyDSCallback(self.step_size)
  189. at_least_one = False
  190. if self.__class__.sync_step_begin != WaitedDSCallback.sync_step_begin:
  191. c_cb.set_step_begin(self.ds_step_begin)
  192. at_least_one = True
  193. if self.__class__.sync_epoch_begin != WaitedDSCallback.sync_epoch_begin:
  194. c_cb.set_epoch_begin(self.ds_epoch_begin)
  195. at_least_one = True
  196. if not at_least_one:
  197. raise AttributeError("Provided Callback class did not override any of the 2 callback methods.")
  198. return c_cb
  199. def end(self, run_context):
  200. """
  201. Internal method, release the wait if training is ended.
  202. Args:
  203. run_context: Include some information of the model.
  204. """
  205. self.epoch_end(run_context)
  206. self.step_end(run_context)
  207. self.training_ended = True