|
|
|
@@ -46,6 +46,7 @@ class BaseShape : public Base { |
|
|
|
virtual bool operator==(const BaseShape &other) const; |
|
|
|
bool operator!=(const BaseShape &other) const; |
|
|
|
std::size_t hash() const override { return tid(); } |
|
|
|
virtual bool IsDynamic() const = 0; |
|
|
|
|
|
|
|
// return a deep copy |
|
|
|
virtual BaseShapePtr Clone() const = 0; |
|
|
|
@@ -57,6 +58,7 @@ class NoShape : public BaseShape { |
|
|
|
MS_DECLARE_PARENT(NoShape, BaseShape) |
|
|
|
BaseShapePtr Clone() const override { return std::make_shared<NoShape>(); } |
|
|
|
std::string ToString() const override { return type_name(); } |
|
|
|
bool IsDynamic() const override { return false; } |
|
|
|
}; |
|
|
|
extern const std::shared_ptr<NoShape> kNoShape; |
|
|
|
|
|
|
|
@@ -78,10 +80,13 @@ class Shape : public BaseShape { |
|
|
|
ShapeVector &shape() { return shape_; } |
|
|
|
ShapeVector &min_shape() { return min_shape_; } |
|
|
|
ShapeVector &max_shape() { return max_shape_; } |
|
|
|
bool IsDynamic() const override { |
|
|
|
return std::any_of(shape_.begin(), shape_.end(), [](int64_t s) { return s < 0; }); |
|
|
|
} |
|
|
|
|
|
|
|
ShapeVector shape_; // use SHP_ANY to implement the any shape in python |
|
|
|
ShapeVector min_shape_; // record mininum length for each dynamic dimention |
|
|
|
ShapeVector max_shape_; // record maximum length for each dynamic dimention |
|
|
|
ShapeVector min_shape_; // record minimum length for each dynamic dimension |
|
|
|
ShapeVector max_shape_; // record maximum length for each dynamic dimension |
|
|
|
}; |
|
|
|
using ShapePtr = std::shared_ptr<Shape>; |
|
|
|
using ShapePtrList = std::vector<ShapePtr>; |
|
|
|
@@ -102,6 +107,9 @@ class SequeueShape : public BaseShape { |
|
|
|
const BaseShapePtrList &shape() const { return p_shapes_; } |
|
|
|
size_t size() const { return p_shapes_.size(); } |
|
|
|
const BaseShapePtr operator[](std::size_t dim) const { return p_shapes_[dim]; } |
|
|
|
bool IsDynamic() const override { |
|
|
|
return std::any_of(p_shapes_.begin(), p_shapes_.end(), [](const BaseShapePtr &bs) { return bs->IsDynamic(); }); |
|
|
|
} |
|
|
|
|
|
|
|
protected: |
|
|
|
BaseShapePtrList p_shapes_; // shape list of each elements |
|
|
|
|