You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

composite.h 8.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #ifndef MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_H_
  19. #define MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_H_
  20. #include <vector>
  21. #include <string>
  22. #include <unordered_map>
  23. #include <utility>
  24. #include <map>
  25. #include <set>
  26. #include <memory>
  27. #include "frontend/operator/composite/zip_operation.h"
  28. #include "frontend/operator/composite/list_append_operation.h"
  29. #include "frontend/operator/composite/do_signature.h"
  30. #include "frontend/operator/composite/unpack_call.h"
  31. #include "frontend/operator/composite/multitype_funcgraph.h"
  32. #include "pipeline/jit/static_analysis/static_analysis.h"
  33. #include "utils/misc.h"
  34. #include "utils/any.h"
  35. #include "ir/dtype.h"
  36. #include "ir/meta_func_graph.h"
  37. namespace mindspore {
  38. // namespace to support composite operators definition
  39. namespace prim {
  40. using AbstractSlicePtr = abstract::AbstractSlicePtr;
  41. using AbstractScalarPtr = abstract::AbstractScalarPtr;
  42. using AbstractTensorPtr = abstract::AbstractTensorPtr;
  43. using ElemwiseMap = std::unordered_map<std::string, PrimitivePtr>;
  44. using ArgsPairList = std::vector<std::pair<AnfNodePtr, TypePtr>>;
  45. class HyperMap : public MetaFuncGraph {
  46. public:
  47. explicit HyperMap(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr);
  48. HyperMap(const HyperMap &h);
  49. void Init();
  50. HyperMap &operator=(const HyperMap &h) {
  51. if (this != &h) {
  52. fn_leaf_ = h.fn_leaf_;
  53. broadcast_ = h.broadcast_;
  54. nonleaf_ = h.nonleaf_;
  55. if (fn_leaf_) {
  56. name_ = "hyper_map[" + fn_leaf_->name() + "]";
  57. }
  58. }
  59. return *this;
  60. }
  61. ~HyperMap() override = default;
  62. MS_DECLARE_PARENT(HyperMap, MetaFuncGraph)
  63. abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const override;
  64. FuncGraphPtr GenerateFromTypes(const TypePtrList &args_spec_list) override;
  65. MetaFuncGraphPtr GetFnLeaf() { return fn_leaf_; }
  66. private:
  67. AnfNodePtr FullMake(TypePtr type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
  68. const ArgsPairList &arg_map);
  69. AnfNodePtr FullMake(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
  70. const ArgsPairList &arg_map);
  71. AnfNodePtr FullMake(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
  72. const ArgsPairList &arg_map);
  73. AnfNodePtr FullMake(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
  74. const ArgsPairList &arg_map);
  75. AnfNodePtr Make(const FuncGraphPtr &graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map);
  76. ArgsPairList Harmonize(const FuncGraphPtr &graph, const ArgsPairList &args_spec_list);
  77. MultitypeFuncGraphPtr fn_leaf_;
  78. bool broadcast_;
  79. std::set<TypeId> nonleaf_;
  80. };
  81. using HyperMapPtr = std::shared_ptr<HyperMap>;
  82. class HyperMapPy : public HyperMap {
  83. public:
  84. explicit HyperMapPy(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr) : HyperMap(fn_leaf) {}
  85. ~HyperMapPy() override = default;
  86. MS_DECLARE_PARENT(HyperMapPy, HyperMap)
  87. };
  88. using HyperMapPyPtr = std::shared_ptr<HyperMapPy>;
  89. extern ValuePtr kCompositeHyperMap;
  90. enum TailType { kGradAll, kGradFirst, kNotGrad };
  91. class Tail : public MetaFuncGraph {
  92. public:
  93. explicit Tail(const std::string &name, TailType tail_type = kNotGrad) : MetaFuncGraph(name), tail_type_(tail_type) {}
  94. ~Tail() override = default;
  95. MS_DECLARE_PARENT(Tail, MetaFuncGraph)
  96. FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
  97. FuncGraphPtr GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue) const;
  98. friend bool operator==(const Tail &lhs, const Tail &rhs) { return lhs.name_ == rhs.name_; }
  99. private:
  100. TailType tail_type_;
  101. };
  102. using TailPtr = std::shared_ptr<Tail>;
  103. class MakeTupleGradient : public MetaFuncGraph {
  104. public:
  105. explicit MakeTupleGradient(const std::string &name) : MetaFuncGraph(name) {}
  106. ~MakeTupleGradient() override = default;
  107. MS_DECLARE_PARENT(MakeTupleGradient, MetaFuncGraph)
  108. FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
  109. friend bool operator==(const MakeTupleGradient &lhs, const MakeTupleGradient &rhs) { return lhs.name_ == rhs.name_; }
  110. };
  111. using MakeTupleGradientPtr = std::shared_ptr<MakeTupleGradient>;
  112. class MakeListGradient : public MetaFuncGraph {
  113. public:
  114. explicit MakeListGradient(const std::string &name) : MetaFuncGraph(name) {}
  115. ~MakeListGradient() override = default;
  116. MS_DECLARE_PARENT(MakeListGradient, MetaFuncGraph)
  117. FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
  118. friend bool operator==(const MakeListGradient &lhs, const MakeListGradient &rhs) { return lhs.name_ == rhs.name_; }
  119. };
  120. using MakeListGradientPtr = std::shared_ptr<MakeListGradient>;
  121. class GradOperation : public MetaFuncGraph {
  122. public:
  123. explicit GradOperation(const std::string &name, bool get_all = false, bool get_by_list = false,
  124. bool sens_param = false);
  125. ~GradOperation() override = default;
  126. MS_DECLARE_PARENT(GradOperation, MetaFuncGraph)
  127. FuncGraphPtr GetGrad(const AnfNodePtr &k, const AnfNodePtr &weights,
  128. const std::vector<AnfNodePtr> &forward_graph_params,
  129. const std::vector<AnfNodePtr> &weight_args = {});
  130. FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
  131. bool sens_param() const { return sens_param_; }
  132. bool get_all_;
  133. bool get_by_list_;
  134. bool sens_param_;
  135. private:
  136. void GradByParameter(const FuncGraphPtr &k_child, const AnfNodePtr &f_app, const AnfNodePtr &bprop,
  137. const AnfNodePtr &weights);
  138. };
  139. using GradOperationPtr = std::shared_ptr<GradOperation>;
  140. class ListMap {
  141. public:
  142. explicit ListMap(const std::string &name) : name_(name) { cache_.clear(); }
  143. ~ListMap() = default;
  144. void MakeCond(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &gnext_ptr, const FuncGraphPtr &graph_ptr);
  145. void MakeNext(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &gcond_ptr, const FuncGraphPtr &graph_ptr);
  146. FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list);
  147. private:
  148. std::string name_;
  149. std::map<std::vector<AnyPtr>, FuncGraphPtr> cache_;
  150. };
  151. class TupleAdd : public MetaFuncGraph {
  152. public:
  153. explicit TupleAdd(const std::string &name) : MetaFuncGraph(name) {}
  154. ~TupleAdd() override = default;
  155. MS_DECLARE_PARENT(TupleAdd, MetaFuncGraph)
  156. FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
  157. friend bool operator==(const TupleAdd &lhs, const TupleAdd &rhs) { return lhs.name_ == rhs.name_; }
  158. };
  159. using TupleAddPtr = std::shared_ptr<TupleAdd>;
  160. class TupleSlice : public MetaFuncGraph {
  161. public:
  162. explicit TupleSlice(const std::string &name) : MetaFuncGraph(name) {}
  163. ~TupleSlice() override = default;
  164. MS_DECLARE_PARENT(TupleSlice, MetaFuncGraph)
  165. FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
  166. friend bool operator==(const TupleSlice &lhs, const TupleSlice &rhs) { return lhs.name_ == rhs.name_; }
  167. };
  168. using TupleSlicePtr = std::shared_ptr<TupleSlice>;
  169. class TupleGetItemTensor : public MetaFuncGraph {
  170. public:
  171. explicit TupleGetItemTensor(const std::string &name) : MetaFuncGraph(name) {}
  172. ~TupleGetItemTensor() override = default;
  173. MS_DECLARE_PARENT(TupleGetItemTensor, MetaFuncGraph)
  174. FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
  175. friend bool operator==(const TupleGetItemTensor &lhs, const TupleGetItemTensor &rhs) {
  176. return lhs.name_ == rhs.name_;
  177. }
  178. };
  179. using TupleGetItemTensorPtr = std::shared_ptr<TupleGetItemTensor>;
  180. } // namespace prim
  181. } // namespace mindspore
  182. #endif // MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_H_