You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

composite.h 8.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #ifndef MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_H_
  19. #define MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_H_
  20. #include <vector>
  21. #include <string>
  22. #include <unordered_map>
  23. #include <utility>
  24. #include <map>
  25. #include <set>
  26. #include <memory>
  27. #include "frontend/operator/composite/zip_operation.h"
  28. #include "frontend/operator/composite/list_append_operation.h"
  29. #include "frontend/operator/composite/do_signature.h"
  30. #include "frontend/operator/composite/unpack_call.h"
  31. #include "frontend/operator/composite/multitype_funcgraph.h"
  32. #include "pipeline/jit/static_analysis/static_analysis.h"
  33. #include "utils/misc.h"
  34. #include "utils/any.h"
  35. #include "ir/dtype.h"
  36. #include "ir/meta_func_graph.h"
  37. namespace mindspore {
  38. // namespace to support composite operators definition
  39. namespace prim {
  40. using AbstractSlicePtr = abstract::AbstractSlicePtr;
  41. using AbstractScalarPtr = abstract::AbstractScalarPtr;
  42. using AbstractTensorPtr = abstract::AbstractTensorPtr;
  43. using ElemwiseMap = std::unordered_map<std::string, PrimitivePtr>;
  44. using ArgsPairList = std::vector<std::pair<AnfNodePtr, TypePtr>>;
  45. class HyperMap : public MetaFuncGraph {
  46. public:
  47. explicit HyperMap(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr);
  48. HyperMap(const HyperMap &h);
  49. void Init();
  50. HyperMap &operator=(const HyperMap &h) {
  51. if (this != &h) {
  52. fn_leaf_ = h.fn_leaf_;
  53. broadcast_ = h.broadcast_;
  54. nonleaf_ = h.nonleaf_;
  55. if (fn_leaf_) {
  56. name_ = "hyper_map[" + fn_leaf_->name() + "]";
  57. }
  58. }
  59. return *this;
  60. }
  61. ~HyperMap() override = default;
  62. MS_DECLARE_PARENT(HyperMap, MetaFuncGraph)
  63. abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const override;
  64. FuncGraphPtr GenerateFromTypes(const TypePtrList &args_spec_list) override;
  65. MetaFuncGraphPtr GetFnLeaf() { return fn_leaf_; }
  66. private:
  67. AnfNodePtr FullMake(TypePtr type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
  68. const ArgsPairList &arg_map);
  69. AnfNodePtr FullMake(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
  70. const ArgsPairList &arg_map);
  71. AnfNodePtr FullMake(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
  72. const ArgsPairList &arg_map);
  73. AnfNodePtr FullMake(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
  74. const ArgsPairList &arg_map);
  75. AnfNodePtr Make(const FuncGraphPtr &graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map);
  76. ArgsPairList Harmonize(const FuncGraphPtr &graph, const ArgsPairList &args_spec_list);
  77. MultitypeFuncGraphPtr fn_leaf_;
  78. bool broadcast_;
  79. std::set<TypeId> nonleaf_;
  80. };
  81. using HyperMapPtr = std::shared_ptr<HyperMap>;
  82. class HyperMapPy : public HyperMap {
  83. public:
  84. explicit HyperMapPy(const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr) : HyperMap(fn_leaf) {}
  85. ~HyperMapPy() override = default;
  86. MS_DECLARE_PARENT(HyperMapPy, HyperMap)
  87. };
  88. using HyperMapPyPtr = std::shared_ptr<HyperMapPy>;
  89. extern ValuePtr kCompositeHyperMap;
  90. class Tail : public MetaFuncGraph {
  91. public:
  92. explicit Tail(const std::string &name) : MetaFuncGraph(name) {}
  93. ~Tail() override = default;
  94. MS_DECLARE_PARENT(Tail, MetaFuncGraph)
  95. FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
  96. FuncGraphPtr GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tuple);
  97. FuncGraphPtr GenerateListFuncGraph(const abstract::AbstractListPtr &a_list);
  98. friend bool operator==(const Tail &lhs, const Tail &rhs) { return lhs.name_ == rhs.name_; }
  99. };
  100. using TailPtr = std::shared_ptr<Tail>;
  101. class MakeTupleGradient : public MetaFuncGraph {
  102. public:
  103. explicit MakeTupleGradient(const std::string &name) : MetaFuncGraph(name) {}
  104. ~MakeTupleGradient() override = default;
  105. MS_DECLARE_PARENT(MakeTupleGradient, MetaFuncGraph)
  106. FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
  107. friend bool operator==(const MakeTupleGradient &lhs, const MakeTupleGradient &rhs) { return lhs.name_ == rhs.name_; }
  108. };
  109. using MakeTupleGradientPtr = std::shared_ptr<MakeTupleGradient>;
  110. class MakeListGradient : public MetaFuncGraph {
  111. public:
  112. explicit MakeListGradient(const std::string &name) : MetaFuncGraph(name) {}
  113. ~MakeListGradient() override = default;
  114. MS_DECLARE_PARENT(MakeListGradient, MetaFuncGraph)
  115. FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
  116. friend bool operator==(const MakeListGradient &lhs, const MakeListGradient &rhs) { return lhs.name_ == rhs.name_; }
  117. };
  118. using MakeListGradientPtr = std::shared_ptr<MakeListGradient>;
  119. class GradOperation : public MetaFuncGraph {
  120. public:
  121. explicit GradOperation(const std::string &name, bool get_all = false, bool get_by_list = false,
  122. bool sens_param = false);
  123. ~GradOperation() override = default;
  124. MS_DECLARE_PARENT(GradOperation, MetaFuncGraph)
  125. FuncGraphPtr GetGrad(AnfNodePtr ptrNode, const AnfNodePtr &weights, const std::vector<AnfNodePtr> &ptrParams,
  126. const std::vector<AnfNodePtr> &args = {}, bool applyJ = false);
  127. FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
  128. bool sens_param() const { return sens_param_; }
  129. bool get_all_;
  130. bool get_by_list_;
  131. bool sens_param_;
  132. private:
  133. void doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr ptrOut, AnfNodePtr ptrBprop, AnfNodePtr weights,
  134. ValueNodePtr opsTupleItem);
  135. };
  136. using GradOperationPtr = std::shared_ptr<GradOperation>;
  137. class ListMap {
  138. public:
  139. explicit ListMap(const std::string &name) : name_(name) { cache_.clear(); }
  140. ~ListMap() = default;
  141. void MakeCond(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &gnext_ptr, const FuncGraphPtr &graph_ptr);
  142. void MakeNext(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &gcond_ptr, const FuncGraphPtr &graph_ptr);
  143. FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list);
  144. private:
  145. std::string name_;
  146. std::map<std::vector<AnyPtr>, FuncGraphPtr> cache_;
  147. };
  148. class TupleAdd : public MetaFuncGraph {
  149. public:
  150. explicit TupleAdd(const std::string &name) : MetaFuncGraph(name) {}
  151. ~TupleAdd() override = default;
  152. MS_DECLARE_PARENT(TupleAdd, MetaFuncGraph)
  153. FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
  154. friend bool operator==(const TupleAdd &lhs, const TupleAdd &rhs) { return lhs.name_ == rhs.name_; }
  155. };
  156. using TupleAddPtr = std::shared_ptr<TupleAdd>;
  157. class TupleSlice : public MetaFuncGraph {
  158. public:
  159. explicit TupleSlice(const std::string &name) : MetaFuncGraph(name) {}
  160. ~TupleSlice() override = default;
  161. MS_DECLARE_PARENT(TupleSlice, MetaFuncGraph)
  162. FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
  163. friend bool operator==(const TupleSlice &lhs, const TupleSlice &rhs) { return lhs.name_ == rhs.name_; }
  164. };
  165. using TupleSlicePtr = std::shared_ptr<TupleSlice>;
  166. class TupleGetItemTensor : public MetaFuncGraph {
  167. public:
  168. explicit TupleGetItemTensor(const std::string &name) : MetaFuncGraph(name) {}
  169. ~TupleGetItemTensor() override = default;
  170. MS_DECLARE_PARENT(TupleGetItemTensor, MetaFuncGraph)
  171. FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
  172. friend bool operator==(const TupleGetItemTensor &lhs, const TupleGetItemTensor &rhs) {
  173. return lhs.name_ == rhs.name_;
  174. }
  175. };
  176. using TupleGetItemTensorPtr = std::shared_ptr<TupleGetItemTensor>;
  177. } // namespace prim
  178. } // namespace mindspore
  179. #endif // MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_H_