/** * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_H_ #define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_H_ #include #include #include #include #include #include #include #include "operator/composite/zip_operation.h" #include "operator/composite/list_append_operation.h" #include "operator/composite/do_signature.h" #include "operator/composite/unpack_call.h" #include "pipeline/static_analysis/static_analysis.h" #include "utils/misc.h" #include "utils/any.h" #include "ir/dtype.h" #include "ir/meta_func_graph.h" namespace mindspore { // namespace to support composite operators definition namespace prim { using AbstractSlicePtr = abstract::AbstractSlicePtr; using AbstractScalarPtr = abstract::AbstractScalarPtr; using AbstractTensorPtr = abstract::AbstractTensorPtr; using ElemwiseMap = std::unordered_map; using ArgsPairList = std::vector>; class MultitypeFuncGraph : public MetaFuncGraph { public: explicit MultitypeFuncGraph(const std::string &name); ~MultitypeFuncGraph() override = default; MS_DECLARE_PARENT(MultitypeFuncGraph, MetaFuncGraph) using specialize_fn = FuncGraph *(*)(TypePtrList); // Register a method which specialize based on types vectors; virtual void Register(const TypePtrList &types, specialize_fn s_fn); virtual void Register(const TypePtrList &types, const py::function &py_fn); virtual void Register(const std::vector &types_name, const py::function &py_fn); virtual void PyRegister(const py::tuple &tuple, const py::function &py_fn); FuncGraphPtr GenerateFromTypes(const TypePtrList &types) override; size_t GetPyFnCacheSize() const { return fn_cache_py_.size(); } const std::unordered_map &GetPyFunctions() const { return fn_cache_py_; } private: std::unordered_map fn_cache_; std::unordered_map fn_cache_py_; }; using MultitypeFuncGraphPtr = std::shared_ptr; class HyperMap : public MetaFuncGraph { public: explicit HyperMap(const std::shared_ptr &fn_leaf = nullptr); HyperMap(const HyperMap &h); void Init(); HyperMap &operator=(const HyperMap &h) { if (this != &h) { fn_leaf_ = h.fn_leaf_; broadcast_ = h.broadcast_; nonleaf_ = h.nonleaf_; if (fn_leaf_) { name_ = "hyper_map[" + fn_leaf_->name() + "]"; } } return *this; } ~HyperMap() override = default; MS_DECLARE_PARENT(HyperMap, MetaFuncGraph) abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const override; FuncGraphPtr GenerateFromTypes(const TypePtrList &args_spec_list) override; MetaFuncGraphPtr GetFnLeaf() { return fn_leaf_; } private: AnfNodePtr FullMake(TypePtr type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map); AnfNodePtr FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map); AnfNodePtr FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map); AnfNodePtr FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map); AnfNodePtr Make(const FuncGraphPtr &graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map); ArgsPairList Harmonize(const FuncGraphPtr &graph, const ArgsPairList &args_spec_list); MultitypeFuncGraphPtr fn_leaf_; bool broadcast_; std::set nonleaf_; }; using HyperMapPtr = std::shared_ptr; class HyperMapPy : public HyperMap { public: explicit HyperMapPy(const std::shared_ptr &fn_leaf = nullptr) : HyperMap(fn_leaf) {} ~HyperMapPy() override = default; MS_DECLARE_PARENT(HyperMapPy, HyperMap) }; using HyperMapPyPtr = std::shared_ptr; extern ValuePtr kCompositeHyperMap; class Tail : public MetaFuncGraph { public: explicit Tail(const std::string &name) : MetaFuncGraph(name) {} ~Tail() override = default; MS_DECLARE_PARENT(Tail, MetaFuncGraph) FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; FuncGraphPtr GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tuple); FuncGraphPtr GenerateListFuncGraph(const abstract::AbstractListPtr &a_list); friend bool operator==(const Tail &lhs, const Tail &rhs) { return lhs.name_ == rhs.name_; } }; using TailPtr = std::shared_ptr; class MakeTupleGradient : public MetaFuncGraph { public: explicit MakeTupleGradient(const std::string &name) : MetaFuncGraph(name) {} ~MakeTupleGradient() override = default; MS_DECLARE_PARENT(MakeTupleGradient, MetaFuncGraph) FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; friend bool operator==(const MakeTupleGradient &lhs, const MakeTupleGradient &rhs) { return lhs.name_ == rhs.name_; } }; using MakeTupleGradientPtr = std::shared_ptr; class GradOperation : public MetaFuncGraph { public: explicit GradOperation(const std::string &name, bool get_all = false, bool get_by_list = false, bool sens_param = false); ~GradOperation() override = default; MS_DECLARE_PARENT(GradOperation, MetaFuncGraph) FuncGraphPtr GetGrad(AnfNodePtr ptrNode, const AnfNodePtr &weights, const std::vector &ptrParams, bool applyJ = false); FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; bool sens_param() const { return sens_param_; } bool get_all_; bool get_by_list_; bool sens_param_; private: void doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr ptrOut, AnfNodePtr ptrBprop, AnfNodePtr weights, ValueNodePtr opsTupleItem); }; using GradOperationPtr = std::shared_ptr; class ListMap { public: explicit ListMap(const std::string &name) : name_(name) { cache_.clear(); } ~ListMap() = default; void MakeCond(const std::vector &lists, const FuncGraphPtr &gnext_ptr, const FuncGraphPtr &graph_ptr); void MakeNext(const std::vector &lists, const FuncGraphPtr &gcond_ptr, const FuncGraphPtr &graph_ptr); FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list); private: std::string name_; std::map, FuncGraphPtr> cache_; }; class TupleAdd : public MetaFuncGraph { public: explicit TupleAdd(const std::string &name) : MetaFuncGraph(name) {} ~TupleAdd() override = default; MS_DECLARE_PARENT(TupleAdd, MetaFuncGraph) FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; friend bool operator==(const TupleAdd &lhs, const TupleAdd &rhs) { return lhs.name_ == rhs.name_; } }; using TupleAddPtr = std::shared_ptr; class TupleSlice : public MetaFuncGraph { public: explicit TupleSlice(const std::string &name) : MetaFuncGraph(name) {} ~TupleSlice() override = default; MS_DECLARE_PARENT(TupleSlice, MetaFuncGraph) FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; friend bool operator==(const TupleSlice &lhs, const TupleSlice &rhs) { return lhs.name_ == rhs.name_; } }; using TupleSlicePtr = std::shared_ptr; class TensorSlice : public MetaFuncGraph { public: explicit TensorSlice(const std::string &name) : MetaFuncGraph(name) {} ~TensorSlice() override = default; MS_DECLARE_PARENT(TensorSlice, MetaFuncGraph) FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; friend bool operator==(const TensorSlice &lhs, const TensorSlice &rhs) { return lhs.name_ == rhs.name_; } }; using TensorSlicePtr = std::shared_ptr; class TupleGetItemTensor : public MetaFuncGraph { public: explicit TupleGetItemTensor(const std::string &name) : MetaFuncGraph(name) {} ~TupleGetItemTensor() override = default; MS_DECLARE_PARENT(TupleGetItemTensor, MetaFuncGraph) FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; friend bool operator==(const TupleGetItemTensor &lhs, const TupleGetItemTensor &rhs) { return lhs.name_ == rhs.name_; } }; using TupleGetItemTensorPtr = std::shared_ptr; } // namespace prim } // namespace mindspore #endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_H_