/** * Copyright 2021-2022 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_NODE_H_ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_NODE_H_ #include #include #include #include #include "ir/dtype/type_id.h" #include "ir/value.h" #include "ir/tensor.h" #include "utils/hash_map.h" #include "utils/shape_utils.h" #include "include/common/utils/utils.h" namespace mindspore::graphkernel::inner { enum class NType { Base, Primitive, Parameter, Value, Output, }; using DFormat = std::string; using DShape = ShapeVector; using DAttrs = mindspore::HashMap; struct NodeBase { DShape shape; TypeId type; DFormat format; }; class Node; using NodePtr = std::shared_ptr; using NodePtrList = std::vector; class Node : public NodeBase, public std::enable_shared_from_this { public: explicit Node(const NodeBase &baseinfo) : NodeBase(baseinfo) {} virtual ~Node() { ClearInputs(); } // remove this node from the previous nodes' user. virtual NType NodeType() { return NType::Base; } virtual std::string ToString() const; void SetBaseInfo(NodeBase baseinfo); void AddInput(const NodePtr &new_input); void SetInput(size_t i, const NodePtr &new_input); void SetInputs(const NodePtrList &inputs); void ClearInputs(); void ReplaceWith(const NodePtr &other_node); void SetAttrs(const DAttrs &attrs) { attrs_ = attrs; } void SetAttr(const std::string &key, const ValuePtr &value) { attrs_[key] = value; } void SetDebugName(const std::string &debug_name) { debug_name_ = debug_name; } template std::shared_ptr As() { return std::static_pointer_cast(shared_from_this()); } const std::string &debug_name() const { return debug_name_; } const DAttrs &attrs() const { return attrs_; } const NodePtr &input(size_t i) const { return inputs_[i]; } const NodePtrList &inputs() const { return inputs_; } const mindspore::HashMap> &users() const { return users_; } protected: mutable std::string debug_name_; // only used in Dump function DAttrs attrs_; NodePtrList inputs_; mindspore::HashMap> users_; // {user_node: {input edge index set}} private: // the nodes' users are only maintained by AddInput/SetInput. void AddUser(Node *user, size_t index) { users_[user].insert(index); } void RemoveUser(Node *const user, size_t index); }; class ConstTensorNode : public Node { public: explicit ConstTensorNode(const tensor::TensorPtr &data) : Node({data->shape(), data->data_type(), kOpFormat_DEFAULT}), data_(data) {} ~ConstTensorNode() = default; NType NodeType() override { return NType::Value; } std::string ToString() const override { return data_->data().ToString(this->type, this->shape, false); } const tensor::TensorPtr data() const { return data_; } protected: tensor::TensorPtr data_; }; class ParamNode : public Node { public: explicit ParamNode(const NodeBase &baseinfo) : Node(baseinfo) {} ~ParamNode() = default; NType NodeType() override { return NType::Parameter; } }; // the OutputNode's inputs are the real outputs of graph, like the `make_tuple` in FuncGraph. class OutputNode : public Node { public: OutputNode() : Node({{1}, TypeId::kNumberTypeBegin, kOpFormat_DEFAULT}) {} ~OutputNode() = default; NType NodeType() override { return NType::Output; } }; } // namespace mindspore::graphkernel::inner #endif