You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dshape.h 5.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #ifndef MINDSPORE_CORE_ABSTRACT_DSHAPE_H_
  19. #define MINDSPORE_CORE_ABSTRACT_DSHAPE_H_
  20. #include <vector>
  21. #include <string>
  22. #include <sstream>
  23. #include <unordered_map>
  24. #include <typeindex>
  25. #include <memory>
  26. #include <algorithm>
  27. #include "utils/log_adapter.h"
  28. #include "base/base.h"
  29. #include "utils/shape_utils.h"
  30. namespace mindspore {
  31. namespace abstract {
  32. class BaseShape;
  33. using BaseShapePtr = std::shared_ptr<BaseShape>;
  34. using BaseShapePtrList = std::vector<BaseShapePtr>;
  35. class BaseShape : public Base {
  36. public:
  37. BaseShape() = default;
  38. ~BaseShape() override = default;
  39. MS_DECLARE_PARENT(BaseShape, Base)
  40. virtual bool operator==(const BaseShape &other) const;
  41. bool operator!=(const BaseShape &other) const;
  42. std::size_t hash() const override { return tid(); }
  43. // return a deep copy
  44. virtual BaseShapePtr Clone() const = 0;
  45. virtual void Broaden() {}
  46. };
  47. class NoShape : public BaseShape {
  48. public:
  49. MS_DECLARE_PARENT(NoShape, BaseShape)
  50. BaseShapePtr Clone() const override { return std::make_shared<NoShape>(); }
  51. std::string ToString() const override { return type_name(); }
  52. };
  53. extern const std::shared_ptr<NoShape> kNoShape;
  54. class Shape : public BaseShape {
  55. public:
  56. static const int SHP_ANY = -1;
  57. Shape() : shape_() {}
  58. Shape(const std::initializer_list<int> &list) : shape_(list) {}
  59. Shape(const std::initializer_list<int64_t> &list) {
  60. std::vector<int64_t> list_in(list);
  61. (void)std::transform(list_in.begin(), list_in.end(), std::back_inserter(shape_),
  62. [](const int64_t &value) { return static_cast<int>(value); });
  63. }
  64. explicit Shape(const ShapeVector &list) : shape_(list) {}
  65. explicit Shape(const std::vector<int64_t> &list) {
  66. (void)std::transform(list.begin(), list.end(), std::back_inserter(shape_),
  67. [](const int64_t &value) { return static_cast<int>(value); });
  68. }
  69. Shape(const ShapeVector &list, const ShapeVector &min_shape, const ShapeVector &max_shape)
  70. : shape_(list), min_shape_(min_shape), max_shape_(max_shape) {}
  71. ~Shape() override = default;
  72. MS_DECLARE_PARENT(Shape, BaseShape)
  73. std::string ToString() const override;
  74. std::string DumpText() const override;
  75. bool operator==(const BaseShape &other) const override;
  76. BaseShapePtr Clone() const override { return std::make_shared<Shape>(shape_, min_shape_, max_shape_); }
  77. void Broaden() override;
  78. ShapeVector &shape() { return shape_; }
  79. ShapeVector &min_shape() { return min_shape_; }
  80. ShapeVector &max_shape() { return max_shape_; }
  81. ShapeVector shape_; // use SHP_ANY to implement the any shape in python
  82. ShapeVector min_shape_; // record mininum length for each dynamic dimention
  83. ShapeVector max_shape_; // record maximum length for each dynamic dimention
  84. };
  85. using ShapePtr = std::shared_ptr<Shape>;
  86. using ShapePtrList = std::vector<ShapePtr>;
  87. class SequeueShape : public BaseShape {
  88. public:
  89. SequeueShape() : p_shapes_() {}
  90. explicit SequeueShape(const BaseShapePtrList &shapes) : p_shapes_(shapes) {}
  91. ~SequeueShape() override = default;
  92. MS_DECLARE_PARENT(SequeueShape, BaseShape)
  93. std::string ToString() const override;
  94. BaseShapePtrList ElementsClone() const;
  95. template <typename T>
  96. bool SequeueEqual(const BaseShape &other) const;
  97. const BaseShapePtrList &shape() const { return p_shapes_; }
  98. size_t size() const { return p_shapes_.size(); }
  99. const BaseShapePtr operator[](std::size_t dim) const { return p_shapes_[dim]; }
  100. protected:
  101. BaseShapePtrList p_shapes_; // shape list of each elements
  102. };
  103. using SequeueShapePtr = std::shared_ptr<SequeueShape>;
  104. class TupleShape : public SequeueShape {
  105. public:
  106. TupleShape() : SequeueShape() {}
  107. explicit TupleShape(const BaseShapePtrList &shapes) : SequeueShape(shapes) {}
  108. ~TupleShape() override = default;
  109. MS_DECLARE_PARENT(TupleShape, SequeueShape)
  110. std::string ToString() const override { return type_name() + "(" + SequeueShape::ToString() + ")"; }
  111. BaseShapePtr Clone() const override { return std::make_shared<TupleShape>(ElementsClone()); }
  112. bool operator==(const BaseShape &other) const override { return SequeueEqual<TupleShape>(other); }
  113. };
  114. using TupleShapePtr = std::shared_ptr<TupleShape>;
  115. class ListShape : public SequeueShape {
  116. public:
  117. ListShape() : SequeueShape() {}
  118. explicit ListShape(const BaseShapePtrList &shapes) : SequeueShape(shapes) {}
  119. ~ListShape() override = default;
  120. MS_DECLARE_PARENT(ListShape, SequeueShape)
  121. std::string ToString() const override { return type_name() + "[" + SequeueShape::ToString() + "]"; }
  122. BaseShapePtr Clone() const override { return std::make_shared<ListShape>(SequeueShape::ElementsClone()); }
  123. bool operator==(const BaseShape &other) const override { return SequeueEqual<ListShape>(other); }
  124. };
  125. using ListShapePtr = std::shared_ptr<ListShape>;
  126. } // namespace abstract
  127. } // namespace mindspore
  128. #endif // MINDSPORE_CORE_ABSTRACT_DSHAPE_H_