You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

composite.h 9.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_H_
  19. #define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_H_
  20. #include <vector>
  21. #include <string>
  22. #include <unordered_map>
  23. #include <utility>
  24. #include <map>
  25. #include <set>
  26. #include <memory>
  27. #include "operator/composite/zip_operation.h"
  28. #include "operator/composite/list_append_operation.h"
  29. #include "operator/composite/do_signature.h"
  30. #include "operator/composite/unpack_call.h"
  31. #include "pipeline/static_analysis/static_analysis.h"
  32. #include "utils/misc.h"
  33. #include "utils/any.h"
  34. #include "ir/dtype.h"
  35. #include "ir/meta_func_graph.h"
  36. namespace mindspore {
  37. // namespace to support composite operators definition
  38. namespace prim {
  39. using AbstractSlicePtr = abstract::AbstractSlicePtr;
  40. using AbstractScalarPtr = abstract::AbstractScalarPtr;
  41. using AbstractTensorPtr = abstract::AbstractTensorPtr;
  42. using ElemwiseMap = std::unordered_map<std::string, PrimitivePtr>;
  43. using ArgsPairList = std::vector<std::pair<AnfNodePtr, TypePtr>>;
  44. class MultitypeFuncGraph : public MetaFuncGraph {
  45. public:
  46. explicit MultitypeFuncGraph(const std::string &name);
  47. ~MultitypeFuncGraph() override = default;
  48. MS_DECLARE_PARENT(MultitypeFuncGraph, MetaFuncGraph)
  49. using specialize_fn = FuncGraph *(*)(TypePtrList);
  50. // Register a method which specialize based on types vectors;
  51. virtual void Register(const TypePtrList &types, specialize_fn s_fn);
  52. virtual void Register(const TypePtrList &types, const py::function &py_fn);
  53. virtual void Register(const std::vector<std::string> &types_name, const py::function &py_fn);
  54. virtual void PyRegister(const py::tuple &tuple, const py::function &py_fn);
  55. FuncGraphPtr GenerateFromTypes(const TypePtrList &types) override;
  56. size_t GetPyFnCacheSize() const { return fn_cache_py_.size(); }
  57. const std::unordered_map<TypePtrList, py::function, TypeListHasher, TypeListEqual> &GetPyFunctions() const {
  58. return fn_cache_py_;
  59. }
  60. private:
  61. std::unordered_map<TypePtrList, specialize_fn, TypeListHasher, TypeListEqual> fn_cache_;
  62. std::unordered_map<TypePtrList, py::function, TypeListHasher, TypeListEqual> fn_cache_py_;
  63. };
  64. using MultitypeFuncGraphPtr = std::shared_ptr<MultitypeFuncGraph>;
  65. class HyperMap : public MetaFuncGraph {
  66. public:
  67. explicit HyperMap(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr);
  68. HyperMap(const HyperMap &h);
  69. void Init();
  70. HyperMap &operator=(const HyperMap &h) {
  71. if (this != &h) {
  72. fn_leaf_ = h.fn_leaf_;
  73. broadcast_ = h.broadcast_;
  74. nonleaf_ = h.nonleaf_;
  75. if (fn_leaf_) {
  76. name_ = "hyper_map[" + fn_leaf_->name() + "]";
  77. }
  78. }
  79. return *this;
  80. }
  81. ~HyperMap() override = default;
  82. MS_DECLARE_PARENT(HyperMap, MetaFuncGraph)
  83. abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const override;
  84. FuncGraphPtr GenerateFromTypes(const TypePtrList &args_spec_list) override;
  85. MetaFuncGraphPtr GetFnLeaf() { return fn_leaf_; }
  86. private:
  87. AnfNodePtr FullMake(TypePtr type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
  88. const ArgsPairList &arg_map);
  89. AnfNodePtr FullMake(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
  90. const ArgsPairList &arg_map);
  91. AnfNodePtr FullMake(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
  92. const ArgsPairList &arg_map);
  93. AnfNodePtr FullMake(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
  94. const ArgsPairList &arg_map);
  95. AnfNodePtr Make(const FuncGraphPtr &graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map);
  96. ArgsPairList Harmonize(const FuncGraphPtr &graph, const ArgsPairList &args_spec_list);
  97. MultitypeFuncGraphPtr fn_leaf_;
  98. bool broadcast_;
  99. std::set<TypeId> nonleaf_;
  100. };
  101. using HyperMapPtr = std::shared_ptr<HyperMap>;
  102. class HyperMapPy : public HyperMap {
  103. public:
  104. explicit HyperMapPy(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr) : HyperMap(fn_leaf) {}
  105. ~HyperMapPy() override = default;
  106. MS_DECLARE_PARENT(HyperMapPy, HyperMap)
  107. };
  108. using HyperMapPyPtr = std::shared_ptr<HyperMapPy>;
  109. extern ValuePtr kCompositeHyperMap;
  110. class Tail : public MetaFuncGraph {
  111. public:
  112. explicit Tail(const std::string &name) : MetaFuncGraph(name) {}
  113. ~Tail() override = default;
  114. MS_DECLARE_PARENT(Tail, MetaFuncGraph)
  115. FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
  116. FuncGraphPtr GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tuple);
  117. FuncGraphPtr GenerateListFuncGraph(const abstract::AbstractListPtr &a_list);
  118. friend bool operator==(const Tail &lhs, const Tail &rhs) { return lhs.name_ == rhs.name_; }
  119. };
  120. using TailPtr = std::shared_ptr<Tail>;
  121. class MakeTupleGradient : public MetaFuncGraph {
  122. public:
  123. explicit MakeTupleGradient(const std::string &name) : MetaFuncGraph(name) {}
  124. ~MakeTupleGradient() override = default;
  125. MS_DECLARE_PARENT(MakeTupleGradient, MetaFuncGraph)
  126. FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
  127. friend bool operator==(const MakeTupleGradient &lhs, const MakeTupleGradient &rhs) { return lhs.name_ == rhs.name_; }
  128. };
  129. using MakeTupleGradientPtr = std::shared_ptr<MakeTupleGradient>;
  130. class GradOperation : public MetaFuncGraph {
  131. public:
  132. explicit GradOperation(const std::string &name, bool get_all = false, bool get_by_list = false,
  133. bool sens_param = false);
  134. ~GradOperation() override = default;
  135. MS_DECLARE_PARENT(GradOperation, MetaFuncGraph)
  136. FuncGraphPtr GetGrad(AnfNodePtr ptrNode, const AnfNodePtr &weights, const std::vector<AnfNodePtr> &ptrParams,
  137. bool applyJ = false);
  138. FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
  139. bool sens_param() const { return sens_param_; }
  140. bool get_all_;
  141. bool get_by_list_;
  142. bool sens_param_;
  143. private:
  144. void doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr ptrOut, AnfNodePtr ptrBprop, AnfNodePtr weights,
  145. ValueNodePtr opsTupleItem);
  146. };
  147. using GradOperationPtr = std::shared_ptr<GradOperation>;
  148. class ListMap {
  149. public:
  150. explicit ListMap(const std::string &name) : name_(name) { cache_.clear(); }
  151. ~ListMap() = default;
  152. void MakeCond(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &gnext_ptr, const FuncGraphPtr &graph_ptr);
  153. void MakeNext(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &gcond_ptr, const FuncGraphPtr &graph_ptr);
  154. FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list);
  155. private:
  156. std::string name_;
  157. std::map<std::vector<AnyPtr>, FuncGraphPtr> cache_;
  158. };
  159. class TupleAdd : public MetaFuncGraph {
  160. public:
  161. explicit TupleAdd(const std::string &name) : MetaFuncGraph(name) {}
  162. ~TupleAdd() override = default;
  163. MS_DECLARE_PARENT(TupleAdd, MetaFuncGraph)
  164. FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
  165. friend bool operator==(const TupleAdd &lhs, const TupleAdd &rhs) { return lhs.name_ == rhs.name_; }
  166. };
  167. using TupleAddPtr = std::shared_ptr<TupleAdd>;
  168. class TupleSlice : public MetaFuncGraph {
  169. public:
  170. explicit TupleSlice(const std::string &name) : MetaFuncGraph(name) {}
  171. ~TupleSlice() override = default;
  172. MS_DECLARE_PARENT(TupleSlice, MetaFuncGraph)
  173. FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
  174. friend bool operator==(const TupleSlice &lhs, const TupleSlice &rhs) { return lhs.name_ == rhs.name_; }
  175. };
  176. using TupleSlicePtr = std::shared_ptr<TupleSlice>;
  177. class TensorSlice : public MetaFuncGraph {
  178. public:
  179. explicit TensorSlice(const std::string &name) : MetaFuncGraph(name) {}
  180. ~TensorSlice() override = default;
  181. MS_DECLARE_PARENT(TensorSlice, MetaFuncGraph)
  182. FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
  183. friend bool operator==(const TensorSlice &lhs, const TensorSlice &rhs) { return lhs.name_ == rhs.name_; }
  184. FuncGraphPtr ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) const;
  185. };
  186. using TensorSlicePtr = std::shared_ptr<TensorSlice>;
  187. } // namespace prim
  188. } // namespace mindspore
  189. #endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_H_