You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

value.h 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_IR_VALUE_H_
  17. #define MINDSPORE_CCSRC_IR_VALUE_H_
  18. #include <type_traits>
  19. #include <algorithm>
  20. #include <vector>
  21. #include <string>
  22. #include <memory>
  23. #include <sstream>
  24. #include <utility>
  25. #include "ir/base.h"
  26. #include "ir/anf.h"
  27. #include "ir/dtype.h"
  28. #include "ir/scalar.h"
  29. #include "utils/hashing.h"
  30. #include "common/utils.h"
  31. namespace mindspore {
  32. class ValueSequeue : public Value {
  33. public:
  34. explicit ValueSequeue(const ValuePtrList &elements) : elements_(elements) {
  35. TypePtrList t_list;
  36. (void)std::transform(elements.begin(), elements.end(), std::back_inserter(t_list), [](const ValuePtr &ele) {
  37. MS_EXCEPTION_IF_NULL(ele);
  38. return ele->type();
  39. });
  40. TypePtr t = std::make_shared<Tuple>(t_list);
  41. type_ = t;
  42. }
  43. ValueSequeue(const std::initializer_list<ValuePtr> &elements) : elements_(elements.begin(), elements.end()) {
  44. TypePtrList t_list;
  45. (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(t_list),
  46. [](const ValuePtr &ele) { return ele->type(); });
  47. TypePtr t = std::make_shared<Tuple>(t_list);
  48. type_ = t;
  49. }
  50. ~ValueSequeue() override = default;
  51. MS_DECLARE_PARENT(ValueSequeue, Value)
  52. std::size_t hash() const override { return hash_combine(tid(), std::hash<std::size_t>{}(elements_.size())); }
  53. std::size_t size() const { return elements_.size(); }
  54. bool erase(size_t idx);
  55. const ValuePtr operator[](const std::size_t &dim) const;
  56. const ValuePtrList &value() const { return elements_; }
  57. bool operator==(const Value &other) const override;
  58. bool operator==(const ValueSequeue &other) const;
  59. std::string ToString() const override;
  60. std::string DumpText() const override;
  61. protected:
  62. ValuePtrList elements_;
  63. };
  64. using ValueSequeuePtr = std::shared_ptr<ValueSequeue>;
  65. class ValueTuple : public ValueSequeue {
  66. public:
  67. explicit ValueTuple(const std::vector<ValuePtr> &elements) : ValueSequeue(elements) {}
  68. ValueTuple(const std::initializer_list<ValuePtr> &elements) : ValueSequeue(elements) {}
  69. ~ValueTuple() override = default;
  70. MS_DECLARE_PARENT(ValueTuple, ValueSequeue)
  71. abstract::AbstractBasePtr ToAbstract() override;
  72. std::string DumpText() const override { return "(" + ValueSequeue::DumpText() + ")"; }
  73. std::string ToString() const override { return "(" + ValueSequeue::ToString() + ")"; }
  74. };
  75. using ValueTuplePtr = std::shared_ptr<ValueTuple>;
  76. class ValueList : public ValueSequeue {
  77. public:
  78. explicit ValueList(const std::vector<ValuePtr> &elements) : ValueSequeue(elements) {}
  79. ValueList(const std::initializer_list<ValuePtr> &elements) : ValueSequeue(elements) {}
  80. ~ValueList() override = default;
  81. MS_DECLARE_PARENT(ValueList, ValueSequeue)
  82. abstract::AbstractBasePtr ToAbstract() override;
  83. std::string DumpText() const override { return "[" + ValueSequeue::DumpText() + "]"; }
  84. std::string ToString() const override { return "[" + ValueSequeue::ToString() + "]"; }
  85. };
  86. using ValueListPtr = std::shared_ptr<ValueList>;
  87. inline ValuePtr MakeValue(const std::vector<ValuePtr> &v) { return std::make_shared<ValueTuple>(v); }
  88. inline ValuePtr MakeValue(std::initializer_list<ValuePtr> v) { return std::make_shared<ValueTuple>(v); }
  89. template <typename T>
  90. struct is_vector : public std::false_type {};
  91. template <typename T, typename A>
  92. struct is_vector<std::vector<T, A>> : public std::true_type {};
  93. template <typename T, typename U = typename std::enable_if<is_vector<T>::value, typename T::value_type>::type>
  94. ValuePtr MakeValue(const T &vec) {
  95. std::vector<ValuePtr> list;
  96. (void)std::transform(vec.begin(), vec.end(), std::back_inserter(list), [](U ele) { return MakeValue(ele); });
  97. return std::make_shared<ValueTuple>(list);
  98. }
  99. class ValueSlice : public Value {
  100. public:
  101. ValueSlice(const ValuePtr &start, const ValuePtr &stop, const ValuePtr &step)
  102. : start_(start), stop_(stop), step_(step) {}
  103. ~ValueSlice() override = default;
  104. MS_DECLARE_PARENT(ValueSlice, Value)
  105. std::size_t hash() const override;
  106. bool operator==(const Value &other) const override;
  107. bool operator==(const ValueSlice &other) const;
  108. std::string ToString() const override;
  109. abstract::AbstractBasePtr ToAbstract() override;
  110. std::string DumpText() const override { return ToString(); }
  111. ValuePtr start() const { return start_; }
  112. ValuePtr stop() const { return stop_; }
  113. ValuePtr step() const { return step_; }
  114. private:
  115. ValuePtr start_;
  116. ValuePtr stop_;
  117. ValuePtr step_;
  118. };
  119. using ValueSlicePtr = std::shared_ptr<ValueSlice>;
  120. class KeywordArg : public Value {
  121. public:
  122. KeywordArg(const std::string &key, const ValuePtr &value) : key_(key), value_(value) {}
  123. ~KeywordArg() override = default;
  124. MS_DECLARE_PARENT(KeywordArg, Value)
  125. std::size_t hash() const override;
  126. ValuePtr get_value() const { return value_; }
  127. bool operator==(const Value &other) const override;
  128. bool operator==(const KeywordArg &other) const;
  129. std::string ToString() const override;
  130. abstract::AbstractBasePtr ToAbstract() override;
  131. std::string DumpText() const override { return ToString(); }
  132. private:
  133. std::string key_;
  134. ValuePtr value_;
  135. };
  136. using KeywordArgPtr = std::shared_ptr<KeywordArg>;
  137. class ValueDictionary : public Value {
  138. public:
  139. explicit ValueDictionary(const std::vector<std::pair<std::string, ValuePtr>> &key_values) : key_values_(key_values) {}
  140. ~ValueDictionary() override = default;
  141. MS_DECLARE_PARENT(ValueDictionary, Value)
  142. std::size_t hash() const override { return hash_combine(tid(), std::hash<std::size_t>{}(key_values_.size())); }
  143. std::size_t size() const { return key_values_.size(); }
  144. const ValuePtr operator[](const std::string &key) const;
  145. const std::vector<std::pair<std::string, ValuePtr>> &value() const { return key_values_; }
  146. bool operator==(const Value &other) const override;
  147. bool operator==(const ValueDictionary &other) const;
  148. std::string ToString() const override {
  149. std::ostringstream buffer;
  150. std::vector<std::string> keys;
  151. std::vector<ValuePtr> values;
  152. for (const auto &kv : key_values_) {
  153. keys.push_back(kv.first);
  154. values.push_back(kv.second);
  155. }
  156. buffer << "(Dict: "
  157. << " keys:(";
  158. for (const auto &key : keys) {
  159. buffer << key << ", ";
  160. }
  161. buffer << ") values:(";
  162. for (const auto &value : values) {
  163. MS_EXCEPTION_IF_NULL(value);
  164. buffer << value->DumpText() << ", ";
  165. }
  166. buffer << ")";
  167. return buffer.str();
  168. }
  169. abstract::AbstractBasePtr ToAbstract() override;
  170. std::string DumpText() const override { return ToString(); }
  171. private:
  172. std::vector<std::pair<std::string, ValuePtr>> key_values_;
  173. };
  174. using ValueDictionaryPtr = std::shared_ptr<ValueDictionary>;
  175. class StringImm : public Value {
  176. public:
  177. explicit StringImm(const std::string &str) : Value(kString), str_(str), hash_(std::hash<std::string>{}(str_)) {}
  178. ~StringImm() override = default;
  179. MS_DECLARE_PARENT(StringImm, Value)
  180. std::size_t hash() const override { return hash_; }
  181. const std::string &value() const { return str_; }
  182. bool operator==(const Value &other) const override;
  183. bool operator==(const StringImm &other) const;
  184. abstract::AbstractBasePtr ToAbstract() override;
  185. std::string ToString() const override { return str_; }
  186. std::string DumpText() const override {
  187. std::ostringstream oss;
  188. oss << "\"" << str_ << "\"";
  189. return oss.str();
  190. }
  191. private:
  192. std::string str_;
  193. std::size_t hash_ = 0;
  194. };
  195. using StringImmPtr = std::shared_ptr<StringImm>;
  196. IMM_TRAITS(StringImmPtr, std::string)
  197. IMM_TRAITS(StringImmPtr, const char *)
  198. class RefKey : public Value {
  199. public:
  200. explicit RefKey(const std::string &tag) : Value(kRefKeyType), tag_(tag), hash_(std::hash<std::string>{}(tag)) {}
  201. ~RefKey() override = default;
  202. MS_DECLARE_PARENT(RefKey, Value)
  203. std::size_t hash() const override { return hash_; }
  204. const std::string &tag() const { return tag_; }
  205. bool operator==(const Value &other) const override;
  206. bool operator==(const RefKey &other) const;
  207. abstract::AbstractBasePtr ToAbstract() override;
  208. std::string ToString() const override { return "RefKey[" + tag_ + "]"; }
  209. std::string DumpText() const override {
  210. std::ostringstream oss;
  211. oss << "RefKey[\"" << tag_ << "\"]";
  212. return oss.str();
  213. }
  214. private:
  215. std::string tag_;
  216. std::size_t hash_ = 0;
  217. };
  218. using RefKeyPtr = std::shared_ptr<RefKey>;
  219. class AnyValue : public Value {
  220. public:
  221. AnyValue() = default;
  222. ~AnyValue() override = default;
  223. MS_DECLARE_PARENT(AnyValue, Value)
  224. std::size_t hash() const override { return tid(); }
  225. bool operator==(const Value &other) const override;
  226. abstract::AbstractBasePtr ToAbstract() override;
  227. };
  228. extern const ValuePtr kAnyValue;
  229. template <>
  230. inline const char *GetValue(const ValuePtr &value) {
  231. if (value == nullptr) {
  232. MS_LOG(EXCEPTION) << "Value is nullptr";
  233. }
  234. auto imm = value->cast<StringImmPtr>();
  235. if (imm == nullptr) {
  236. MS_LOG(EXCEPTION) << "GetValue:" << value->ToString() << ", Type:" << value->type_name();
  237. }
  238. return common::SafeCStr(imm->value());
  239. }
  240. template <typename T, typename S = typename std::decay<T>::type,
  241. typename U = typename std::enable_if<is_vector<S>::value, typename S::value_type>::type>
  242. std::vector<U> GetValue(const ValuePtr &value) {
  243. if (value == nullptr) {
  244. MS_LOG(EXCEPTION) << "Value is nullptr";
  245. }
  246. if (!value->isa<ValueSequeue>()) {
  247. MS_LOG(EXCEPTION) << "Error GetValue for value: " << value->ToString() << ", type: vector<" << typeid(U).name()
  248. << ">";
  249. }
  250. std::vector<U> rets;
  251. const std::vector<ValuePtr> &vals = value->cast<ValueSequeuePtr>()->value();
  252. (void)std::transform(vals.begin(), vals.end(), std::back_inserter(rets),
  253. [](const ValuePtr &v) { return GetValue<U>(v); });
  254. return rets;
  255. }
  256. inline ValueNodePtr NewValueNode(const ValuePtr &t) { return std::make_shared<ValueNode>(t); }
  257. template <typename T, typename _ = typename std::enable_if<!std::is_base_of<Value, T>::value>::type>
  258. inline ValueNodePtr NewValueNode(const std::shared_ptr<T> &x) {
  259. return NewValueNode(MakeValue(x));
  260. }
  261. template <typename T, typename _ = typename std::enable_if<!is_shared_ptr<T>::value>::type>
  262. inline ValueNodePtr NewValueNode(const T &x) {
  263. return NewValueNode(MakeValue(x));
  264. }
  265. } // namespace mindspore
  266. #endif // MINDSPORE_CCSRC_IR_VALUE_H_