You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pytree.py 4.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. from typing import Callable, NamedTuple
  11. SUPPORTED_TYPE = {}
  12. NodeType = NamedTuple("NodeType", [("flatten", Callable), ("unflatten", Callable)])
  13. def register_supported_type(type, flatten, unflatten):
  14. SUPPORTED_TYPE[type] = NodeType(flatten, unflatten)
  15. def _dict_flatten(inp):
  16. aux_data = []
  17. results = []
  18. for key, value in sorted(inp.items()):
  19. results.append(value)
  20. aux_data.append(key)
  21. return results, tuple(aux_data)
  22. def _dict_unflatten(inps, aux_data):
  23. return dict(zip(aux_data, inps))
  24. register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x))
  25. register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: list(x))
  26. register_supported_type(dict, _dict_flatten, _dict_unflatten)
  27. register_supported_type(
  28. slice,
  29. lambda x: ([x.start, x.stop, x.step], None),
  30. lambda x, aux_data: slice(x[0], x[1], x[2]),
  31. )
  32. def tree_flatten(
  33. values,
  34. leaf_type: Callable = lambda x: type(x),
  35. is_leaf: Callable = lambda _: True,
  36. is_const_leaf: Callable = lambda _: False,
  37. ):
  38. if type(values) not in SUPPORTED_TYPE:
  39. assert is_leaf(values)
  40. node = LeafDef(leaf_type(values))
  41. if is_const_leaf(values):
  42. node.const_val = values
  43. return [values,], node
  44. rst = []
  45. children_defs = []
  46. children_values, aux_data = SUPPORTED_TYPE[type(values)].flatten(values)
  47. for v in children_values:
  48. v_list, treedef = tree_flatten(v, leaf_type, is_leaf, is_const_leaf)
  49. rst.extend(v_list)
  50. children_defs.append(treedef)
  51. return rst, TreeDef(type(values), aux_data, children_defs)
  52. class TreeDef:
  53. def __init__(self, type, aux_data, children_defs):
  54. self.type = type
  55. self.aux_data = aux_data
  56. self.children_defs = children_defs
  57. self.num_leaves = sum(ch.num_leaves for ch in children_defs)
  58. def unflatten(self, leaves):
  59. assert len(leaves) == self.num_leaves
  60. start = 0
  61. children = []
  62. for ch in self.children_defs:
  63. children.append(ch.unflatten(leaves[start : start + ch.num_leaves]))
  64. start += ch.num_leaves
  65. return SUPPORTED_TYPE[self.type].unflatten(children, self.aux_data)
  66. def __hash__(self):
  67. return hash(
  68. tuple(
  69. [
  70. self.type,
  71. self.aux_data,
  72. self.num_leaves,
  73. tuple([hash(x) for x in self.children_defs]),
  74. ]
  75. )
  76. )
  77. def __eq__(self, other):
  78. return (
  79. self.type == other.type
  80. and self.aux_data == other.aux_data
  81. and self.num_leaves == other.num_leaves
  82. and self.children_defs == other.children_defs
  83. )
  84. def __repr__(self):
  85. return "{}[{}]".format(self.type.__name__, self.children_defs)
  86. class LeafDef(TreeDef):
  87. def __init__(self, type):
  88. if not isinstance(type, collections.abc.Sequence):
  89. type = (type,)
  90. super().__init__(type, None, [])
  91. self.num_leaves = 1
  92. self.const_val = None
  93. def unflatten(self, leaves):
  94. assert len(leaves) == 1
  95. assert isinstance(leaves[0], self.type), self.type
  96. return leaves[0]
  97. def __eq__(self, other):
  98. return self.type == other.type and self.const_val == other.const_val
  99. def __hash__(self):
  100. return hash(tuple([self.type, self.const_val]))
  101. def __repr__(self):
  102. return "Leaf({}[{}])".format(
  103. ", ".join(t.__name__ for t in self.type), self.const_val
  104. )

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台