You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

debug_ops.py 6.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """debug_ops"""
  16. from ..._checkparam import Validator as validator
  17. from ...common import dtype as mstype
  18. from ..primitive import Primitive, prim_attr_register, PrimitiveWithInfer
  19. class ScalarSummary(Primitive):
  20. """
  21. Output scalar to protocol buffer through scalar summary operator.
  22. Inputs:
  23. - **name** (str) - The name of the input variable.
  24. - **value** (Tensor) - The value of scalar.
  25. Examples:
  26. >>> class SummaryDemo(nn.Cell):
  27. >>> def __init__(self,):
  28. >>> super(SummaryDemo, self).__init__()
  29. >>> self.summary = P.ScalarSummary()
  30. >>> self.add = P.TensorAdd()
  31. >>>
  32. >>> def construct(self, x, y):
  33. >>> name = "x"
  34. >>> self.summary(name, x)
  35. >>> x = self.add(x, y)
  36. >>> return x
  37. """
  38. @prim_attr_register
  39. def __init__(self):
  40. """init"""
  41. def __call__(self, *args, **kwargs):
  42. pass
  43. class ImageSummary(Primitive):
  44. """
  45. Output image tensor to protocol buffer through image summary operator.
  46. Inputs:
  47. - **name** (str) - The name of the input variable.
  48. - **value** (Tensor) - The value of image.
  49. Examples:
  50. >>> class Net(nn.Cell):
  51. >>> def __init__(self):
  52. >>> super(Net, self).__init__()
  53. >>> self.summary = P.ImageSummary()
  54. >>>
  55. >>> def construct(self, x):
  56. >>> name = "image"
  57. >>> out = self.summary(name, x)
  58. >>> return out
  59. """
  60. @prim_attr_register
  61. def __init__(self):
  62. """init"""
  63. def __call__(self, *args, **kwargs):
  64. pass
  65. class TensorSummary(Primitive):
  66. """
  67. Output tensor to protocol buffer through tensor summary operator.
  68. Inputs:
  69. - **name** (str) - The name of the input variable.
  70. - **value** (Tensor) - The value of tensor.
  71. Examples:
  72. >>> class SummaryDemo(nn.Cell):
  73. >>> def __init__(self,):
  74. >>> super(SummaryDemo, self).__init__()
  75. >>> self.summary = P.TensorSummary()
  76. >>> self.add = P.TensorAdd()
  77. >>>
  78. >>> def construct(self, x, y):
  79. >>> x = self.add(x, y)
  80. >>> name = "x"
  81. >>> self.summary(name, x)
  82. >>> return x
  83. """
  84. @prim_attr_register
  85. def __init__(self):
  86. """init"""
  87. def __call__(self, *args, **kwargs):
  88. pass
  89. class HistogramSummary(Primitive):
  90. """
  91. Output tensor to protocol buffer through histogram summary operator.
  92. Inputs:
  93. - **name** (str) - The name of the input variable.
  94. - **value** (Tensor) - The value of tensor, and the rank of tensor should be greater than 0.
  95. Examples:
  96. >>> class SummaryDemo(nn.Cell):
  97. >>> def __init__(self,):
  98. >>> super(SummaryDemo, self).__init__()
  99. >>> self.summary = P.HistogramSummary()
  100. >>> self.add = P.TensorAdd()
  101. >>>
  102. >>> def construct(self, x, y):
  103. >>> x = self.add(x, y)
  104. >>> name = "x"
  105. >>> self.summary(name, x)
  106. >>> return x
  107. """
  108. @prim_attr_register
  109. def __init__(self):
  110. """init"""
  111. class InsertGradientOf(PrimitiveWithInfer):
  112. """
  113. Attach callback to graph node that will be invoked on the node's gradient.
  114. Args:
  115. f (Function): MindSpore's Function. Callback function.
  116. Inputs:
  117. - **input_x** (Tensor) - The graph node to attach to.
  118. Outputs:
  119. Tensor, returns `input_x` directly. `InsertGradientOf` does not affect the forward result.
  120. Examples:
  121. >>> def clip_gradient(dx):
  122. >>> ret = dx
  123. >>> if ret > 1.0:
  124. >>> ret = 1.0
  125. >>>
  126. >>> if ret < 0.2:
  127. >>> ret = 0.2
  128. >>>
  129. >>> return ret
  130. >>>
  131. >>> clip = P.InsertGradientOf(clip_gradient)
  132. >>> grad_all = C.GradOperation('get_all', get_all=True)
  133. >>> def InsertGradientOfClipDemo():
  134. >>> def clip_test(x, y):
  135. >>> x = clip(x)
  136. >>> y = clip(y)
  137. >>> c = x * y
  138. >>> return c
  139. >>>
  140. >>> @ms_function
  141. >>> def f(x, y):
  142. >>> return clip_test(x, y)
  143. >>>
  144. >>> def fd(x, y):
  145. >>> return grad_all(clip_test)(x, y)
  146. >>>
  147. >>> print("forward: ", f(1.1, 0.1))
  148. >>> print("clip_gradient:", fd(1.1, 0.1))
  149. """
  150. @prim_attr_register
  151. def __init__(self, f):
  152. self.f = f
  153. def __call__(self, x):
  154. """run in PyNative mode."""
  155. return x
  156. def infer_shape(self, x_shape):
  157. return x_shape
  158. def infer_dtype(self, x_type):
  159. return x_type
  160. class Print(PrimitiveWithInfer):
  161. """
  162. Output tensor or string to stdout.
  163. Note:
  164. The print operation cannot support float64 and bool types currently.
  165. Inputs:
  166. - **input_x** (Union[Tensor, str]) - The graph node to attach to. The input supports
  167. multiple strings and tensors which are separated by ','.
  168. Examples:
  169. >>> class PrintDemo(nn.Cell):
  170. >>> def __init__(self):
  171. >>> super(PrintDemo, self).__init__()
  172. >>> self.print = P.Print()
  173. >>>
  174. >>> def construct(self, x, y):
  175. >>> self.print('Print Tensor x and Tensor y:', x, y)
  176. >>> return x
  177. """
  178. @prim_attr_register
  179. def __init__(self):
  180. pass
  181. def __call__(self, *args):
  182. for arg in args:
  183. print(arg)
  184. def infer_shape(self, *inputs):
  185. return [1]
  186. def infer_dtype(self, *inputs):
  187. for dtype in inputs:
  188. validator.check_subclass("input", dtype, (mstype.tensor, mstype.string), self.name)
  189. return mstype.int32