You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

composite.h 9.0 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019-2021 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #ifndef MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_H_
  19. #define MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_H_
  20. #include <vector>
  21. #include <string>
  22. #include <utility>
  23. #include <map>
  24. #include <set>
  25. #include <memory>
  26. #include "utils/hash_map.h"
  27. #include "frontend/operator/composite/zip_operation.h"
  28. #include "frontend/operator/composite/list_append_operation.h"
  29. #include "frontend/operator/composite/do_signature.h"
  30. #include "frontend/operator/composite/unpack_call.h"
  31. #include "frontend/operator/composite/multitype_funcgraph.h"
  32. #include "pipeline/jit/static_analysis/static_analysis.h"
  33. #include "utils/misc.h"
  34. #include "utils/any.h"
  35. #include "ir/dtype.h"
  36. #include "ir/meta_func_graph.h"
  37. namespace mindspore {
  38. // namespace to support composite operators definition
  39. namespace prim {
  40. using AbstractSlicePtr = abstract::AbstractSlicePtr;
  41. using AbstractScalarPtr = abstract::AbstractScalarPtr;
  42. using AbstractTensorPtr = abstract::AbstractTensorPtr;
  43. using ElemwiseMap = mindspore::HashMap<std::string, PrimitivePtr>;
  44. using ArgsPairList = std::vector<std::pair<AnfNodePtr, TypePtr>>;
  45. class HyperMap : public MetaFuncGraph {
  46. public:
  47. explicit HyperMap(bool reverse = false, const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr);
  48. HyperMap(const HyperMap &h);
  49. void Init();
  50. HyperMap &operator=(const HyperMap &h) {
  51. if (this != &h) {
  52. fn_leaf_ = h.fn_leaf_;
  53. reverse_ = h.reverse_;
  54. broadcast_ = h.broadcast_;
  55. nonleaf_ = h.nonleaf_;
  56. if (fn_leaf_) {
  57. name_ = "hyper_map[" + fn_leaf_->name() + "]";
  58. }
  59. }
  60. return *this;
  61. }
  62. ~HyperMap() override = default;
  63. MS_DECLARE_PARENT(HyperMap, MetaFuncGraph)
  64. abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const override;
  65. FuncGraphPtr GenerateFromTypes(const TypePtrList &args_spec_list) override;
  66. MetaFuncGraphPtr GetFnLeaf() { return fn_leaf_; }
  67. private:
  68. AnfNodePtr FullMake(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map);
  69. AnfNodePtr FullMake(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
  70. const ArgsPairList &arg_map);
  71. AnfNodePtr FullMake(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
  72. const ArgsPairList &arg_map);
  73. AnfNodePtr FullMake(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
  74. const ArgsPairList &arg_map);
  75. AnfNodePtr Make(const FuncGraphPtr &graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map);
  76. ArgsPairList Harmonize(const FuncGraphPtr &graph, const ArgsPairList &args_spec_list);
  77. MultitypeFuncGraphPtr fn_leaf_;
  78. bool reverse_;
  79. bool broadcast_;
  80. std::set<TypeId> nonleaf_;
  81. };
  82. using HyperMapPtr = std::shared_ptr<HyperMap>;
  83. class HyperMapPy : public HyperMap {
  84. public:
  85. explicit HyperMapPy(bool reverse = false, const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr)
  86. : HyperMap(reverse, fn_leaf) {}
  87. ~HyperMapPy() override = default;
  88. MS_DECLARE_PARENT(HyperMapPy, HyperMap)
  89. };
  90. using HyperMapPyPtr = std::shared_ptr<HyperMapPy>;
  91. extern ValuePtr kCompositeHyperMap;
  92. enum TailType { kGradAll, kGradFirst, kGradByPosition, kNotGrad };
  93. class Tail : public MetaFuncGraph {
  94. public:
  95. explicit Tail(const std::string &name, TailType tail_type = kNotGrad)
  96. : MetaFuncGraph(name), tail_type_(tail_type), enable_tuple_grad_(false) {}
  97. ~Tail() override = default;
  98. MS_DECLARE_PARENT(Tail, MetaFuncGraph)
  99. FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
  100. FuncGraphPtr GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue,
  101. const abstract::AbstractSequeuePtr &pos = nullptr) const;
  102. friend bool operator==(const Tail &lhs, const Tail &rhs) { return lhs.name_ == rhs.name_; }
  103. void set_enable_tuple_grad(bool enable_tuple_grad) { enable_tuple_grad_ = enable_tuple_grad; }
  104. private:
  105. TailType tail_type_;
  106. bool enable_tuple_grad_;
  107. };
  108. using TailPtr = std::shared_ptr<Tail>;
  109. class MakeTupleGradient : public MetaFuncGraph {
  110. public:
  111. explicit MakeTupleGradient(const std::string &name) : MetaFuncGraph(name) {}
  112. ~MakeTupleGradient() override = default;
  113. MS_DECLARE_PARENT(MakeTupleGradient, MetaFuncGraph)
  114. FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
  115. friend bool operator==(const MakeTupleGradient &lhs, const MakeTupleGradient &rhs) { return lhs.name_ == rhs.name_; }
  116. };
  117. using MakeTupleGradientPtr = std::shared_ptr<MakeTupleGradient>;
  118. class MakeListGradient : public MetaFuncGraph {
  119. public:
  120. explicit MakeListGradient(const std::string &name) : MetaFuncGraph(name) {}
  121. ~MakeListGradient() override = default;
  122. MS_DECLARE_PARENT(MakeListGradient, MetaFuncGraph)
  123. FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
  124. friend bool operator==(const MakeListGradient &lhs, const MakeListGradient &rhs) { return lhs.name_ == rhs.name_; }
  125. };
  126. using MakeListGradientPtr = std::shared_ptr<MakeListGradient>;
  127. class GradOperation : public MetaFuncGraph {
  128. public:
  129. explicit GradOperation(const std::string &name, bool get_all = false, bool get_by_list = false,
  130. bool sens_param = false, bool get_by_position = false);
  131. ~GradOperation() override = default;
  132. MS_DECLARE_PARENT(GradOperation, MetaFuncGraph)
  133. FuncGraphPtr GetGrad(const AnfNodePtr &k, const AnfNodePtr &weights, const AnfNodePtr &position,
  134. const std::vector<AnfNodePtr> &forward_graph_params, bool enable_tuple_grad,
  135. const std::vector<AnfNodePtr> &weight_args = {});
  136. FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
  137. bool sens_param() const { return sens_param_; }
  138. bool get_all_;
  139. bool get_by_list_;
  140. bool sens_param_;
  141. bool get_by_position_;
  142. private:
  143. void GradByParameter(const FuncGraphPtr &k_child, const AnfNodePtr &f_app, const AnfNodePtr &bprop,
  144. const AnfNodePtr &weights, const AnfNodePtr &position, bool enable_tuple_grad);
  145. };
  146. using GradOperationPtr = std::shared_ptr<GradOperation>;
  147. class ListMap {
  148. public:
  149. explicit ListMap(const std::string &name) : name_(name) { cache_.clear(); }
  150. ~ListMap() = default;
  151. void MakeCond(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &gnext_ptr, const FuncGraphPtr &graph_ptr);
  152. void MakeNext(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &gcond_ptr, const FuncGraphPtr &graph_ptr);
  153. FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list);
  154. private:
  155. std::string name_;
  156. std::map<std::vector<AnyPtr>, FuncGraphPtr> cache_;
  157. };
  158. class TupleAdd : public MetaFuncGraph {
  159. public:
  160. explicit TupleAdd(const std::string &name) : MetaFuncGraph(name) {}
  161. ~TupleAdd() override = default;
  162. MS_DECLARE_PARENT(TupleAdd, MetaFuncGraph)
  163. FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
  164. friend bool operator==(const TupleAdd &lhs, const TupleAdd &rhs) { return lhs.name_ == rhs.name_; }
  165. };
  166. using TupleAddPtr = std::shared_ptr<TupleAdd>;
  167. class TupleSlice : public MetaFuncGraph {
  168. public:
  169. explicit TupleSlice(const std::string &name) : MetaFuncGraph(name) {}
  170. ~TupleSlice() override = default;
  171. MS_DECLARE_PARENT(TupleSlice, MetaFuncGraph)
  172. FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
  173. friend bool operator==(const TupleSlice &lhs, const TupleSlice &rhs) { return lhs.name_ == rhs.name_; }
  174. };
  175. using TupleSlicePtr = std::shared_ptr<TupleSlice>;
  176. class TupleGetItemTensor : public MetaFuncGraph {
  177. public:
  178. explicit TupleGetItemTensor(const std::string &name) : MetaFuncGraph(name) {}
  179. ~TupleGetItemTensor() override = default;
  180. MS_DECLARE_PARENT(TupleGetItemTensor, MetaFuncGraph)
  181. FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
  182. friend bool operator==(const TupleGetItemTensor &lhs, const TupleGetItemTensor &rhs) {
  183. return lhs.name_ == rhs.name_;
  184. }
  185. };
  186. using TupleGetItemTensorPtr = std::shared_ptr<TupleGetItemTensor>;
  187. } // namespace prim
  188. } // namespace mindspore
  189. #endif // MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_H_