You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dshape.h 4.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #ifndef PIPELINE_STATIC_ANALYSIS_DSHAPE_H_
  19. #define PIPELINE_STATIC_ANALYSIS_DSHAPE_H_
  20. #include <vector>
  21. #include <string>
  22. #include <sstream>
  23. #include <unordered_map>
  24. #include <typeindex>
  25. #include <memory>
  26. #include "utils/log_adapter.h"
  27. #include "ir/base.h"
  28. namespace mindspore {
  29. namespace abstract {
  30. class BaseShape;
  31. using BaseShapePtr = std::shared_ptr<BaseShape>;
  32. using BaseShapePtrList = std::vector<BaseShapePtr>;
  33. class BaseShape : public Base {
  34. public:
  35. BaseShape() = default;
  36. ~BaseShape() override = default;
  37. MS_DECLARE_PARENT(BaseShape, Base)
  38. virtual bool operator==(const BaseShape &other) const;
  39. bool operator!=(const BaseShape &other) const;
  40. std::size_t hash() const override { return tid(); }
  41. // return a deep copy
  42. virtual BaseShapePtr Clone() const = 0;
  43. virtual void Broaden() {}
  44. };
  45. class NoShape : public BaseShape {
  46. public:
  47. MS_DECLARE_PARENT(NoShape, BaseShape)
  48. BaseShapePtr Clone() const override { return std::make_shared<NoShape>(); }
  49. std::string ToString() const override { return type_name(); }
  50. };
  51. extern const std::shared_ptr<NoShape> kNoShape;
  52. class Shape : public BaseShape {
  53. public:
  54. static const int SHP_ANY = -1;
  55. Shape() : shape_() {}
  56. Shape(const std::initializer_list<int> &list) : shape_(list) {}
  57. explicit Shape(const std::vector<int> &list) : shape_(list) {}
  58. ~Shape() override = default;
  59. MS_DECLARE_PARENT(Shape, BaseShape)
  60. std::string ToString() const override;
  61. std::string DumpText() const override;
  62. bool operator==(const BaseShape &other) const override;
  63. BaseShapePtr Clone() const override { return std::make_shared<Shape>(shape_); }
  64. void Broaden() override;
  65. std::vector<int> &shape() { return shape_; }
  66. std::vector<int> shape_; // use SHP_ANY to implement the any shape in python
  67. };
  68. using ShapePtr = std::shared_ptr<Shape>;
  69. using ShapePtrList = std::vector<ShapePtr>;
  70. class SequeueShape : public BaseShape {
  71. public:
  72. SequeueShape() : p_shapes_() {}
  73. explicit SequeueShape(const BaseShapePtrList &shapes) : p_shapes_(shapes) {}
  74. ~SequeueShape() override = default;
  75. MS_DECLARE_PARENT(SequeueShape, BaseShape)
  76. std::string ToString() const override;
  77. BaseShapePtrList ElementsClone() const;
  78. template <typename T>
  79. bool SequeueEqual(const BaseShape &other) const;
  80. const BaseShapePtrList &shape() const { return p_shapes_; }
  81. size_t size() const { return p_shapes_.size(); }
  82. const BaseShapePtr operator[](std::size_t dim) const { return p_shapes_[dim]; }
  83. protected:
  84. BaseShapePtrList p_shapes_; // shape list of each elements
  85. };
  86. using SequeueShapePtr = std::shared_ptr<SequeueShape>;
  87. class TupleShape : public SequeueShape {
  88. public:
  89. TupleShape() : SequeueShape() {}
  90. explicit TupleShape(const BaseShapePtrList &shapes) : SequeueShape(shapes) {}
  91. ~TupleShape() override = default;
  92. MS_DECLARE_PARENT(TupleShape, SequeueShape)
  93. std::string ToString() const override { return type_name() + "(" + SequeueShape::ToString() + ")"; }
  94. BaseShapePtr Clone() const override { return std::make_shared<TupleShape>(ElementsClone()); }
  95. bool operator==(const BaseShape &other) const override { return SequeueEqual<TupleShape>(other); }
  96. };
  97. using TupleShapePtr = std::shared_ptr<TupleShape>;
  98. class ListShape : public SequeueShape {
  99. public:
  100. ListShape() : SequeueShape() {}
  101. explicit ListShape(const BaseShapePtrList &shapes) : SequeueShape(shapes) {}
  102. ~ListShape() override = default;
  103. MS_DECLARE_PARENT(ListShape, SequeueShape)
  104. std::string ToString() const override { return type_name() + "[" + SequeueShape::ToString() + "]"; }
  105. BaseShapePtr Clone() const override { return std::make_shared<ListShape>(SequeueShape::ElementsClone()); }
  106. bool operator==(const BaseShape &other) const override { return SequeueEqual<ListShape>(other); }
  107. };
  108. using ListShapePtr = std::shared_ptr<ListShape>;
  109. } // namespace abstract
  110. } // namespace mindspore
  111. #endif // PIPELINE_STATIC_ANALYSIS_DSHAPE_H_