You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node.h 4.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. /**
  2. * Copyright 2021-2022 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_NODE_H_
  17. #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_NODE_H_
  18. #include <memory>
  19. #include <vector>
  20. #include <set>
  21. #include <string>
  22. #include "ir/dtype/type_id.h"
  23. #include "ir/value.h"
  24. #include "ir/tensor.h"
  25. #include "utils/hash_map.h"
  26. #include "utils/shape_utils.h"
  27. #include "include/common/utils/utils.h"
  28. namespace mindspore::graphkernel::inner {
  29. enum class NType {
  30. Base,
  31. Primitive,
  32. Parameter,
  33. Value,
  34. Output,
  35. };
  36. using DFormat = std::string;
  37. using DShape = ShapeVector;
  38. using DAttrs = mindspore::HashMap<std::string, ValuePtr>;
  39. struct NodeBase {
  40. DShape shape;
  41. TypeId type;
  42. DFormat format;
  43. };
  44. class Node;
  45. using NodePtr = std::shared_ptr<Node>;
  46. using NodePtrList = std::vector<NodePtr>;
  47. class Node : public NodeBase, public std::enable_shared_from_this<Node> {
  48. public:
  49. explicit Node(const NodeBase &baseinfo) : NodeBase(baseinfo) {}
  50. virtual ~Node() { ClearInputs(); } // remove this node from the previous nodes' user.
  51. virtual NType NodeType() { return NType::Base; }
  52. virtual std::string ToString() const;
  53. void SetBaseInfo(NodeBase baseinfo);
  54. void AddInput(const NodePtr &new_input);
  55. void SetInput(size_t i, const NodePtr &new_input);
  56. void SetInputs(const NodePtrList &inputs);
  57. void ClearInputs();
  58. void ReplaceWith(const NodePtr &other_node);
  59. void SetAttrs(const DAttrs &attrs) { attrs_ = attrs; }
  60. void SetAttr(const std::string &key, const ValuePtr &value) { attrs_[key] = value; }
  61. void SetDebugName(const std::string &debug_name) { debug_name_ = debug_name; }
  62. template <typename T>
  63. std::shared_ptr<T> As() {
  64. return std::static_pointer_cast<T>(shared_from_this());
  65. }
  66. const std::string &debug_name() const { return debug_name_; }
  67. const DAttrs &attrs() const { return attrs_; }
  68. const NodePtr &input(size_t i) const { return inputs_[i]; }
  69. const NodePtrList &inputs() const { return inputs_; }
  70. const mindspore::HashMap<Node *, std::set<size_t>> &users() const { return users_; }
  71. protected:
  72. mutable std::string debug_name_; // only used in Dump function
  73. DAttrs attrs_;
  74. NodePtrList inputs_;
  75. mindspore::HashMap<Node *, std::set<size_t>> users_; // {user_node: {input edge index set}}
  76. private:
  77. // the nodes' users are only maintained by AddInput/SetInput.
  78. void AddUser(Node *user, size_t index) { users_[user].insert(index); }
  79. void RemoveUser(Node *const user, size_t index);
  80. };
  81. class ConstTensorNode : public Node {
  82. public:
  83. explicit ConstTensorNode(const tensor::TensorPtr &data)
  84. : Node({data->shape(), data->data_type(), kOpFormat_DEFAULT}), data_(data) {}
  85. ~ConstTensorNode() = default;
  86. NType NodeType() override { return NType::Value; }
  87. std::string ToString() const override { return data_->data().ToString(this->type, this->shape, false); }
  88. const tensor::TensorPtr data() const { return data_; }
  89. protected:
  90. tensor::TensorPtr data_;
  91. };
  92. class ParamNode : public Node {
  93. public:
  94. explicit ParamNode(const NodeBase &baseinfo) : Node(baseinfo) {}
  95. ~ParamNode() = default;
  96. NType NodeType() override { return NType::Parameter; }
  97. };
  98. // the OutputNode's inputs are the real outputs of graph, like the `make_tuple` in FuncGraph.
  99. class OutputNode : public Node {
  100. public:
  101. OutputNode() : Node({{1}, TypeId::kNumberTypeBegin, kOpFormat_DEFAULT}) {}
  102. ~OutputNode() = default;
  103. NType NodeType() override { return NType::Output; }
  104. };
  105. } // namespace mindspore::graphkernel::inner
  106. #endif