/** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef TRANSFORM_OP_ADAPTER_BASE_H_ #define TRANSFORM_OP_ADAPTER_BASE_H_ #include #include #include #include #include #include #include "transform/util.h" #include "ir/anf.h" #include "ir/primitive.h" #include "ir/value.h" #include "transform/types.h" #ifdef ENABLE_GE #ifdef OPEN_SOURCE #include "graph/types.h" #endif #endif #include "graph/operator_reg.h" #ifdef OPEN_SOURCE #include "ge/client/ge_api.h" #else #include "external/ge/ge_api.h" #endif #include "graph/tensor.h" #include "transform/all_ops.h" namespace ge { class CustomOperator : public Operator { public: CustomOperator(const string &name, const string &type) : Operator(name, type) {} ~CustomOperator() override{}; void CustomInputRegister(const string &name) { Operator::InputRegister(name); } void CustomOutputRegister(const string &name) { Operator::OutputRegister(name); } void CustomInferFuncRegister(const std::function &func) { Operator::InferFuncRegister(func); } }; } // namespace ge namespace mindspore { namespace transform { using CusOperatorPtr = std::shared_ptr; using CustomOperator = ge::CustomOperator; struct OutHandler { OperatorPtr op; std::string out; OutHandler() : op(nullptr), out("") {} OutHandler(const OperatorPtr &op, const std::string out) : op(op), out(out) {} }; struct ControlEdge { OperatorPtr src_op; OperatorPtr dest_op; }; using AttrFunc = std::function; using OutputFunc = std::function; using InputOpFunc = std::function; using InputHandleFunc = std::function; using CreateDynInputOpFunc = std::function; using DynInputOpFunc = std::function; using DynInputHandleFunc = std::function; using UpdateOutputDescFunc = std::function; using CreateDynOutputOpFunc = std::function; struct AttrDesc { std::string name; AttrFunc set_attr; }; struct InputDesc { std::string name; InputOpFunc set_op; InputHandleFunc set_handle; UpdateOutputDescFunc update_input_desc; }; struct DynInputDesc { std::string name; CreateDynInputOpFunc create_dyn_input; DynInputOpFunc set_op; DynInputHandleFunc set_handle; }; struct OutputDesc { std::string name; UpdateOutputDescFunc update_out_desc; }; struct DynOutputDesc { std::string name; CreateDynOutputOpFunc create_dyn_output; }; class BaseOpAdapter { public: virtual ~BaseOpAdapter() {} virtual OperatorPtr generate(const AnfNodePtr &anf) = 0; virtual OperatorPtr generate(const std::string &type) { return std::make_shared(type); } virtual int setInput(const OperatorPtr &op, int index, const OperatorPtr &input) = 0; virtual int setInput(const OperatorPtr &op, int index, const OutHandler &handle) = 0; virtual int setInput(const OperatorPtr &op, int index, const std::shared_ptr> &handler_vec) = 0; virtual int setAttr(const OperatorPtr &op, const std::string &attrKey, const ValuePtr &attrValue) = 0; virtual int setAttr(const OperatorPtr &op, const PrimitivePtr &prim) = 0; virtual int setAttr(const OperatorPtr &op, const AnfNodePtr &node) = 0; virtual std::unordered_map GetExtraAttr() = 0; template ::value>::type> int setAttr(const OperatorPtr &op, const std::string &attrKey, const std::shared_ptr &attrValue) { return setAttr(op, attrKey, MakeValue(attrValue)); } template ::value>::type> int setAttr(const OperatorPtr &op, const std::string &attrKey, const T &attrValue) { return setAttr(op, attrKey, MakeValue(attrValue)); } virtual OutHandler getOutput(const OperatorPtr &op, int index) = 0; virtual void updateOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, const AnfNodePtr &node) = 0; virtual const std::unordered_map &getInputMap() = 0; virtual const std::unordered_map &getInputAttrMap() = 0; virtual const std::unordered_map &getDynInputMap() = 0; virtual const std::unordered_map &getOutputMap() = 0; void AddAttrToDrawGraph(const std::string &attr_str) { attrs_vec_.push_back(attr_str); } const std::vector &GetAttrsFromDrawGraph() const { return attrs_vec_; } void clearAttrVect() { attrs_vec_.clear(); } private: std::vector attrs_vec_; }; using OpAdapterPtr = std::shared_ptr; enum AttrType { ATTR_INT = 0, ATTR_FLOAT, ATTR_DOUBLE, ATTR_STRING, ATTR_TENSOR, ATTR_BOOL, ATTR_LIST_INT, ATTR_LIST_ANY_INT, ATTR_ENUM }; struct GeEnum {}; struct TFType {}; struct GEType {}; // declare Any type template struct AnyTraits { using type = T; }; template <> struct AnyTraits { using type = int64_t; }; using ExtraAttr = std::unordered_map; } // namespace transform } // namespace mindspore #endif // TRANSFORM_OP_ADAPTER_BASE_H_