/** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef PREDICT_SRC_NODE_H_ #define PREDICT_SRC_NODE_H_ #include #include #include #include "include/session.h" #include "src/op.h" namespace mindspore { namespace predict { using NODE_ID = std::string; class Node { public: Node() = default; explicit Node(const NodeDef *nodeDef); virtual ~Node(); NODE_ID ID(); std::string Type(); void SetTensors(const NodeDef &nodeDef, const std::vector &allTensors); void SetDepends(const std::unordered_set &deps); std::unordered_set GetDepends(); void AddInEdge(Node *node); void AddOutEdge(Node *node); std::unordered_set &GetAllOutEdges(); std::unordered_set &GetAllInEdges(); std::vector &GetOutputTensors(); std::vector &GetInputTensors(); int InitOp(const OpDef &opDef, const Context &ctx); int Run(const Context &ctx); int MallocOutput(const Context &ctx); void FreeInput(); protected: friend class GraphExecution; NODE_ID id; std::string type; OpBase *op{}; std::vector inputs; std::vector outputs; std::unordered_set depends; std::unordered_set inEdges; std::unordered_set outEdges; }; } // namespace predict } // namespace mindspore #endif // PREDICT_SRC_NODE_H_