/** * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef PIPELINE_STATIC_ANALYSIS_ABSTRACT_VALUE_H_ #define PIPELINE_STATIC_ANALYSIS_ABSTRACT_VALUE_H_ #include #include #include #include #include #include "utils/log_adapter.h" #include "utils/hashing.h" #include "ir/base.h" #include "ir/dtype.h" #include "ir/value.h" #include "ir/meta_tensor.h" #include "pipeline/static_analysis/dshape.h" namespace mindspore { namespace abstract { class AbstractBase; using AbstractBasePtrList = std::vector; // The base class for abstract value. The abstract value is used in inferring // to express the type, shape, and value of the real value. class AbstractBase : public Base { public: explicit AbstractBase(const ValuePtr &value = nullptr, const TypePtr &type = kAnyType, const BaseShapePtr &shape = kNoShape) : value_(value), type_(type), shape_(shape) {} ~AbstractBase() override = default; MS_DECLARE_PARENT(AbstractBase, Base) std::size_t hash() const override { return tid(); } std::string ToString() const override; virtual bool operator==(const AbstractBase &other) const; void set_value(const ValuePtr &value) { value_ = value; } void set_type(const TypePtr &type) { type_ = type; } void set_shape(const BaseShapePtr &shape) { shape_ = shape; } void set_value_desc(const std::string &desc) { value_desc_ = desc; } const std::string &value_desc() const { return value_desc_; } ValuePtr GetValueTrack() const { return value_; } TypePtr GetTypeTrack() const { return type_; } BaseShapePtr GetShapeTrack() const { return shape_; } // Try build a real value from an abstract value. If the value cannot be built, // a default value (AnyValue) is returned. ValuePtr BuildValue() const; virtual TypePtr BuildType() const = 0; virtual BaseShapePtr BuildShape() const { return kNoShape; } virtual AbstractBasePtr Clone() const = 0; virtual AbstractBasePtr Broaden() const; virtual AbstractBasePtr Join(const AbstractBasePtr &) { return shared_from_base(); } friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr &a) { os << a->ToString(); return os; } protected: // default implementation, it can be overwritten by subclass; virtual ValuePtr RealBuildValue() const { return kAnyValue; } private: ValuePtr value_; TypePtr type_; BaseShapePtr shape_; std::string value_desc_; // store initial value description for error report }; class AbstractScalar : public AbstractBase { public: AbstractScalar() : AbstractBase(kAnyValue, kAnyType) {} explicit AbstractScalar(const ValuePtr &value, const TypePtr &type) : AbstractBase(value, type) {} explicit AbstractScalar(const ValuePtr &value) : AbstractBase(value, value->type()) {} explicit AbstractScalar(int value) : AbstractBase(MakeValue(value), kInt32) {} explicit AbstractScalar(float value) : AbstractBase(MakeValue(value), kFloat32) {} explicit AbstractScalar(double value) : AbstractBase(MakeValue(value), kFloat64) {} explicit AbstractScalar(bool value) : AbstractBase(MakeValue(value), kBool) {} explicit AbstractScalar(const std::string &value) : AbstractBase(MakeValue(value), kString) {} explicit AbstractScalar(const TypePtr &type) : AbstractBase(kAnyValue, type) {} ~AbstractScalar() override = default; MS_DECLARE_PARENT(AbstractScalar, AbstractBase) std::size_t hash() const override { return hash_combine({tid(), GetValueTrack()->hash(), GetTypeTrack()->hash()}); } TypePtr BuildType() const override { return GetTypeTrack(); } AbstractBasePtr Clone() const override { return std::make_shared(GetValueTrack(), GetTypeTrack()->Clone()); } AbstractBasePtr Broaden() const override; AbstractBasePtr Join(const AbstractBasePtr &other) override; }; using AbstractScalarPtr = std::shared_ptr; class AbstractType : public AbstractBase { public: explicit AbstractType(const TypePtr &type) : AbstractBase(type, kTypeType) { if (type == nullptr) { MS_LOG(EXCEPTION) << "type is nullptr"; } } ~AbstractType() override = default; MS_DECLARE_PARENT(AbstractType, AbstractBase) std::string ToString() const override; bool operator==(const AbstractBase &other) const override; TypePtr BuildType() const override { return std::make_shared(); } AbstractBasePtr Clone() const override; AbstractBasePtr Broaden() const override { return Clone(); } }; using AbstractTypePtr = std::shared_ptr; class AbstractError : public AbstractBase { public: explicit AbstractError(const StringImmPtr &err, const AnfNodePtr &node) : AbstractBase(err), node_(node) { if (err == nullptr || node == nullptr) { MS_LOG(EXCEPTION) << "err or node is nullptr"; } } ~AbstractError() override = default; MS_DECLARE_PARENT(AbstractError, AbstractBase) TypePtr BuildType() const override { return std::make_shared(); } AbstractBasePtr Broaden() const override { return Clone(); } AbstractBasePtr Clone() const override { return std::make_shared(GetValueTrack()->cast(), node_); } std::string ToString() const override; private: // Origin node been specialized to AbstractError, for debug purpose only. const AnfNodePtr node_; }; class Evaluator; using EvaluatorPtr = std::shared_ptr; class AnalysisEngine; using AnalysisEnginePtr = std::shared_ptr; class AbstractFunction; using AbstractFunctionPtr = std::shared_ptr; class AbstractFuncAtom; using AbstractFuncAtomPtr = std::shared_ptr; using AbstractFuncAtomPtrList = std::vector; class AbstractFunction : public AbstractBase { public: AbstractFunction() = default; ~AbstractFunction() override = default; MS_DECLARE_PARENT(AbstractFunction, AbstractBase) // If there is exactly one possible function, return it. Otherwise, raise an Exception. // Caller should ensure the uniqueness. virtual AbstractFunctionPtr GetUnique() = 0; TypePtr BuildType() const override { return std::make_shared(); } AbstractBasePtr Clone() const override { return Copy(); } // For Function, no need to broaden. AbstractBasePtr Broaden() const override { return const_cast(this)->shared_from_base(); } virtual AbstractFunctionPtr Copy() const = 0; AbstractBasePtr Join(const AbstractBasePtr &other) final; virtual AbstractFunctionPtr Join(const AbstractFunctionPtr &other) = 0; virtual void Visit(std::function) const = 0; bool operator==(const AbstractBase &other) const final; virtual bool operator==(const AbstractFunction &other) const = 0; static AbstractFunctionPtr MakeAbstractFunction(const AbstractFuncAtomPtrList &func_list); virtual EvaluatorPtr GetEvaluator(AnalysisEnginePtr engine) = 0; virtual AnfNodePtr tracking_id() const { return nullptr; } virtual void set_tracking_id(AnfNodePtr) {} virtual AnalysisContextPtr context() const { return nullptr; } }; using AbstractFunctionPtrList = std::vector; // Represents a key-value pair used in function's parameters. class AbstractKeywordArg : public AbstractBase { public: AbstractKeywordArg(const std::string &key, const AbstractBasePtr &argument) : arg_name_(key), arg_value_(argument) {} ~AbstractKeywordArg() override = default; MS_DECLARE_PARENT(AbstractKeywordArg, AbstractBase) TypePtr BuildType() const override; AbstractBasePtr Clone() const override; AbstractBasePtr Broaden() const override; std::size_t hash() const override; bool operator==(const AbstractKeywordArg &other) const; bool operator==(const AbstractBase &other) const override; std::string get_key() const { return arg_name_; } AbstractBasePtr get_arg() const { return arg_value_; } std::string ToString() const override; protected: ValuePtr RealBuildValue() const override; private: std::string arg_name_; AbstractBasePtr arg_value_; }; using AbstractKeywordArgPtr = std::shared_ptr; class AbstractTensor : public AbstractBase { public: // only element_ and value, shape track are valid member, type track are unknown. explicit AbstractTensor(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared()) : AbstractBase(kAnyValue), element_(element) { if (element == nullptr) { MS_LOG(EXCEPTION) << "element is nullptr"; } if (element->isa()) { MS_LOG(EXCEPTION) << "element type error"; } set_shape(shape); } AbstractTensor(const TypePtr &element_type, const std::vector &shape) : AbstractBase(kAnyValue), element_(std::make_shared(kAnyValue, element_type)) { if (element_type == nullptr) { MS_LOG(EXCEPTION) << "element_type is nullptr"; } set_shape(std::make_shared(shape)); } explicit AbstractTensor(const tensor::TensorPtr &tensor) : AbstractBase(tensor), element_(std::make_shared(kAnyValue, tensor->Dtype())) { if (tensor == nullptr) { MS_LOG(EXCEPTION) << "tensor is nullptr"; } set_shape(std::make_shared(tensor->shape())); } ~AbstractTensor() override = default; MS_DECLARE_PARENT(AbstractTensor, AbstractBase) TypePtr BuildType() const override; BaseShapePtr BuildShape() const override; AbstractBasePtr Clone() const override; AbstractBasePtr Broaden() const override; AbstractBasePtr BroadenWithShape() const; AbstractBasePtr Join(const AbstractBasePtr &other) final; bool operator==(const AbstractTensor &other) const; bool operator==(const AbstractBase &other) const override; ShapePtr shape() const; std::string ToString() const override; const AbstractBasePtr element() const { return element_; } std::size_t hash() const override { auto value = GetValueTrack(); auto hash_sum = hash_combine(tid(), element_->hash()); if (value != nullptr) { auto tensor = value->cast(); if (tensor != nullptr) { hash_sum = hash_combine(hash_sum, IntToSize(tensor->DataSize())); } } return hash_sum; } private: AbstractBasePtr element_; }; using AbstractTensorPtr = std::shared_ptr; using AbstractTensorPtrList = std::vector; class AbstractSequeue : public AbstractBase { public: explicit AbstractSequeue(const AbstractBasePtrList &elements) : elements_(elements) {} ~AbstractSequeue() override = default; MS_DECLARE_PARENT(AbstractSequeue, AbstractBase) TypePtrList ElementsType() const; BaseShapePtrList ElementsShape() const; AbstractBasePtrList ElementsClone() const; AbstractBasePtrList ElementsBroaden() const; template ValuePtr ElementsBuildValue() const; template AbstractBasePtr ElementsJoin(const AbstractBasePtr &other); std::size_t size() const { return elements_.size(); } const AbstractBasePtrList &elements() const { return elements_; } std::size_t hash() const override; std::string ToString() const override; const AbstractBasePtr operator[](const std::size_t &dim) const; protected: AbstractBasePtrList elements_; }; using AbstractSequeuePtr = std::shared_ptr; class AbstractTuple : public AbstractSequeue { public: explicit AbstractTuple(const AbstractBasePtrList &elements) : AbstractSequeue(elements) {} ~AbstractTuple() override = default; MS_DECLARE_PARENT(AbstractTuple, AbstractSequeue) TypePtr BuildType() const override { return std::make_shared(ElementsType()); } BaseShapePtr BuildShape() const override { return std::make_shared(ElementsShape()); } AbstractBasePtr Clone() const override { return std::make_shared(ElementsClone()); } AbstractBasePtr Broaden() const override { return std::make_shared(ElementsBroaden()); } AbstractBasePtr Join(const AbstractBasePtr &other) override { return ElementsJoin(other); } std::string ToString() const override { return type_name() + "(" + AbstractSequeue::ToString() + ")"; } bool operator==(const AbstractTuple &other) const; bool operator==(const AbstractBase &other) const override; protected: ValuePtr RealBuildValue() const override { return ElementsBuildValue(); } }; using AbstractTuplePtr = std::shared_ptr; class AbstractList : public AbstractSequeue { public: explicit AbstractList(const AbstractBasePtrList &elements) : AbstractSequeue(elements) {} ~AbstractList() override = default; MS_DECLARE_PARENT(AbstractList, AbstractSequeue) TypePtr BuildType() const override { return std::make_shared(ElementsType()); } BaseShapePtr BuildShape() const override { return std::make_shared(ElementsShape()); } AbstractBasePtr Clone() const override { return std::make_shared(ElementsClone()); } AbstractBasePtr Broaden() const override { return std::make_shared(ElementsBroaden()); } AbstractBasePtr Join(const AbstractBasePtr &other) override { return ElementsJoin(other); } std::string ToString() const override { return type_name() + "[" + AbstractSequeue::ToString() + "]"; } bool operator==(const AbstractList &other) const; bool operator==(const AbstractBase &other) const override; protected: ValuePtr RealBuildValue() const override { return ElementsBuildValue(); } }; using AbstractListPtr = std::shared_ptr; class AbstractClass : public AbstractBase { public: AbstractClass(const Named &tag, const std::vector &attributes, const std::unordered_map &methods) : attributes_(attributes), tag_(tag), methods_(methods) {} ~AbstractClass() override = default; MS_DECLARE_PARENT(AbstractClass, AbstractBase) TypePtr BuildType() const override; bool operator==(const AbstractClass &other) const; bool operator==(const AbstractBase &other) const override; const std::vector &attributes() const { return attributes_; } std::unordered_map methods() { return methods_; } AbstractBasePtr GetAttribute(const std::string &name); ValuePtr GetMethod(const std::string &name); AbstractBasePtr Clone() const override; AbstractBasePtr Broaden() const override; std::string ToString() const override; Named tag() const { return tag_; } std::size_t hash() const override; protected: ValuePtr RealBuildValue() const override; private: std::vector attributes_; Named tag_; std::unordered_map methods_; }; using AbstractClassPtr = std::shared_ptr; class AbstractDictionary : public AbstractBase { public: explicit AbstractDictionary(const std::vector &key_values) : key_values_(key_values) {} ~AbstractDictionary() override = default; MS_DECLARE_PARENT(AbstractDictionary, AbstractBase) TypePtr BuildType() const override; bool operator==(const AbstractDictionary &other) const; bool operator==(const AbstractBase &other) const override; AbstractBasePtr Clone() const override; AbstractBasePtr Broaden() const override; std::string ToString() const override; std::size_t hash() const override; std::size_t size() const { return key_values_.size(); } const std::vector &elements() const { return key_values_; } std::vector key_values_; protected: ValuePtr RealBuildValue() const override; }; using AbstractDictionaryPtr = std::shared_ptr; class AbstractSlice : public AbstractBase { public: AbstractSlice(const AbstractBasePtr &start, const AbstractBasePtr &stop, const AbstractBasePtr &step) : start_(start), stop_(stop), step_(step) {} ~AbstractSlice() override = default; MS_DECLARE_PARENT(AbstractSlice, AbstractBase) TypePtr BuildType() const override; bool operator==(const AbstractSlice &other) const; bool operator==(const AbstractBase &other) const override; AbstractBasePtr Clone() const override; AbstractBasePtr Broaden() const override; std::string ToString() const override; std::size_t hash() const override; AbstractBasePtr start() const { return start_; } AbstractBasePtr stop() const { return stop_; } AbstractBasePtr step() const { return step_; } protected: ValuePtr RealBuildValue() const override; private: AbstractBasePtr start_; AbstractBasePtr stop_; AbstractBasePtr step_; }; using AbstractSlicePtr = std::shared_ptr; class AbstractJTagged : public AbstractBase { public: explicit AbstractJTagged(const AbstractBasePtr &element) : element_(element) {} ~AbstractJTagged() override = default; MS_DECLARE_PARENT(AbstractJTagged, AbstractBase) TypePtr BuildType() const override; AbstractBasePtr Clone() const override { return std::make_shared(element_->Clone()); } AbstractBasePtr Broaden() const override { return std::make_shared(element_->Broaden()); } AbstractBasePtr Join(const AbstractBasePtr &other) override; bool operator==(const AbstractJTagged &other) const; bool operator==(const AbstractBase &other) const override; std::string ToString() const override; AbstractBasePtr element() { return element_; } std::size_t hash() const override { return hash_combine(tid(), element_->hash()); } private: AbstractBasePtr element_; }; using AbstractJTaggedPtr = std::shared_ptr; class AbstractNone : public AbstractBase { public: AbstractNone() : AbstractBase() { set_type(std::make_shared()); } ~AbstractNone() override = default; MS_DECLARE_PARENT(AbstractNone, AbstractBase) TypePtr BuildType() const override { return std::make_shared(); } bool operator==(const AbstractNone &other) const; bool operator==(const AbstractBase &other) const override; AbstractBasePtr Clone() const override { return std::make_shared(); } std::string ToString() const override; protected: ValuePtr RealBuildValue() const override; }; using AbstractNonePtr = std::shared_ptr; // the un assigned state value for variable, which means the variable is not assigned class AbstractNull : public AbstractBase { public: AbstractNull() : AbstractBase(kNull) { set_type(std::make_shared()); } ~AbstractNull() override = default; MS_DECLARE_PARENT(AbstractNull, AbstractBase) TypePtr BuildType() const override { return std::make_shared(); } bool operator==(const AbstractNull &other) const; bool operator==(const AbstractBase &other) const override; AbstractBasePtr Clone() const override { return std::make_shared(); } std::string ToString() const override; }; using AbstractNullPtr = std::shared_ptr; class AbstractEllipsis : public AbstractBase { public: AbstractEllipsis() : AbstractBase(kEllipsis) { set_type(std::make_shared()); } ~AbstractEllipsis() override = default; MS_DECLARE_PARENT(AbstractEllipsis, AbstractBase) TypePtr BuildType() const override { return std::make_shared(); } bool operator==(const AbstractEllipsis &other) const; bool operator==(const AbstractBase &other) const override; AbstractBasePtr Clone() const override { return std::make_shared(); } std::string ToString() const override; }; using AbstractEllipsisPtr = std::shared_ptr; class AbstractRefKey : public AbstractBase { public: AbstractRefKey() : AbstractBase() { set_type(std::make_shared()); } ~AbstractRefKey() override = default; MS_DECLARE_PARENT(AbstractRefKey, AbstractBase) TypePtr BuildType() const override { return std::make_shared(); } bool operator==(const AbstractRefKey &other) const; bool operator==(const AbstractBase &other) const override; AbstractBasePtr Clone() const override { return std::make_shared(); } std::string ToString() const override; }; using AbstractRefKeyPtr = std::shared_ptr; class AbstractRef : public AbstractBase { public: AbstractRef(const AbstractBasePtr &ref_key, const AbstractBasePtr &ref_value, const AbstractBasePtr &ref_origin) : ref_key_(ref_key), ref_(ref_value), ref_origin_(ref_origin) { set_type(std::make_shared()); } ~AbstractRef() override = default; MS_DECLARE_PARENT(AbstractRef, AbstractBase) TypePtr BuildType() const override; bool operator==(const AbstractRef &other) const; bool operator==(const AbstractBase &other) const override; AbstractBasePtr Clone() const override { return std::make_shared(ref_key_->Clone(), ref_->Clone(), ref_origin_->Clone()); } std::string ToString() const override; AbstractBasePtr ref() { return ref_; } AbstractBasePtr ref_origin() { return ref_origin_; } AbstractBasePtr ref_key() { return ref_key_; } AbstractBasePtr Broaden() const override { return std::make_shared(ref_key_->Broaden(), ref_->Broaden(), ref_origin_->Broaden()); } std::size_t hash() const override { return ref_key_->hash() ^ ref_->hash() ^ ref_origin_->hash() ^ (std::hash{}(this->tid()) << 1); } private: AbstractBasePtr ref_key_; AbstractBasePtr ref_; AbstractBasePtr ref_origin_; }; using AbstractRefPtr = std::shared_ptr; struct AbstractBasePtrListHasher { std::size_t operator()(const AbstractBasePtrList &args_spec_list) const; }; struct AbstractBasePtrListEqual { bool operator()(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs) const; }; std::size_t AbstractBasePtrListHash(const AbstractBasePtrList &args_spec_list); bool AbstractBasePtrListDeepEqual(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs); } // namespace abstract } // namespace mindspore #endif // PIPELINE_STATIC_ANALYSIS_ABSTRACT_VALUE_H_