You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

autodiff.py 18 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. #!/usr/bin/env python3
  2. # coding: utf-8
  3. # Copyright 2019 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """Automatic differentiation of tensor expressions."""
  17. import akg
  18. from akg.tvm._ffi.function import _init_api
  19. from akg.tvm._ffi.node import NodeBase, register_node
  20. from akg.utils.format_transform import get_shape
  21. _init_api("akg.autodiff")
  22. def collect_subtensors_by_name(tensor, name, result):
  23. """
  24. find all the subtensors with names matched the pattern `name`.
  25. Args:
  26. tensor: An input tensor.
  27. name: the `name` pattern to be matched.
  28. result: list of all subtensors found with name matched.
  29. Returns:
  30. list of all subtensors found with name matched.
  31. """
  32. for child in tensor.op.input_tensors:
  33. child_result = collect_by_name(child, name, result)
  34. result.extend(child_result)
  35. if tensor.op.name.find(name) != -1:
  36. result.append([tensor])
  37. return result
  38. @akg.tvm.register_func("akg.autodiff.export_to_DOT")
  39. def export_to_dot(tensors, filename="test.dot"):
  40. """
  41. Export computation tree of tensors to a DOT file.
  42. Args:
  43. tensors: A single/list/array of input tensors.
  44. filename: the name of the DOT file to be generated.
  45. """
  46. def export_tensor_shape(a_shape):
  47. result = "("
  48. for _, a_shp in enumerate(a_shape):
  49. result = result + str(a_shp.value) + ", "
  50. result = result + ")"
  51. return result
  52. def recursive_collect_nodes(tensor, exported_op_nodes, repeat_name):
  53. if tensor in exported_op_nodes:
  54. return exported_op_nodes, repeat_name
  55. if not exported_op_nodes:
  56. exported_op_nodes = {tensor: tensor.op.name}
  57. else:
  58. if tensor.op.name in exported_op_nodes.values():
  59. exported_op_nodes[tensor] = tensor.op.name + '_r' + str(repeat_name)
  60. repeat_name = repeat_name + 1
  61. else:
  62. exported_op_nodes[tensor] = tensor.op.name
  63. # exported_op_nodes[tensor] contains the name in DOT for "tensor"
  64. # If name is duplicated, a postfix '-r' + number is add to the end
  65. for child in tensor.op.input_tensors:
  66. if child not in exported_op_nodes:
  67. exported_op_nodes, repeat_name = recursive_collect_nodes(child, exported_op_nodes, repeat_name)
  68. return exported_op_nodes, repeat_name
  69. def export_node_name(tensor):
  70. if isinstance(tensor.op, akg.tvm.tensor.ComputeOp):
  71. if isinstance(tensor.op.body[0], akg.tvm.expr.Reduce):
  72. tensor_opcode_name = 'Reduce'
  73. elif isinstance(tensor.op.body[0], akg.tvm.expr.Mul):
  74. tensor_opcode_name = '*'
  75. elif isinstance(tensor.op.body[0], akg.tvm.expr.Add):
  76. tensor_opcode_name = '+'
  77. elif isinstance(tensor.op.body[0], akg.tvm.expr.Sub):
  78. tensor_opcode_name = '-'
  79. elif isinstance(tensor.op.body[0], akg.tvm.expr.Div):
  80. tensor_opcode_name = '/'
  81. elif isinstance(tensor.op.body[0], akg.tvm.expr.Call):
  82. tensor_opcode_name = 'Call ' + tensor.op.body[0].name
  83. elif isinstance(tensor.op.body[0], akg.tvm.expr.Cast):
  84. tensor_opcode_name = 'Cast:' + tensor.op.input_tensors[0].dtype + '=>' + tensor.dtype
  85. else:
  86. tensor_opcode_name = 'Unsupported yet OP'
  87. tensor_node_name = ' "' + exported_op_nodes[tensor] + '" [label = "' + exported_op_nodes[tensor] +\
  88. '\\n' + export_tensor_shape(tensor.shape) + '; ' + tensor.dtype + '\\n' +\
  89. tensor_opcode_name + '"; shape = ellipse; style = filled; color = lightgrey];'
  90. else: # isinstance(tensor.op,akg.tvm.tensor.PlaceholderOp):
  91. tensor_node_name = ' "' + exported_op_nodes[tensor] + '" [label = "' + exported_op_nodes[tensor] +\
  92. '\\n' + export_tensor_shape(tensor.shape) +\
  93. '"; shape = box; style = filled; color = lightseagreen];'
  94. return tensor_node_name
  95. def recursive_export_nodes_name(tensor, f, exported_op_nodes):
  96. for child in tensor.op.input_tensors:
  97. recursive_export_nodes_name(child, f, exported_op_nodes)
  98. if isinstance(tensor.op, akg.tvm.tensor.ComputeOp):
  99. if isinstance(tensor.op.body[0], (akg.tvm.expr.Mul, akg.tvm.expr.Add, akg.tvm.expr.Sub, akg.tvm.expr.Div)):
  100. if len(tensor.op.input_tensors) < 2:
  101. if isinstance(tensor.op.body[0].a, akg.tvm.expr.FloatImm):
  102. tensor_node_name = ' "Const_a_' + exported_op_nodes[tensor] +\
  103. '" [label = "' + str(tensor.op.body[0].a.value) + '\\n' +\
  104. tensor.op.body[0].a.dtype +\
  105. '"; shape = box; style = filled; color = lightseagreen];'
  106. f.write(tensor_node_name + "\n")
  107. if isinstance(tensor.op.body[0].b, akg.tvm.expr.FloatImm):
  108. tensor_node_name = ' "Const_b_' + exported_op_nodes[tensor] +\
  109. '" [label = "' + str(tensor.op.body[0].b.value) + '\\n' +\
  110. tensor.op.body[0].b.dtype +\
  111. '"; shape = box; style = filled; color = lightseagreen];'
  112. f.write(tensor_node_name + "\n")
  113. f.write(export_node_name(tensor) + "\n")
  114. def recursive_export_edges(tensor, f, exported_op_nodes, exported_edges):
  115. to_name = '"' + exported_op_nodes[tensor] + '"'
  116. for child in tensor.op.input_tensors:
  117. recursive_export_edges(child, f, exported_op_nodes, exported_edges)
  118. from_name = '"' + exported_op_nodes[child] + '"'
  119. if (from_name, to_name) not in exported_edges:
  120. exported_edges.add((from_name, to_name))
  121. f.write(' ' + from_name + " -> " + to_name
  122. + ' [label = "' + export_tensor_shape(child.shape) + '"];\n')
  123. if isinstance(tensor.op, akg.tvm.tensor.ComputeOp):
  124. if isinstance(tensor.op.body[0], (akg.tvm.expr.Mul, akg.tvm.expr.Add, akg.tvm.expr.Sub, akg.tvm.expr.Div)):
  125. if len(tensor.op.input_tensors) < 2:
  126. if isinstance(tensor.op.body[0].a, akg.tvm.expr.FloatImm):
  127. from_name = '"Const_a_' + exported_op_nodes[tensor] + '"'
  128. if (from_name, to_name) not in exported_edges:
  129. exported_edges.add((from_name, to_name))
  130. f.write(' ' + from_name + " -> " + to_name + ' [label = "(const)"];\n')
  131. if isinstance(tensor.op.body[0].b, akg.tvm.expr.FloatImm):
  132. from_name = '"Const_b_' + exported_op_nodes[tensor] + '"'
  133. if (from_name, to_name) not in exported_edges:
  134. exported_edges.add((from_name, to_name))
  135. f.write(' ' + from_name + " -> " + to_name + ' [label = "(const)"];\n')
  136. return exported_edges
  137. with open(filename, "w+") as f_out:
  138. f_out.write('digraph G {\n ration = compress;\n nodesep = 0.1; rankdir = BT\n')
  139. exported_op_nodes = dict() # dict of {tensor, tensor_name}
  140. exported_edges = set()
  141. repeat_name = 0
  142. if isinstance(tensors, akg.tvm.container.Array):
  143. list_tensors = [x for x in tensors]
  144. else:
  145. if isinstance(tensors, akg.tvm.tensor.Tensor):
  146. list_tensors = [tensors]
  147. else:
  148. list_tensors = []
  149. for a_tensor in list_tensors:
  150. exported_op_nodes, repeat_name = recursive_collect_nodes(a_tensor, exported_op_nodes, repeat_name)
  151. recursive_export_nodes_name(a_tensor, f_out, exported_op_nodes)
  152. exported_edges = recursive_export_edges(a_tensor, f_out, exported_op_nodes, exported_edges)
  153. f_out.write("\n}\n")
  154. variable_map = {}
  155. def register_variables(name, input, output):
  156. """register variables as a dictionary."""
  157. if not isinstance(name, str):
  158. raise ValueError("key {} is not str.".format(name))
  159. variable_map[name] = [output, input]
  160. def get_variables(name):
  161. """get variables from dictionary."""
  162. if isinstance(name, str):
  163. if not variable_map[name]:
  164. raise ValueError("value to key {} is empty.".format(name))
  165. return variable_map[name]
  166. raise ValueError("key {} is not str.".format(name))
  167. @register_node
  168. class DifferentiationResult(NodeBase):
  169. """
  170. Result of differentiation.
  171. Args:
  172. result (list[tvm.tensor.Tensor]):
  173. The requested adjoints, i.e. the Jacobians or gradients of the given output
  174. wrt to the given inputs.
  175. adjoints (dict[tvm.tensor.Tensor, tvm.tensor.Tensor]):
  176. A map from tensors to the corresponding adjoints (including internal nodes).
  177. adjoint_summands (dict[tvm.tensor.Tensor, dict[tvm.tensor.Tensor, tvm.tensor.Tensor]]):
  178. Single summands of the adjoints.
  179. """
  180. # Here we convert tvm Maps to dicts because Map compares keys by reference which is
  181. # wrong for tvm.tensor.Tensors. Hopefully, in the future Map gets fixed somehow, and these properties
  182. # may be removed then.
  183. @property
  184. def adjoints(self):
  185. res = NodeBase.__getattr__(self, 'adjoints')
  186. return dict(res.items())
  187. @property
  188. def adjoint_summands(self):
  189. res = NodeBase.__getattr__(self, 'adjoint_summands')
  190. return {k: dict(v.items()) for k, v in res.items()}
  191. def _check_not_empty(self):
  192. if not self.result:
  193. raise ValueError("The result of differentiation does not contain any explicitly "
  194. "requested results, so using it as an iterable is probably a mistake. "
  195. "Please explicitly use res.adjoints to get adjoints or res.result to "
  196. "get the empty list.")
  197. def __getitem__(self, i):
  198. self._check_not_empty()
  199. return self.result[i]
  200. def __len__(self):
  201. self._check_not_empty()
  202. return len(self.result)
  203. def differentiate(output, inputs=None, head=None, ad_attrs=None, new_pld_array=None, override=None, fdiff=None):
  204. """
  205. Perform operator-level automatic differentiation.
  206. Args:
  207. output (tvm.tensor.Tensor): The tensor to differentiate.
  208. inputs (list[tvm.tensor.Tensor]): The list of input tensors.
  209. When the list is empty or None, will perform differentiation with respect to all tensors the output depends
  210. on (i.e. will compute all adjoints and populate the corresponding dict, but the list of results will be
  211. empty). Default: None.
  212. head (tvm.tensor.Tensor): The adjoint of the output.
  213. in other words, some tensors, by which the Jacobians will be multiplied. Its shape must be of the form
  214. `prefix + output.shape`. For example, if the shape of `output` is (2, 3), the shape of `head` could
  215. be (2, 3), (?, 2, 3) and etc.
  216. If `None` is passed, the identity tensor of shape `output.shape + output.shape` will be used.
  217. Default: None.
  218. ad_attrs (dict): The additional attributes for the auto-differentiate computation. Default: None.
  219. new_pld_array (list): List of additional variables which could be used in differentiation. Default: None.
  220. override (dict): A dictionary to override differentiation for certain tensors.
  221. Override is a dictionary with types: {tvm.tensor.Tensor: (list[tvm.tensor.Tensor],
  222. callable[tvm.tensor.Tensor, list[tvm.tensor.Tensor], tvm.tensor.Tensor, list[tvm.tensor.Tensor]])}.
  223. This dict maps tensors `t` to pairs `(dependencies, custom_diff)` where `dependencies` is a list of
  224. tensors which are considered to be inputs of `t` (which may differ from the immediate inputs),
  225. and `custom_diff` is a custom differentiation function which will be called as
  226. `custom_diff(t, dependencies, adjoint, new_pld_array)` and should return a list of adjoints
  227. corresponding to dependencies.
  228. Note that this function differs from the one required for `fdiff`
  229. in that it takes a list of inputs instead of a single input
  230. and returns a list of adjoints instead of a single adjoint. Default: None.
  231. fdiff (callable[tvm.tensor.Tensor, tvm.tensor.Tensor, tvm.tensor.Tensor, tvm.tensor.Tensor]): The default
  232. function performing differentiation and multiplication, by default `akg.autodiff.DiffBuildingBlock` is used.
  233. The function must accept parameters:
  234. - `output` - an output tensor
  235. - `input` - an input tensor
  236. - `head` - the adjoint of the output tensor
  237. - `ad_attrs` - the additional attributes for the auto-differentiate computation
  238. - `new_pld_array` - the additional tensors with information for the auto-differentiate computation
  239. The result should be `head` multiplied by the Jacobians of `output` wrt `input`. Default: None.
  240. Returns:
  241. DifferentiationResult.
  242. class DifferentiationResult is used to represent a differentiation result, including:
  243. - result (list[tvm.tensor.Tensor]):
  244. The requested adjoints, i.e. the Jacobians or gradients of the given output
  245. with respect to the given inputs.
  246. - adjoints (dict{tvm.tensor.Tensor: tvm.tensor.Tensor}):
  247. A dict from tensors to the corresponding adjoints (including internal nodes).
  248. - adjoint_summands (dict{tvm.tensor.Tensor: dict{tvm.tensor.Tensor: tvm.tensor.Tensor}}):
  249. Single summands of the adjoints.
  250. Raises:
  251. ValueError: If the shape of `head` is invalid.
  252. Examples:
  253. >>> x = akg.tvm.placeholder((32, 3, 28, 28), name='x')
  254. >>> w1 = akg.tvm.placeholder((10, 3, 3, 3), name='w1')
  255. >>> z1 = akg.topi.nn.conv2d(x, w1, 1, 0, 1)
  256. >>> z2 = akg.topi.nn.flatten(z1)
  257. >>> y = akg.topi.sum(z2)
  258. >>>
  259. >>> # produce gradients
  260. >>> [dw1, dw2] = akg.differentiate(y, [x, w1])
  261. >>>
  262. >>> # produce Jacobians
  263. >>> [jw1, jw2] = akg.differentiate(z2, [x, w1])
  264. >>>
  265. >>> # produce Jacobians, the head adjoint for z2 is provided manually
  266. >>> [dw1, dw2] = akg.differentiate(z2, [x, w1], akg.topi.full_like(z2, 1.0))
  267. >>>
  268. >>> # produce gradients wrt all inputs
  269. >>> res = akg.differentiate(y)
  270. >>> dw1 = res.adjoints[x]
  271. >>> dw2 = res.adjoints[w1]
  272. >>>
  273. >>> # a custom differentiation function
  274. >>> head = akg.tvm.placeholder((1,), name = 'head')
  275. >>> def my_fdiff(out, inp, head, ad_attrs, new_pld_array):
  276. >>> return [akg.tvm.compute(inp[0].shape, lambda ax0, ax1, ax2, ax3: head[ax0, ax3 + ax2*26 + ax1*676])]
  277. >>>
  278. >>> # using a custom differentiation function only for z2
  279. >>> res = akg.differentiate(y, [x, w1], head, None, None, override={z2: ([z1], my_fdiff)})
  280. """
  281. # check whether head shape is compatible with output shape.
  282. if head is not None:
  283. output_shape = get_shape(output)
  284. head_shape = get_shape(head)
  285. output_dim = len(output_shape)
  286. head_last_shape = head_shape[-output_dim:]
  287. if head_last_shape != output_shape:
  288. raise ValueError("operands could not be broadcast together with head shape %s and output shape %s" %
  289. (str(head_shape), str(output_shape)))
  290. if inputs is None:
  291. inputs = []
  292. if override is not None:
  293. override_deps = []
  294. if fdiff is None:
  295. fdiff = DiffBuildingBlock
  296. if override is not None:
  297. def modified_fdiff(out, inp, head, ad_attrs, new_pld_array, override=override, old_fdiff=fdiff, cache=None):
  298. if cache is None:
  299. cache = {}
  300. if out in override:
  301. if (out, head) not in cache:
  302. cache[(out, head)] = override[out][1](out, override[out][0], head, ad_attrs, new_pld_array)
  303. idx = override[out][0].index(inp)
  304. return cache[(out, head)][idx]
  305. return old_fdiff(out, inp, head, ad_attrs, new_pld_array)
  306. fdiff = modified_fdiff
  307. override_deps = {t: deps for t, (deps, _) in override.items()}
  308. return akg.autodiff.Differentiate(output, inputs, head, ad_attrs, None, fdiff, override_deps)
  309. if new_pld_array is None:
  310. return akg.autodiff.Differentiate(output, inputs, head, ad_attrs, [], fdiff)
  311. return akg.autodiff.Differentiate(output, inputs, head, ad_attrs, new_pld_array, fdiff)