You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base.py 9.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. # This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  2. #
  3. # Copyright 2020 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. # ============================================================================
  17. """Basic composite operations."""
  18. from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, MultitypeFuncGraph_, Tail_, TensorSlice_, \
  19. TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_
  20. from ...common import dtype as mstype
  21. from ...common.api import ms_function
  22. from .. import functional as F
  23. from .. import operations as P
  24. __all__ = [EnvInstance_, TensorSlice_, TupleAdd_, TupleSlice_, UnpackCall_]
  25. def add_flags(fn, **flags):
  26. """
  27. An interface to add flag for a function.
  28. Note:
  29. Only supports bool value.
  30. Args:
  31. fn (Function): Function or cell to add flag.
  32. flags (bool): Flags use kwargs.
  33. Returns:
  34. Function, the fn added flags.
  35. Examples:
  36. >>> add_flags(net, predit=True)
  37. """
  38. # need set the attr and access on c++
  39. if not hasattr(fn, "_mindspore_flags"):
  40. fn._mindspore_flags = {}
  41. fn._mindspore_flags.update({**flags})
  42. return fn
  43. def core(fn=None, **flags):
  44. """
  45. A decorator to add flag to a function.
  46. By default, the function is marked core=True using this decorator to
  47. set flag to a graph.
  48. Args:
  49. fn (Function): Function to add flag. Default: None.
  50. flags (dict): The following flags can be set core, which indicates that this is a core function or
  51. other flag. Default: None.
  52. """
  53. # need set the attr and access on c++
  54. def deco(fn):
  55. fn._mindspore_flags = {
  56. 'core': True,
  57. **flags,
  58. }
  59. return fn
  60. if fn is not None:
  61. ret = deco(fn)
  62. else:
  63. ret = deco
  64. return ret
  65. class GradOperation(GradOperation_):
  66. """
  67. An metafuncgraph object which is used to get the gradient of output of a network(function).
  68. The GradOperation will convert the network(function) into a back propagation graph.
  69. Args:
  70. get_all (bool): If True, get all the gradients w.r.t inputs. Default: False.
  71. get_by_list (bool): If True, get all the gradients w.r.t Parameter variables.
  72. If get_all and get_by_list are both False, get the gradient w.r.t first input.
  73. If get_all and get_by_list are both True, get the gradients w.r.t inputs and Parameter variables
  74. at the same time in the form of ((grads w.r.t inputs), (grads w.r.t parameters)). Default: False.
  75. sens_param (bool): Whether append sensitivity as input. If sens_param is False,
  76. a 'ones_like(outputs)' sensitivity will be attached automatically. Default: False.
  77. """
  78. def __init__(self, name,
  79. get_all=False, get_by_list=False, sens_param=False):
  80. self.get_all = get_all
  81. self.get_by_list = get_by_list
  82. self.sens_param = sens_param
  83. GradOperation_.__init__(self, name, get_all, get_by_list, sens_param)
  84. self.grad_fn = None
  85. self.fn = None
  86. def __call__(self, fn, weights=None):
  87. grad_ = GradOperation('grad', self.get_all, self.get_by_list, self.sens_param)
  88. if self.grad_fn is None or self.fn != fn:
  89. if self.get_by_list:
  90. @ms_function(obj=fn)
  91. def after_grad(*args):
  92. return grad_(fn, weights)(*args)
  93. else:
  94. @ms_function(obj=fn)
  95. def after_grad(*args):
  96. return grad_(fn)(*args)
  97. self.grad_fn = after_grad
  98. self.fn = fn
  99. return self.grad_fn
  100. grad = GradOperation('grad')
  101. grad_all = GradOperation('get_all', get_all=True)
  102. grad_by_list = GradOperation('get_by_list', get_by_list=True)
  103. grad_with_sens = GradOperation('grad_with_sens', sens_param=True)
  104. grad_all_with_sens = GradOperation('grad_all_with_sens', get_all=True, sens_param=True)
  105. grad_by_list_with_sens = GradOperation('grad_by_list_with_sens', get_by_list=True, sens_param=True)
  106. class MultitypeFuncGraph(MultitypeFuncGraph_):
  107. """
  108. Generate multiply graph.
  109. MultitypeFuncGraph is a class used to generate graphs for function with different type as input.
  110. Args:
  111. name (str): Operator name.
  112. Raises:
  113. ValueError: Cannot find matching fn for the given args.
  114. Examples:
  115. >>> # `add` is a metagraph object which will add two objects according to
  116. >>> # input type using ".register" decorator.
  117. >>> add = MultitypeFuncGraph('add')
  118. """
  119. def __init__(self, name):
  120. MultitypeFuncGraph_.__init__(self, name)
  121. self.entries = list()
  122. def __call__(self, *args):
  123. for sig, fn in self.entries:
  124. if len(sig) != len(args):
  125. continue
  126. output = fn(*args)
  127. return output
  128. raise ValueError("Cannot find fn match given args.")
  129. def register(self, *type_names):
  130. """Register a function for the given type string."""
  131. def deco(fn):
  132. self.register_fn(type_names, fn)
  133. self.entries.append((type_names, fn))
  134. return fn
  135. return deco
  136. class HyperMap(HyperMap_):
  137. """
  138. Hypermap will apply the set operation on input sequences.
  139. Which will apply the operations of every elements of the sequence.
  140. Args:
  141. ops (Union[MultitypeFuncGraph, None]): `ops` is the operation to apply. If `ops` is `None`,
  142. the operations should be putted in the first input of the instance.
  143. Inputs:
  144. - **args** (Tuple[sequence]) - If `ops` is not `None`, all the inputs should be the same length sequences,
  145. and each row of the sequences. e.g. If args length is 2, and for `i` in length of each sequence
  146. `(args[0][i], args[1][i])` will be the input of the operation.
  147. If `ops` is not `None`, the first input is the operation, and the other is inputs.
  148. Outputs:
  149. sequence, the output will be same type and same length of sequence from input and the value of each element
  150. is the result of operation apply each row of element. e.g. `operation(args[0][i], args[1][i])`.
  151. """
  152. def __init__(self, ops=None):
  153. self.ops = ops
  154. if ops:
  155. HyperMap_.__init__(self, ops)
  156. else:
  157. HyperMap_.__init__(self)
  158. def __call__(self, *args):
  159. func = args[0]
  160. count = 0
  161. count_max = 1
  162. args_list = args[1:]
  163. if self.ops is not None:
  164. func = self.ops
  165. args_list = args
  166. for item in args_list:
  167. if isinstance(item, (tuple, list)):
  168. count_max = len(item)
  169. break
  170. def get_item(x):
  171. nonlocal count
  172. if isinstance(x, (tuple, list)):
  173. return x[count]
  174. return x
  175. for i in range(count_max):
  176. true_args = tuple(map(get_item, args_list))
  177. func(*true_args)
  178. count = i + 1
  179. return True
  180. def register(self, *type_names):
  181. """Register a function for the given type string."""
  182. def deco(fn):
  183. self.register_fn(type_names, fn)
  184. return fn
  185. return deco
  186. class _ListAppend(ListAppend_):
  187. """
  188. A metafuncgraph class that append one element to list.
  189. Args:
  190. name (str): The name of the metafuncgraph object.
  191. """
  192. def __init__(self, name):
  193. ListAppend_.__init__(self, name)
  194. def __call__(self, *args):
  195. pass
  196. _append = _ListAppend("append")
  197. class _Tail(Tail_):
  198. """
  199. A metafuncgraph class that generates tail elements of the tuple.
  200. Args:
  201. name (str): The name of the metafuncgraph object.
  202. """
  203. def __init__(self, name):
  204. Tail_.__init__(self, name)
  205. def __call__(self, *args):
  206. pass
  207. tail = _Tail('tail')
  208. class _ZipOperation(ZipOperation_):
  209. """Generates a tuple of zip iterations for inputs."""
  210. def __init__(self, name):
  211. ZipOperation_.__init__(self, name)
  212. def __call__(self, *args):
  213. pass
  214. zip_operation = _ZipOperation('zip_operation')
  215. """`zip_operation` will generate a tuple of zip iterations of inputs."""
  216. env_get = MultitypeFuncGraph("env_get")
  217. @env_get.register("EnvType", "Tensor")
  218. def _tensor_env_get(env, parameter):
  219. """Used to get env."""
  220. return F.env_getitem(env, F.ref_to_embed(parameter), F.zeros_like_tensor(parameter))
  221. _mp_cast_helper = MultitypeFuncGraph('mixed_precision_cast_helper')
  222. @_mp_cast_helper.register("TypeType", "Number")
  223. @core
  224. def _mixed_precision_cast_helper_1(type_, x):
  225. """if x is float cast to type."""
  226. # type_ is place holder
  227. return x
  228. @_mp_cast_helper.register("TypeType", "Tensor")
  229. @core
  230. def _mixed_precision_cast_helper_2(type_, x):
  231. """if x is float cast to type."""
  232. if F.issubclass_(F.dtype(x), mstype.float_):
  233. return P.Cast()(x, type_)
  234. return x
  235. @_mp_cast_helper.register("TypeType", "Tuple")
  236. @core
  237. def _mixed_precision_cast_helper_3(type_, x):
  238. """if x is a tuple"""
  239. t = ()
  240. for item in x:
  241. t = t + (_mp_cast_helper(type_, item),)
  242. return t