You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

debug_ops.py 15 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """debug_ops"""
  16. from types import FunctionType, MethodType
  17. from mindspore import context
  18. from ..._checkparam import Validator as validator
  19. from ..._checkparam import Rel
  20. from ...common import dtype as mstype
  21. from ..primitive import prim_attr_register, Primitive, PrimitiveWithInfer
  22. def _check_mode(class_name):
  23. """Check for PyNative mode."""
  24. mode = context.get_context('mode')
  25. if mode == context.PYNATIVE_MODE:
  26. raise RuntimeError(f'{class_name} operator does not support PyNative mode.')
  27. def _check_summary_param(name, value, class_name):
  28. """Checks the name and value is valid for summary."""
  29. _check_mode(class_name)
  30. n_type = name['dtype']
  31. n_value = name['value']
  32. validator.check_value_type('name', n_type, [type(mstype.string)], class_name)
  33. if not n_value:
  34. raise ValueError(f"For 'name' the value should by valid string in {class_name}, but got an empty string.")
  35. v_type = value['dtype']
  36. validator.check_value_type('value', v_type, [type(mstype.tensor)], class_name)
  37. # Note: The return value of the summary operator is not used,
  38. # so there's nothing special about the return `dtype` or `shape`, any value is ok.
  39. # The `value` should be set to None, else summary operators may be optimized at compile graph phase,
  40. # it cause summary operators can not record data in constant folding scene.
  41. SUMMARY_RETURN_VALUE = {'dtype': mstype.int32, 'shape': [1], 'value': None}
  42. class ScalarSummary(Primitive):
  43. """
  44. Outputs a scalar to a protocol buffer through a scalar summary operator.
  45. Inputs:
  46. - **name** (str) - The name of the input variable, it must not be an empty string.
  47. - **value** (Tensor) - The value of scalar, and the shape of value must be [] or [1].
  48. Raises:
  49. TypeError: If `name` is not a str.
  50. TypeError: If `value` is not a Tensor.
  51. Supported Platforms:
  52. ``Ascend`` ``GPU`` ``CPU``
  53. Examples:
  54. >>> import mindspore.nn as nn
  55. >>> import mindspore.ops as ops
  56. >>>
  57. >>>
  58. >>> class SummaryDemo(nn.Cell):
  59. ... def __init__(self,):
  60. ... super(SummaryDemo, self).__init__()
  61. ... self.summary = ops.ScalarSummary()
  62. ... self.add = ops.Add()
  63. ...
  64. ... def construct(self, x, y):
  65. ... name = "x"
  66. ... self.summary(name, x)
  67. ... x = self.add(x, y)
  68. ... return x
  69. ...
  70. """
  71. @prim_attr_register
  72. def __init__(self):
  73. """Initialize ScalarSummary."""
  74. self.add_prim_attr("side_effect_io", True)
  75. class ImageSummary(PrimitiveWithInfer):
  76. """
  77. Outputs the image tensor to protocol buffer through image summary operator.
  78. Inputs:
  79. - **name** (str) - The name of the input variable, it must not be an empty string.
  80. - **value** (Tensor) - The value of image, the rank of tensor must be 4.
  81. Raises:
  82. TypeError: If `name` is not a str.
  83. TypeError: If `value` is not a Tensor.
  84. Supported Platforms:
  85. ``Ascend`` ``GPU`` ``CPU``
  86. Examples:
  87. >>> import mindspore.nn as nn
  88. >>> import mindspore.ops as ops
  89. >>>
  90. >>>
  91. >>> class Net(nn.Cell):
  92. ... def __init__(self):
  93. ... super(Net, self).__init__()
  94. ... self.summary = ops.ImageSummary()
  95. ...
  96. ... def construct(self, x):
  97. ... name = "image"
  98. ... out = self.summary(name, x)
  99. ... return out
  100. ...
  101. """
  102. @prim_attr_register
  103. def __init__(self):
  104. """Initialize ImageSummary."""
  105. self.add_prim_attr("side_effect_io", True)
  106. def __infer__(self, name, value):
  107. _check_summary_param(name, value, self.__class__.__name__)
  108. # The shape dim of image should be 4.
  109. v_shape = value['shape']
  110. image_dim = 4
  111. if len(v_shape) != image_dim:
  112. raise ValueError(f"For 'value' the dim should be {image_dim} in {self.__class__.__name__},"
  113. f" but got {len(v_shape)}.")
  114. return SUMMARY_RETURN_VALUE
  115. class TensorSummary(Primitive):
  116. """
  117. Outputs a tensor to a protocol buffer through a tensor summary operator.
  118. Inputs:
  119. - **name** (str) - The name of the input variable.
  120. - **value** (Tensor) - The value of tensor, and the rank of tensor must be greater than 0.
  121. Raises:
  122. TypeError: If `name` is not a str.
  123. TypeError: If `value` is not a Tensor.
  124. Supported Platforms:
  125. ``Ascend`` ``GPU`` ``CPU``
  126. Examples:
  127. >>> import mindspore.nn as nn
  128. >>> import mindspore.ops as ops
  129. >>>
  130. >>>
  131. >>> class SummaryDemo(nn.Cell):
  132. ... def __init__(self,):
  133. ... super(SummaryDemo, self).__init__()
  134. ... self.summary = ops.TensorSummary()
  135. ... self.add = ops.Add()
  136. ...
  137. ... def construct(self, x, y):
  138. ... x = self.add(x, y)
  139. ... name = "x"
  140. ... self.summary(name, x)
  141. ... return x
  142. ...
  143. """
  144. @prim_attr_register
  145. def __init__(self):
  146. """Initialize TensorSummary."""
  147. self.add_prim_attr("side_effect_io", True)
  148. class HistogramSummary(PrimitiveWithInfer):
  149. """
  150. Outputs the tensor to protocol buffer through histogram summary operator.
  151. Inputs:
  152. - **name** (str) - The name of the input variable.
  153. - **value** (Tensor) - The value of tensor, and the rank of tensor must be greater than 0.
  154. Raises:
  155. TypeError: If `name` is not a str.
  156. TypeError: If `value` is not a Tensor.
  157. Supported Platforms:
  158. ``Ascend`` ``GPU`` ``CPU``
  159. Examples:
  160. >>> import mindspore.nn as nn
  161. >>> import mindspore.ops as ops
  162. >>>
  163. >>>
  164. >>> class SummaryDemo(nn.Cell):
  165. ... def __init__(self,):
  166. ... super(SummaryDemo, self).__init__()
  167. ... self.summary = ops.HistogramSummary()
  168. ... self.add = ops.Add()
  169. ...
  170. ... def construct(self, x, y):
  171. ... x = self.add(x, y)
  172. ... name = "x"
  173. ... self.summary(name, x)
  174. ... return x
  175. ...
  176. """
  177. @prim_attr_register
  178. def __init__(self):
  179. """Initialize HistogramSummary."""
  180. self.add_prim_attr("side_effect_io", True)
  181. def __infer__(self, name, value):
  182. _check_summary_param(name, value, self.__class__.__name__)
  183. v_shape = value['shape']
  184. # In the summary, the histogram value should be a tensor whose shape is not [].
  185. if not v_shape:
  186. raise ValueError(f"For 'value' the type should be tensor in {self.__class__.__name__}, "
  187. f"shape should not be [].")
  188. return SUMMARY_RETURN_VALUE
  189. class InsertGradientOf(PrimitiveWithInfer):
  190. """
  191. Attaches callback to the graph node that will be invoked on the node's gradient.
  192. Args:
  193. f (Function): MindSpore's Function. Callback function.
  194. Inputs:
  195. - **input_x** (Any) - The graph node to attach to.
  196. Outputs:
  197. Tensor, returns `input_x` directly. `InsertGradientOf` does not affect the forward result.
  198. Raises:
  199. TypeError: If `f` is not a function of mindspore.
  200. Supported Platforms:
  201. ``Ascend`` ``GPU`` ``CPU``
  202. Examples:
  203. >>> def clip_gradient(dx):
  204. ... ret = dx
  205. ... if ret > 1.0:
  206. ... ret = 1.0
  207. ...
  208. ... if ret < 0.2:
  209. ... ret = 0.2
  210. ...
  211. ... return ret
  212. ...
  213. >>> clip = ops.InsertGradientOf(clip_gradient)
  214. >>> grad_all = ops.GradOperation(get_all=True)
  215. >>> def InsertGradientOfClipDemo():
  216. ... def clip_test(x, y):
  217. ... x = clip(x)
  218. ... y = clip(y)
  219. ... c = x * y
  220. ... return c
  221. ...
  222. ... @ms_function
  223. ... def f(x, y):
  224. ... return clip_test(x, y)
  225. ...
  226. ... def fd(x, y):
  227. ... return grad_all(clip_test)(x, y)
  228. ...
  229. ... print("forward: ", f(1.1, 0.1))
  230. ... print("clip_gradient:", fd(1.1, 0.1))
  231. ...
  232. """
  233. @prim_attr_register
  234. def __init__(self, f):
  235. """Initialize InsertGradientOf."""
  236. self.add_prim_attr('side_effect_backprop', True)
  237. self.f = f
  238. def infer_shape(self, x_shape):
  239. return x_shape
  240. def infer_dtype(self, x_type):
  241. return x_type
  242. class HookBackward(PrimitiveWithInfer):
  243. """
  244. This operation is used as a tag to hook gradient in intermediate variables. Note that this function
  245. is only supported in Pynative Mode.
  246. Note:
  247. The hook function must be defined like `hook_fn(grad) -> Tensor or None`,
  248. where grad is the gradient passed to the primitive and gradient may be
  249. modified and passed to next primitive. The difference between a hook function and
  250. callback of InsertGradientOf is that a hook function is executed in the python
  251. environment while callback will be parsed and added to the graph.
  252. Args:
  253. hook_fn (Function): Python function. hook function.
  254. Inputs:
  255. - **inputs** (Tensor) - The variable to hook.
  256. Raises:
  257. TypeError: If `inputs` are not a Tensor.
  258. TypeError: If `hook_fn` is not a function of python.
  259. Examples:
  260. >>> def hook_fn(grad_out):
  261. ... print(grad_out)
  262. ...
  263. >>> grad_all = GradOperation(get_all=True)
  264. >>> hook = ops.HookBackward(hook_fn)
  265. >>> def hook_test(x, y):
  266. ... z = x * y
  267. ... z = hook(z)
  268. ... z = z * y
  269. ... return z
  270. ...
  271. >>> def backward(x, y):
  272. ... return grad_all(hook_test)(x, y)
  273. ...
  274. >>> output = backward(1, 2)
  275. >>> print(output)
  276. """
  277. def __init__(self, hook_fn, cell_id=""):
  278. """Initialize HookBackward."""
  279. super(HookBackward, self).__init__(self.__class__.__name__)
  280. self.add_prim_attr("cell_id", cell_id)
  281. self.init_attrs["cell_id"] = cell_id
  282. if not isinstance(hook_fn, (FunctionType, MethodType)):
  283. raise TypeError("Hook function should be python function type.")
  284. self.register_hook(hook_fn)
  285. self.cell_id = cell_id
  286. def infer_shape(self, *inputs_shape):
  287. if len(inputs_shape) == 1:
  288. return inputs_shape[0]
  289. return inputs_shape
  290. def infer_dtype(self, *inputs_type):
  291. if len(inputs_type) == 1:
  292. return inputs_type[0]
  293. return inputs_type
  294. class Print(PrimitiveWithInfer):
  295. """
  296. Outputs the tensor or string to stdout.
  297. Note:
  298. In pynative mode, please use python print function.
  299. In graph mode, the bool, int and float would be converted into Tensor to print,
  300. str remains unchanged.
  301. Inputs:
  302. - **input_x** (Union[Tensor, bool, int, float, str]) - The graph node to attach to.
  303. Supports multiple inputs which are separated by ','.
  304. Outputs:
  305. Tensor, has the same data type and shape as original `input_x`.
  306. Raises:
  307. TypeError: If `input_x` is not one of the following: Tensor, bool, int, float, str.
  308. Supported Platforms:
  309. ``Ascend`` ``GPU``
  310. Examples:
  311. >>> class PrintDemo(nn.Cell):
  312. ... def __init__(self):
  313. ... super(PrintDemo, self).__init__()
  314. ... self.print = ops.Print()
  315. ...
  316. ... def construct(self, x, y):
  317. ... self.print('Print Tensor x and Tensor y:', x, y)
  318. ... return x
  319. ...
  320. >>> x = Tensor(np.ones([2, 1]).astype(np.int32))
  321. >>> y = Tensor(np.ones([2, 2]).astype(np.int32))
  322. >>> net = PrintDemo()
  323. >>> result = net(x, y)
  324. Print Tensor x and Tensor y:
  325. Tensor(shape=[2, 1], dtype=Int32, value=
  326. [[1]
  327. [1]])
  328. Tensor(shape=[2, 2], dtype=Int32, value=
  329. [[1 1]
  330. [1 1]])
  331. """
  332. @prim_attr_register
  333. def __init__(self):
  334. """Initialize Print."""
  335. self.add_prim_attr("side_effect_io", True)
  336. def __call__(self, *args):
  337. for arg in args:
  338. print(arg)
  339. def infer_shape(self, *inputs):
  340. return [1]
  341. def infer_dtype(self, *inputs):
  342. # check argument types except the last one (io state).
  343. for ele in inputs[:-1]:
  344. validator.check_subclass("input", ele,
  345. [mstype.tensor, mstype.int_, mstype.float_, mstype.bool_, mstype.string],
  346. self.name)
  347. return mstype.int32
  348. class Assert(PrimitiveWithInfer):
  349. """
  350. Asserts that the given condition is True.
  351. If input condition evaluates to false, print the list of tensor in data.
  352. Args:
  353. summarize (int): Print this many entries of each tensor.
  354. Inputs:
  355. - **condition** [Union[Tensor[bool], bool]] - The condition to evaluate.
  356. - **input_data** (Union(tuple[Tensor], list[Tensor])) - The tensors to print out when condition is false.
  357. Raises:
  358. TypeError: If `summarize` is not an int.
  359. TypeError: If `condition` is neither a Tensor nor a bool.
  360. TypeError: If `input_data` is neither a tuple nor a list.
  361. Examples:
  362. >>> class AssertDemo(nn.Cell):
  363. ... def __init__(self):
  364. ... super(AssertDemo, self).__init__()
  365. ... self.assert1 = ops.Assert(summarize=10)
  366. ... self.add = ops.Add()
  367. ...
  368. ... def construct(self, x, y):
  369. ... data = self.add(x, y)
  370. ... self.assert1(True, [data])
  371. ... return data
  372. ...
  373. """
  374. @prim_attr_register
  375. def __init__(self, summarize=3):
  376. """Initialize Assert"""
  377. self.summarize = validator.check_value_type("summarize", summarize, [int], self.name)
  378. def infer_shape(self, condition, inputs):
  379. condition_len = len(condition)
  380. validator.check_int(condition_len, 1, Rel.LE, "condition's rank", self.name)
  381. if condition_len == 1:
  382. validator.check_equal_int(condition[0], 1, "condition[0]", self.name)
  383. return [1]
  384. def infer_dtype(self, condition, inputs):
  385. validator.check_scalar_or_tensor_types_same({"condition": condition}, [mstype.bool_], self.name)
  386. for dtype in inputs:
  387. validator.check_subclass("input", dtype, [mstype.tensor], self.name)
  388. return mstype.int32