You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pytree.py 4.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. from typing import Callable, NamedTuple
  11. import numpy as np
  12. SUPPORTED_TYPE = {}
  13. NodeType = NamedTuple("NodeType", [("flatten", Callable), ("unflatten", Callable)])
  14. def register_supported_type(type, flatten, unflatten):
  15. SUPPORTED_TYPE[type] = NodeType(flatten, unflatten)
  16. def _dict_flatten(inp):
  17. aux_data = []
  18. results = []
  19. for key, value in sorted(inp.items()):
  20. results.append(value)
  21. aux_data.append(key)
  22. return results, tuple(aux_data)
  23. def _dict_unflatten(inps, aux_data):
  24. return dict(zip(aux_data, inps))
  25. register_supported_type(list, lambda x: (x, None), lambda x, aux_data: list(x))
  26. register_supported_type(tuple, lambda x: (x, None), lambda x, aux_data: tuple(x))
  27. register_supported_type(dict, _dict_flatten, _dict_unflatten)
  28. register_supported_type(
  29. slice,
  30. lambda x: ([x.start, x.stop, x.step], None),
  31. lambda x, aux_data: slice(x[0], x[1], x[2]),
  32. )
  33. def tree_flatten(
  34. values,
  35. leaf_type: Callable = lambda x: type(x),
  36. is_leaf: Callable = lambda _: True,
  37. is_const_leaf: Callable = lambda _: False,
  38. ):
  39. if type(values) not in SUPPORTED_TYPE:
  40. assert is_leaf(values)
  41. node = LeafDef(leaf_type(values))
  42. if is_const_leaf(values):
  43. if isinstance(values, np.ndarray):
  44. node.const_val = str(values)
  45. else:
  46. node.const_val = values
  47. return [values,], node
  48. rst = []
  49. children_defs = []
  50. children_values, aux_data = SUPPORTED_TYPE[type(values)].flatten(values)
  51. for v in children_values:
  52. v_list, treedef = tree_flatten(v, leaf_type, is_leaf, is_const_leaf)
  53. rst.extend(v_list)
  54. children_defs.append(treedef)
  55. return rst, TreeDef(type(values), aux_data, children_defs)
  56. class TreeDef:
  57. def __init__(self, type, aux_data, children_defs):
  58. self.type = type
  59. self.aux_data = aux_data
  60. self.children_defs = children_defs
  61. self.num_leaves = sum(ch.num_leaves for ch in children_defs)
  62. def unflatten(self, leaves):
  63. assert len(leaves) == self.num_leaves
  64. start = 0
  65. children = []
  66. for ch in self.children_defs:
  67. children.append(ch.unflatten(leaves[start : start + ch.num_leaves]))
  68. start += ch.num_leaves
  69. return SUPPORTED_TYPE[self.type].unflatten(children, self.aux_data)
  70. def __hash__(self):
  71. return hash(
  72. tuple(
  73. [
  74. self.type,
  75. self.aux_data,
  76. self.num_leaves,
  77. tuple([hash(x) for x in self.children_defs]),
  78. ]
  79. )
  80. )
  81. def __eq__(self, other):
  82. return (
  83. self.type == other.type
  84. and self.aux_data == other.aux_data
  85. and self.num_leaves == other.num_leaves
  86. and self.children_defs == other.children_defs
  87. )
  88. def __repr__(self):
  89. return "{}[{}]".format(self.type.__name__, self.children_defs)
  90. class LeafDef(TreeDef):
  91. def __init__(self, type):
  92. if not isinstance(type, collections.abc.Sequence):
  93. type = (type,)
  94. super().__init__(type, None, [])
  95. self.num_leaves = 1
  96. self.const_val = None
  97. def unflatten(self, leaves):
  98. assert len(leaves) == 1
  99. assert isinstance(leaves[0], self.type), self.type
  100. return leaves[0]
  101. def __eq__(self, other):
  102. return self.type == other.type and self.const_val == other.const_val
  103. def __hash__(self):
  104. return hash(tuple([self.type, self.const_val]))
  105. def __repr__(self):
  106. return "Leaf({}[{}])".format(
  107. ", ".join(t.__name__ for t in self.type), self.const_val
  108. )

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台