You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

symbolic.h 5.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #ifndef MINDSPORE_CCSRC_UTILS_SYMBOLIC_H_
  19. #define MINDSPORE_CCSRC_UTILS_SYMBOLIC_H_
  20. #include <unordered_map>
  21. #include <memory>
  22. #include <algorithm>
  23. #include <utility>
  24. #include <string>
  25. #include "ir/anf.h"
  26. #include "pipeline/static_analysis/abstract_value.h"
  27. #include "utils/any.h"
  28. namespace mindspore {
  29. class SymbolicKeyInstance : public Value {
  30. public:
  31. SymbolicKeyInstance(const AnfNodePtr &node, const abstract::AbstractBasePtr &abstract)
  32. : node_(node), abstract_(abstract) {}
  33. ~SymbolicKeyInstance() override = default;
  34. MS_DECLARE_PARENT(SymbolicKeyInstance, Value);
  35. AnfNodePtr node() const { return node_; }
  36. abstract::AbstractBasePtr abstract() const { return abstract_; }
  37. bool operator==(const SymbolicKeyInstance &other) const {
  38. return (*node_ == *other.node_) && (*abstract_ == *other.abstract_);
  39. }
  40. std::size_t hash() const override { return std::hash<AnfNodePtr>{}(node_); }
  41. friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr<SymbolicKeyInstance> &inst) {
  42. if (inst == nullptr) {
  43. os << "[Key]["
  44. << "Invalid symbolic key instance"
  45. << "]";
  46. } else {
  47. os << "[Key][" << inst->node_->type_name() << "]" << inst->node_->ToString();
  48. }
  49. return os;
  50. }
  51. std::string ToString() const override {
  52. return node_ == nullptr ? "Invalid node" : "[Key][" + node_->type_name() + "]" + node_->ToString();
  53. }
  54. bool operator==(const Value &other) const override {
  55. if (other.isa<SymbolicKeyInstance>()) {
  56. auto other_ = static_cast<const SymbolicKeyInstance &>(other);
  57. return *this == other_;
  58. } else {
  59. return false;
  60. }
  61. }
  62. abstract::AbstractBasePtr ToAbstract() override {
  63. return std::make_shared<abstract::AbstractScalar>(shared_from_base<SymbolicKeyInstance>(),
  64. std::make_shared<SymbolicKeyType>());
  65. }
  66. private:
  67. AnfNodePtr node_;
  68. abstract::AbstractBasePtr abstract_;
  69. };
  70. using SymbolicKeyInstancePtr = std::shared_ptr<SymbolicKeyInstance>;
  71. struct SymbolicKeyInstanceHash {
  72. std::size_t operator()(const SymbolicKeyInstancePtr s) const {
  73. if (s == nullptr) {
  74. return 0;
  75. }
  76. return s->abstract()->hash();
  77. }
  78. };
  79. struct SymbolicKeyInstanceEqual {
  80. bool operator()(const SymbolicKeyInstancePtr lhs, const SymbolicKeyInstancePtr rhs) const {
  81. if (lhs == nullptr || rhs == nullptr) {
  82. return false;
  83. }
  84. MS_EXCEPTION_IF_NULL(lhs->node());
  85. MS_EXCEPTION_IF_NULL(rhs->node());
  86. MS_EXCEPTION_IF_NULL(lhs->abstract());
  87. MS_EXCEPTION_IF_NULL(rhs->abstract());
  88. return (*lhs->node() == *rhs->node()) && (*lhs->abstract() == *rhs->abstract());
  89. }
  90. };
  91. using EnvInstanceContentsMap =
  92. std::unordered_map<SymbolicKeyInstancePtr, Any, SymbolicKeyInstanceHash, SymbolicKeyInstanceEqual>;
  93. // Environment mapping keys to values.
  94. // Keys are SymbolicKeyInstances, which represent nodes in the graph along
  95. // with inferred properties.
  96. class EnvInstance : public Value {
  97. public:
  98. friend std::ostream &operator<<(std::ostream &out, const std::shared_ptr<EnvInstance> &env);
  99. explicit EnvInstance(const EnvInstanceContentsMap &contents = {}) : contents_(contents) {}
  100. ~EnvInstance() override = default;
  101. MS_DECLARE_PARENT(EnvInstance, Value);
  102. abstract::AbstractBasePtr ToAbstract() override {
  103. return std::make_shared<abstract::AbstractScalar>(shared_from_base<EnvInstance>(), std::make_shared<EnvType>());
  104. }
  105. bool operator==(const EnvInstance &other) const;
  106. bool operator==(const Value &other) const override;
  107. EnvInstance(const EnvInstance &v) : Value(v), contents_(v.contents_) {}
  108. EnvInstance(EnvInstance &&v) = default;
  109. EnvInstance &operator=(EnvInstance &&src) noexcept {
  110. if (&src != this) {
  111. contents_ = src.contents_;
  112. }
  113. return *this;
  114. };
  115. // Get the sensitivity list for the given key
  116. const Any &Get(const SymbolicKeyInstancePtr &key, const Any &def) const {
  117. auto iterator = contents_.find(key);
  118. if (iterator != contents_.end()) {
  119. return iterator->second;
  120. }
  121. return def;
  122. }
  123. // Set a value for the given key.
  124. EnvInstance Set(const SymbolicKeyInstancePtr &key, const Any &value) const {
  125. EnvInstance rval(contents_);
  126. rval.contents_[key] = value;
  127. return rval;
  128. }
  129. // Add two EnvInstances.
  130. EnvInstance Add(const EnvInstance &other) const {
  131. EnvInstance rval(contents_);
  132. for (auto iter_other : other.contents_) {
  133. auto item_self = contents_.find(iter_other.first);
  134. if (item_self != contents_.end()) {
  135. MS_LOG(DEBUG) << "Need to use add";
  136. } else {
  137. rval.contents_[iter_other.first] = iter_other.second;
  138. }
  139. }
  140. return rval;
  141. }
  142. size_t Len() const { return contents_.size(); }
  143. std::size_t hash() const override {
  144. // deterministic characteristic of member variables.
  145. return Len();
  146. }
  147. const bool parse_info_ = true;
  148. private:
  149. EnvInstanceContentsMap contents_;
  150. };
  151. using EnvInstancePtr = std::shared_ptr<EnvInstance>;
  152. extern std::shared_ptr<EnvInstance> newenv;
  153. } // namespace mindspore
  154. #endif // MINDSPORE_CCSRC_UTILS_SYMBOLIC_H_