You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

debug_ops.py 13 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """debug_ops"""
  16. from types import FunctionType, MethodType
  17. from ..._checkparam import Validator as validator
  18. from ..._checkparam import Rel
  19. from ...common import dtype as mstype
  20. from ..primitive import prim_attr_register, PrimitiveWithInfer
  21. def _check_summary_param(name, value, class_name):
  22. """Checks the name and value is valid for summary."""
  23. n_type = name['dtype']
  24. n_value = name['value']
  25. validator.check_value_type('name', n_type, [type(mstype.string)], class_name)
  26. if not n_value:
  27. raise ValueError(f"For 'name' the value should by valid string in {class_name}, but got an empty string.")
  28. v_type = value['dtype']
  29. validator.check_value_type('value', v_type, [type(mstype.tensor)], class_name)
  30. # Note: The return value of the summary operator is not used,
  31. # so there's nothing special about the return `dtype` or `shape`, any value is ok.
  32. # The `value` should be set to None, else summary operators may be optimized at compile graph phase,
  33. # it cause summary operators can not record data in constant folding scene.
  34. SUMMARY_RETURN_VALUE = {'dtype': mstype.int32, 'shape': [1], 'value': None}
  35. class ScalarSummary(PrimitiveWithInfer):
  36. """
  37. Outputs a scalar to a protocol buffer through a scalar summary operator.
  38. Inputs:
  39. - **name** (str) - The name of the input variable, it must not be an empty string.
  40. - **value** (Tensor) - The value of scalar, and the shape of value must be [] or [1].
  41. Examples:
  42. >>> class SummaryDemo(nn.Cell):
  43. >>> def __init__(self,):
  44. >>> super(SummaryDemo, self).__init__()
  45. >>> self.summary = P.ScalarSummary()
  46. >>> self.add = P.TensorAdd()
  47. >>>
  48. >>> def construct(self, x, y):
  49. >>> name = "x"
  50. >>> self.summary(name, x)
  51. >>> x = self.add(x, y)
  52. >>> return x
  53. """
  54. @prim_attr_register
  55. def __init__(self):
  56. """init"""
  57. def __infer__(self, name, value):
  58. _check_summary_param(name, value, self.__class__.__name__)
  59. v_shape = value['shape']
  60. # In the summary, the value whose shape is [1] is also considered as a scalar.
  61. if v_shape and v_shape != [1]:
  62. raise ValueError(f"For 'value' the type should be scalar, "
  63. f"shape should be [] or [1] in {self.__class__.__name__}, but got {v_shape}.")
  64. return SUMMARY_RETURN_VALUE
  65. class ImageSummary(PrimitiveWithInfer):
  66. """
  67. Outputs image tensor to protocol buffer through image summary operator.
  68. Inputs:
  69. - **name** (str) - The name of the input variable, it must not be an empty string.
  70. - **value** (Tensor) - The value of image, the rank of tensor must be 4.
  71. Examples:
  72. >>> class Net(nn.Cell):
  73. >>> def __init__(self):
  74. >>> super(Net, self).__init__()
  75. >>> self.summary = P.ImageSummary()
  76. >>>
  77. >>> def construct(self, x):
  78. >>> name = "image"
  79. >>> out = self.summary(name, x)
  80. >>> return out
  81. """
  82. @prim_attr_register
  83. def __init__(self):
  84. """init"""
  85. def __infer__(self, name, value):
  86. _check_summary_param(name, value, self.__class__.__name__)
  87. # The shape dim of image should be 4.
  88. v_shape = value['shape']
  89. image_dim = 4
  90. if len(v_shape) != image_dim:
  91. raise ValueError(f"For 'value' the dim should be {image_dim} in {self.__class__.__name__},"
  92. f" but got {len(v_shape)}.")
  93. return SUMMARY_RETURN_VALUE
  94. class TensorSummary(PrimitiveWithInfer):
  95. """
  96. Outputs a tensor to a protocol buffer through a tensor summary operator.
  97. Inputs:
  98. - **name** (str) - The name of the input variable.
  99. - **value** (Tensor) - The value of tensor, and the rank of tensor must be greater than 0.
  100. Examples:
  101. >>> class SummaryDemo(nn.Cell):
  102. >>> def __init__(self,):
  103. >>> super(SummaryDemo, self).__init__()
  104. >>> self.summary = P.TensorSummary()
  105. >>> self.add = P.TensorAdd()
  106. >>>
  107. >>> def construct(self, x, y):
  108. >>> x = self.add(x, y)
  109. >>> name = "x"
  110. >>> self.summary(name, x)
  111. >>> return x
  112. """
  113. @prim_attr_register
  114. def __init__(self):
  115. """init"""
  116. def __infer__(self, name, value):
  117. _check_summary_param(name, value, self.__class__.__name__)
  118. v_shape = value['shape']
  119. # In the summary, the value whose shape is [] is not considered as a tensor.
  120. if not v_shape:
  121. raise ValueError(f"For 'value' the type should be tensor in {self.__class__.__name__}, "
  122. f"shape should not be [].")
  123. return SUMMARY_RETURN_VALUE
  124. class HistogramSummary(PrimitiveWithInfer):
  125. """
  126. Outputs tensor to protocol buffer through histogram summary operator.
  127. Inputs:
  128. - **name** (str) - The name of the input variable.
  129. - **value** (Tensor) - The value of tensor, and the rank of tensor must be greater than 0.
  130. Examples:
  131. >>> class SummaryDemo(nn.Cell):
  132. >>> def __init__(self,):
  133. >>> super(SummaryDemo, self).__init__()
  134. >>> self.summary = P.HistogramSummary()
  135. >>> self.add = P.TensorAdd()
  136. >>>
  137. >>> def construct(self, x, y):
  138. >>> x = self.add(x, y)
  139. >>> name = "x"
  140. >>> self.summary(name, x)
  141. >>> return x
  142. """
  143. @prim_attr_register
  144. def __init__(self):
  145. """init"""
  146. def __infer__(self, name, value):
  147. _check_summary_param(name, value, self.__class__.__name__)
  148. v_shape = value['shape']
  149. # In the summary, the histogram value should be a tensor whose shape is not [].
  150. if not v_shape:
  151. raise ValueError(f"For 'value' the type should be tensor in {self.__class__.__name__}, "
  152. f"shape should not be [].")
  153. return SUMMARY_RETURN_VALUE
  154. class InsertGradientOf(PrimitiveWithInfer):
  155. """
  156. Attaches callback to graph node that will be invoked on the node's gradient.
  157. Args:
  158. f (Function): MindSpore's Function. Callback function.
  159. Inputs:
  160. - **input_x** (Any) - The graph node to attach to.
  161. Outputs:
  162. Tensor, returns `input_x` directly. `InsertGradientOf` does not affect the forward result.
  163. Examples:
  164. >>> def clip_gradient(dx):
  165. >>> ret = dx
  166. >>> if ret > 1.0:
  167. >>> ret = 1.0
  168. >>>
  169. >>> if ret < 0.2:
  170. >>> ret = 0.2
  171. >>>
  172. >>> return ret
  173. >>>
  174. >>> clip = P.InsertGradientOf(clip_gradient)
  175. >>> grad_all = C.GradOperation(get_all=True)
  176. >>> def InsertGradientOfClipDemo():
  177. >>> def clip_test(x, y):
  178. >>> x = clip(x)
  179. >>> y = clip(y)
  180. >>> c = x * y
  181. >>> return c
  182. >>>
  183. >>> @ms_function
  184. >>> def f(x, y):
  185. >>> return clip_test(x, y)
  186. >>>
  187. >>> def fd(x, y):
  188. >>> return grad_all(clip_test)(x, y)
  189. >>>
  190. >>> print("forward: ", f(1.1, 0.1))
  191. >>> print("clip_gradient:", fd(1.1, 0.1))
  192. """
  193. @prim_attr_register
  194. def __init__(self, f):
  195. self.f = f
  196. def infer_shape(self, x_shape):
  197. return x_shape
  198. def infer_dtype(self, x_type):
  199. return x_type
  200. class HookBackward(PrimitiveWithInfer):
  201. """
  202. This operation is used as a tag to hook gradient in intermediate variables. Note that this function
  203. is only supported in Pynative Mode.
  204. Note:
  205. The hook function must be defined like `hook_fn(grad) -> Tensor or None`,
  206. where grad is the gradient passed to the primitive and gradient may be
  207. modified and passed to next primitive. The difference between a hook function and
  208. callback of InsertGradientOf is that a hook function is executed in the python
  209. environment while callback will be parsed and added to the graph.
  210. Args:
  211. hook_fn (Function): Python function. hook function.
  212. Inputs:
  213. - **inputs** (Tensor) - The variable to hook.
  214. Examples:
  215. >>> def hook_fn(grad_out):
  216. >>> print(grad_out)
  217. >>>
  218. >>> grad_all = GradOperation(get_all=True)
  219. >>> hook = P.HookBackward(hook_fn)
  220. >>>
  221. >>> def hook_test(x, y):
  222. >>> z = x * y
  223. >>> z = hook(z)
  224. >>> z = z * y
  225. >>> return z
  226. >>>
  227. >>> def backward(x, y):
  228. >>> return grad_all(hook_test)(x, y)
  229. >>>
  230. >>> backward(1, 2)
  231. """
  232. def __init__(self, hook_fn, cell_id=""):
  233. super(HookBackward, self).__init__(self.__class__.__name__)
  234. self.add_prim_attr("cell_id", cell_id)
  235. self.init_attrs["cell_id"] = cell_id
  236. if not isinstance(hook_fn, (FunctionType, MethodType)):
  237. raise TypeError("Hook function should be python function type.")
  238. self.register_hook(hook_fn)
  239. self.cell_id = cell_id
  240. def infer_shape(self, *inputs_shape):
  241. if len(inputs_shape) == 1:
  242. return inputs_shape[0]
  243. return inputs_shape
  244. def infer_dtype(self, *inputs_type):
  245. if len(inputs_type) == 1:
  246. return inputs_type[0]
  247. return inputs_type
  248. class Print(PrimitiveWithInfer):
  249. """
  250. Outputs tensor or string to stdout.
  251. Note:
  252. In pynative mode, please use python print function.
  253. Inputs:
  254. - **input_x** (Union[Tensor, str]) - The graph node to attach to. The input supports
  255. multiple strings and tensors which are separated by ','.
  256. Examples:
  257. >>> class PrintDemo(nn.Cell):
  258. >>> def __init__(self):
  259. >>> super(PrintDemo, self).__init__()
  260. >>> self.print = P.Print()
  261. >>>
  262. >>> def construct(self, x, y):
  263. >>> self.print('Print Tensor x and Tensor y:', x, y)
  264. >>> return x
  265. """
  266. @prim_attr_register
  267. def __init__(self):
  268. self.add_prim_attr("_side_effect", True)
  269. def __call__(self, *args):
  270. for arg in args:
  271. print(arg)
  272. def infer_shape(self, *inputs):
  273. return [1]
  274. def infer_dtype(self, *inputs):
  275. for dtype in inputs:
  276. validator.check_subclass("input", dtype, (mstype.tensor, mstype.string), self.name)
  277. return mstype.int32
  278. class Assert(PrimitiveWithInfer):
  279. """
  280. Asserts that the given condition is true.
  281. If input condition evaluates to false, print the list of tensor in data.
  282. Args:
  283. summarize (int): Print this many entries of each tensor.
  284. Inputs:
  285. - **condition** [Union[Tensor[bool], bool]] - The condition to evaluate.
  286. - **input_data** (Union(tuple[Tensor], list[Tensor])) - The tensors to print out when condition is false.
  287. Examples:
  288. >>> class AssertDemo(nn.Cell):
  289. >>> def __init__(self):
  290. >>> super(AssertDemo, self).__init__()
  291. >>> self.assert = P.Assert(summarize=10)
  292. >>> self.add = P.TensorAdd()
  293. >>>
  294. >>> def construct(self, x, y):
  295. >>> data = self.add(x, y)
  296. >>> self.assert(True, [data])
  297. >>> return data
  298. """
  299. @prim_attr_register
  300. def __init__(self, summarize=3):
  301. """Initialize Assert"""
  302. self.summarize = validator.check_value_type("summarize", summarize, [int], self.name)
  303. def infer_shape(self, condition, inputs):
  304. condition_len = len(condition)
  305. validator.check_integer("condition's rank", condition_len, 1, Rel.LE, self.name)
  306. if condition_len == 1:
  307. validator.check_integer("condition[0]", condition[0], 1, Rel.EQ, self.name)
  308. return [1]
  309. def infer_dtype(self, condition, inputs):
  310. validator.check_scalar_or_tensor_type_same({"condition": condition}, [mstype.bool_], self.name)
  311. for dtype in inputs:
  312. validator.check_subclass("input", dtype, [mstype.tensor], self.name)
  313. return mstype.int32