|
|
@@ -69,7 +69,14 @@ class AbstractBase : public Base { |
|
|
virtual TypePtr BuildType() const = 0; |
|
|
virtual TypePtr BuildType() const = 0; |
|
|
virtual BaseShapePtr BuildShape() const { return kNoShape; } |
|
|
virtual BaseShapePtr BuildShape() const { return kNoShape; } |
|
|
virtual AbstractBasePtr Clone() const = 0; |
|
|
virtual AbstractBasePtr Clone() const = 0; |
|
|
virtual AbstractBasePtr Broaden() const; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// mask for Broaden config |
|
|
|
|
|
inline static const uint8_t kBroadenTensorOnly = 1; |
|
|
|
|
|
inline static const uint8_t kBroadenParameterOnly = 2; |
|
|
|
|
|
// Each bit for on config. |
|
|
|
|
|
// 00000001 -> 1: only boarden tensor |
|
|
|
|
|
// 00000010 -> 2: only boarden parameter |
|
|
|
|
|
virtual AbstractBasePtr Broaden(uint8_t config = 0) const; |
|
|
virtual AbstractBasePtr Join(const AbstractBasePtr &) { return shared_from_base<AbstractBase>(); } |
|
|
virtual AbstractBasePtr Join(const AbstractBasePtr &) { return shared_from_base<AbstractBase>(); } |
|
|
|
|
|
|
|
|
friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr<AbstractBase> &a) { |
|
|
friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr<AbstractBase> &a) { |
|
|
@@ -108,7 +115,7 @@ class AbstractScalar : public AbstractBase { |
|
|
AbstractBasePtr Clone() const override { |
|
|
AbstractBasePtr Clone() const override { |
|
|
return std::make_shared<AbstractScalar>(GetValueTrack(), GetTypeTrack()->Clone()); |
|
|
return std::make_shared<AbstractScalar>(GetValueTrack(), GetTypeTrack()->Clone()); |
|
|
} |
|
|
} |
|
|
AbstractBasePtr Broaden() const override; |
|
|
|
|
|
|
|
|
AbstractBasePtr Broaden(uint8_t config = 0) const override; |
|
|
AbstractBasePtr Join(const AbstractBasePtr &other) override; |
|
|
AbstractBasePtr Join(const AbstractBasePtr &other) override; |
|
|
}; |
|
|
}; |
|
|
using AbstractScalarPtr = std::shared_ptr<AbstractScalar>; |
|
|
using AbstractScalarPtr = std::shared_ptr<AbstractScalar>; |
|
|
@@ -128,7 +135,7 @@ class AbstractType : public AbstractBase { |
|
|
|
|
|
|
|
|
TypePtr BuildType() const override { return std::make_shared<TypeType>(); } |
|
|
TypePtr BuildType() const override { return std::make_shared<TypeType>(); } |
|
|
AbstractBasePtr Clone() const override; |
|
|
AbstractBasePtr Clone() const override; |
|
|
AbstractBasePtr Broaden() const override { return Clone(); } |
|
|
|
|
|
|
|
|
AbstractBasePtr Broaden(uint8_t config = 0) const override { return Clone(); } |
|
|
}; |
|
|
}; |
|
|
using AbstractTypePtr = std::shared_ptr<AbstractType>; |
|
|
using AbstractTypePtr = std::shared_ptr<AbstractType>; |
|
|
|
|
|
|
|
|
@@ -143,7 +150,7 @@ class AbstractError : public AbstractBase { |
|
|
MS_DECLARE_PARENT(AbstractError, AbstractBase) |
|
|
MS_DECLARE_PARENT(AbstractError, AbstractBase) |
|
|
|
|
|
|
|
|
TypePtr BuildType() const override { return std::make_shared<Problem>(); } |
|
|
TypePtr BuildType() const override { return std::make_shared<Problem>(); } |
|
|
AbstractBasePtr Broaden() const override { return Clone(); } |
|
|
|
|
|
|
|
|
AbstractBasePtr Broaden(uint8_t config = 0) const override { return Clone(); } |
|
|
|
|
|
|
|
|
AbstractBasePtr Clone() const override { |
|
|
AbstractBasePtr Clone() const override { |
|
|
return std::make_shared<AbstractError>(GetValueTrack()->cast<StringImmPtr>(), node_); |
|
|
return std::make_shared<AbstractError>(GetValueTrack()->cast<StringImmPtr>(), node_); |
|
|
@@ -180,7 +187,7 @@ class AbstractFunction : public AbstractBase { |
|
|
TypePtr BuildType() const override { return std::make_shared<Function>(); } |
|
|
TypePtr BuildType() const override { return std::make_shared<Function>(); } |
|
|
AbstractBasePtr Clone() const override { return Copy(); } |
|
|
AbstractBasePtr Clone() const override { return Copy(); } |
|
|
// For Function, no need to broaden. |
|
|
// For Function, no need to broaden. |
|
|
AbstractBasePtr Broaden() const override { |
|
|
|
|
|
|
|
|
AbstractBasePtr Broaden(uint8_t config = 0) const override { |
|
|
return const_cast<AbstractFunction *>(this)->shared_from_base<AbstractFunction>(); |
|
|
return const_cast<AbstractFunction *>(this)->shared_from_base<AbstractFunction>(); |
|
|
} |
|
|
} |
|
|
virtual AbstractFunctionPtr Copy() const = 0; |
|
|
virtual AbstractFunctionPtr Copy() const = 0; |
|
|
@@ -209,7 +216,7 @@ class AbstractKeywordArg : public AbstractBase { |
|
|
|
|
|
|
|
|
TypePtr BuildType() const override; |
|
|
TypePtr BuildType() const override; |
|
|
AbstractBasePtr Clone() const override; |
|
|
AbstractBasePtr Clone() const override; |
|
|
AbstractBasePtr Broaden() const override; |
|
|
|
|
|
|
|
|
AbstractBasePtr Broaden(uint8_t config = 0) const override; |
|
|
std::size_t hash() const override; |
|
|
std::size_t hash() const override; |
|
|
|
|
|
|
|
|
bool operator==(const AbstractKeywordArg &other) const; |
|
|
bool operator==(const AbstractKeywordArg &other) const; |
|
|
@@ -275,7 +282,7 @@ class AbstractTensor : public AbstractUndetermined { |
|
|
TypePtr BuildType() const override; |
|
|
TypePtr BuildType() const override; |
|
|
BaseShapePtr BuildShape() const override; |
|
|
BaseShapePtr BuildShape() const override; |
|
|
AbstractBasePtr Clone() const override; |
|
|
AbstractBasePtr Clone() const override; |
|
|
AbstractBasePtr Broaden() const override; |
|
|
|
|
|
|
|
|
AbstractBasePtr Broaden(uint8_t config = 0) const override; |
|
|
AbstractBasePtr BroadenWithShape() const; |
|
|
AbstractBasePtr BroadenWithShape() const; |
|
|
AbstractBasePtr Join(const AbstractBasePtr &other) final; |
|
|
AbstractBasePtr Join(const AbstractBasePtr &other) final; |
|
|
int format() const { return this->format_; } |
|
|
int format() const { return this->format_; } |
|
|
@@ -312,7 +319,7 @@ class AbstractSequeue : public AbstractBase { |
|
|
TypePtrList ElementsType() const; |
|
|
TypePtrList ElementsType() const; |
|
|
BaseShapePtrList ElementsShape() const; |
|
|
BaseShapePtrList ElementsShape() const; |
|
|
AbstractBasePtrList ElementsClone() const; |
|
|
AbstractBasePtrList ElementsClone() const; |
|
|
AbstractBasePtrList ElementsBroaden() const; |
|
|
|
|
|
|
|
|
AbstractBasePtrList ElementsBroaden(uint8_t config = 0) const; |
|
|
|
|
|
|
|
|
template <typename T> |
|
|
template <typename T> |
|
|
ValuePtr ElementsBuildValue() const; |
|
|
ValuePtr ElementsBuildValue() const; |
|
|
@@ -345,7 +352,9 @@ class AbstractTuple : public AbstractSequeue { |
|
|
|
|
|
|
|
|
AbstractBasePtr Clone() const override { return std::make_shared<AbstractTuple>(ElementsClone()); } |
|
|
AbstractBasePtr Clone() const override { return std::make_shared<AbstractTuple>(ElementsClone()); } |
|
|
|
|
|
|
|
|
AbstractBasePtr Broaden() const override { return std::make_shared<AbstractTuple>(ElementsBroaden()); } |
|
|
|
|
|
|
|
|
AbstractBasePtr Broaden(uint8_t config = 0) const override { |
|
|
|
|
|
return std::make_shared<AbstractTuple>(ElementsBroaden(config)); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
AbstractBasePtr Join(const AbstractBasePtr &other) override { return ElementsJoin<AbstractTuple>(other); } |
|
|
AbstractBasePtr Join(const AbstractBasePtr &other) override { return ElementsJoin<AbstractTuple>(other); } |
|
|
|
|
|
|
|
|
@@ -372,7 +381,9 @@ class AbstractList : public AbstractSequeue { |
|
|
|
|
|
|
|
|
AbstractBasePtr Clone() const override { return std::make_shared<AbstractList>(ElementsClone()); } |
|
|
AbstractBasePtr Clone() const override { return std::make_shared<AbstractList>(ElementsClone()); } |
|
|
|
|
|
|
|
|
AbstractBasePtr Broaden() const override { return std::make_shared<AbstractList>(ElementsBroaden()); } |
|
|
|
|
|
|
|
|
AbstractBasePtr Broaden(uint8_t config = 0) const override { |
|
|
|
|
|
return std::make_shared<AbstractList>(ElementsBroaden(config)); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
AbstractBasePtr Join(const AbstractBasePtr &other) override { return ElementsJoin<AbstractList>(other); } |
|
|
AbstractBasePtr Join(const AbstractBasePtr &other) override { return ElementsJoin<AbstractList>(other); } |
|
|
|
|
|
|
|
|
@@ -403,7 +414,7 @@ class AbstractClass : public AbstractBase { |
|
|
AbstractBasePtr GetAttribute(const std::string &name); |
|
|
AbstractBasePtr GetAttribute(const std::string &name); |
|
|
ValuePtr GetMethod(const std::string &name); |
|
|
ValuePtr GetMethod(const std::string &name); |
|
|
AbstractBasePtr Clone() const override; |
|
|
AbstractBasePtr Clone() const override; |
|
|
AbstractBasePtr Broaden() const override; |
|
|
|
|
|
|
|
|
AbstractBasePtr Broaden(uint8_t config = 0) const override; |
|
|
std::string ToString() const override; |
|
|
std::string ToString() const override; |
|
|
Named tag() const { return tag_; } |
|
|
Named tag() const { return tag_; } |
|
|
std::size_t hash() const override; |
|
|
std::size_t hash() const override; |
|
|
@@ -428,7 +439,7 @@ class AbstractDictionary : public AbstractBase { |
|
|
bool operator==(const AbstractDictionary &other) const; |
|
|
bool operator==(const AbstractDictionary &other) const; |
|
|
bool operator==(const AbstractBase &other) const override; |
|
|
bool operator==(const AbstractBase &other) const override; |
|
|
AbstractBasePtr Clone() const override; |
|
|
AbstractBasePtr Clone() const override; |
|
|
AbstractBasePtr Broaden() const override; |
|
|
|
|
|
|
|
|
AbstractBasePtr Broaden(uint8_t config = 0) const override; |
|
|
std::string ToString() const override; |
|
|
std::string ToString() const override; |
|
|
std::size_t hash() const override; |
|
|
std::size_t hash() const override; |
|
|
std::size_t size() const { return key_values_.size(); } |
|
|
std::size_t size() const { return key_values_.size(); } |
|
|
@@ -452,7 +463,7 @@ class AbstractSlice : public AbstractBase { |
|
|
bool operator==(const AbstractSlice &other) const; |
|
|
bool operator==(const AbstractSlice &other) const; |
|
|
bool operator==(const AbstractBase &other) const override; |
|
|
bool operator==(const AbstractBase &other) const override; |
|
|
AbstractBasePtr Clone() const override; |
|
|
AbstractBasePtr Clone() const override; |
|
|
AbstractBasePtr Broaden() const override; |
|
|
|
|
|
|
|
|
AbstractBasePtr Broaden(uint8_t config = 0) const override; |
|
|
std::string ToString() const override; |
|
|
std::string ToString() const override; |
|
|
std::size_t hash() const override; |
|
|
std::size_t hash() const override; |
|
|
AbstractBasePtr start() const { return start_; } |
|
|
AbstractBasePtr start() const { return start_; } |
|
|
@@ -478,7 +489,9 @@ class AbstractJTagged : public AbstractBase { |
|
|
|
|
|
|
|
|
TypePtr BuildType() const override; |
|
|
TypePtr BuildType() const override; |
|
|
AbstractBasePtr Clone() const override { return std::make_shared<AbstractJTagged>(element_->Clone()); } |
|
|
AbstractBasePtr Clone() const override { return std::make_shared<AbstractJTagged>(element_->Clone()); } |
|
|
AbstractBasePtr Broaden() const override { return std::make_shared<AbstractJTagged>(element_->Broaden()); } |
|
|
|
|
|
|
|
|
AbstractBasePtr Broaden(uint8_t config = 0) const override { |
|
|
|
|
|
return std::make_shared<AbstractJTagged>(element_->Broaden(config)); |
|
|
|
|
|
} |
|
|
AbstractBasePtr Join(const AbstractBasePtr &other) override; |
|
|
AbstractBasePtr Join(const AbstractBasePtr &other) override; |
|
|
|
|
|
|
|
|
bool operator==(const AbstractJTagged &other) const; |
|
|
bool operator==(const AbstractJTagged &other) const; |
|
|
@@ -558,7 +571,7 @@ class AbstractRefKey : public AbstractBase { |
|
|
} |
|
|
} |
|
|
RefKeyPtr ref_key_value() const { return ref_key_value_; } |
|
|
RefKeyPtr ref_key_value() const { return ref_key_value_; } |
|
|
AbstractBasePtr Join(const AbstractBasePtr &other) override; |
|
|
AbstractBasePtr Join(const AbstractBasePtr &other) override; |
|
|
AbstractBasePtr Broaden() const override; |
|
|
|
|
|
|
|
|
AbstractBasePtr Broaden(uint8_t config = 0) const override; |
|
|
std::string ToString() const override; |
|
|
std::string ToString() const override; |
|
|
|
|
|
|
|
|
private: |
|
|
private: |
|
|
@@ -588,8 +601,9 @@ class AbstractRef : public AbstractBase { |
|
|
inline RefKeyPtr ref_key_value() const { return ref_key_value_; } |
|
|
inline RefKeyPtr ref_key_value() const { return ref_key_value_; } |
|
|
inline TypePtr target_type() const { return target_type_; } |
|
|
inline TypePtr target_type() const { return target_type_; } |
|
|
inline bool need_cast() const { return need_cast_; } |
|
|
inline bool need_cast() const { return need_cast_; } |
|
|
AbstractBasePtr Broaden() const override { |
|
|
|
|
|
return std::make_shared<AbstractRef>(ref_key_->Broaden(), ref_->Broaden(), need_cast_, target_type_); |
|
|
|
|
|
|
|
|
AbstractBasePtr Broaden(uint8_t config = 0) const override { |
|
|
|
|
|
// always broaden for ref |
|
|
|
|
|
return std::make_shared<AbstractRef>(ref_key_->Broaden(config), ref_->Broaden(), need_cast_, target_type_); |
|
|
} |
|
|
} |
|
|
AbstractBasePtr Join(const AbstractBasePtr &other) override; |
|
|
AbstractBasePtr Join(const AbstractBasePtr &other) override; |
|
|
std::size_t hash() const override { |
|
|
std::size_t hash() const override { |
|
|
@@ -636,7 +650,7 @@ class AbstractRowTensor : public AbstractUndetermined { |
|
|
void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; } |
|
|
void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; } |
|
|
TypePtr BuildType() const override; |
|
|
TypePtr BuildType() const override; |
|
|
AbstractBasePtr Clone() const override; |
|
|
AbstractBasePtr Clone() const override; |
|
|
AbstractBasePtr Broaden() const override; |
|
|
|
|
|
|
|
|
AbstractBasePtr Broaden(uint8_t config = 0) const override; |
|
|
AbstractBasePtr BroadenWithShape() const; |
|
|
AbstractBasePtr BroadenWithShape() const; |
|
|
|
|
|
|
|
|
std::string ToString() const override; |
|
|
std::string ToString() const override; |
|
|
@@ -665,7 +679,7 @@ class AbstractSparseTensor : public AbstractUndetermined { |
|
|
void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; } |
|
|
void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; } |
|
|
TypePtr BuildType() const override; |
|
|
TypePtr BuildType() const override; |
|
|
AbstractBasePtr Clone() const override; |
|
|
AbstractBasePtr Clone() const override; |
|
|
AbstractBasePtr Broaden() const override; |
|
|
|
|
|
|
|
|
AbstractBasePtr Broaden(uint8_t config = 0) const override; |
|
|
AbstractBasePtr BroadenWithShape() const; |
|
|
AbstractBasePtr BroadenWithShape() const; |
|
|
|
|
|
|
|
|
std::string ToString() const override; |
|
|
std::string ToString() const override; |
|
|
|