/** * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "pipeline/static_analysis/abstract_value.h" #include #include "utils/symbolic.h" #include "pipeline/static_analysis/static_analysis.h" #include "pipeline/static_analysis/utils.h" namespace mindspore { namespace abstract { bool AbstractBase::operator==(const AbstractBase &other) const { if (tid() != other.tid()) { return false; } if (value_ == nullptr || other.value_ == nullptr) { MS_LOG(EXCEPTION) << "If value_ is nullptr, AbstractBase::operator== should not be called. this: " << this->ToString() << ", other: " << other.ToString(); } bool value_equal = *value_ == *other.value_; bool type_equal = *type_ == *other.type_; bool shape_equal = *shape_ == *other.shape_; return value_equal && type_equal && shape_equal; } ValuePtr AbstractBase::BuildValue() const { if (value_ == nullptr) { return RealBuildValue(); } return value_; } AbstractBasePtr AbstractBase::Broaden() const { AbstractBasePtr clone = Clone(); clone->set_value(kAnyValue); return clone; } std::string AbstractBase::ToString() const { std::ostringstream buffer; std::string value = std::string("value is null"); if (value_ != nullptr) { value = value_->ToString(); } MS_EXCEPTION_IF_NULL(type_); MS_EXCEPTION_IF_NULL(shape_); buffer << type_name() << "(" << "Type: " << type_->ToString() << " Value: " << value << " Shape: " << shape_->ToString() << ")"; return buffer.str(); } AbstractBasePtr AbstractScalar::Broaden() const { AbstractBasePtr clone = Clone(); MS_EXCEPTION_IF_NULL(clone); auto value_track = clone->GetValueTrack(); MS_EXCEPTION_IF_NULL(value_track); if (value_track->isa()) { return clone; } return AbstractBase::Broaden(); } AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) { MS_EXCEPTION_IF_NULL(other); if (*this == *other) { return shared_from_base(); } auto value_self = GetValueTrack(); MS_EXCEPTION_IF_NULL(value_self); ValuePtr res_value = ValueJoin(value_self, other->GetValueTrack()); TypePtr res_type = TypeJoin(GetTypeTrack(), other->GetTypeTrack()); if (res_value == value_self) { return shared_from_base(); } return std::make_shared(res_value, res_type); } AbstractBasePtr AbstractType::Clone() const { ValuePtr value_self = GetValueTrack(); if (value_self == nullptr || !value_self->isa()) { return nullptr; } TypePtr type_self = value_self->cast(); return std::make_shared(type_self->Clone()); } bool AbstractType::operator==(const AbstractBase &other) const { if (tid() != other.tid()) { return false; } // Have to compare TypePtr with value; ValuePtr value_self = GetValueTrack(); ValuePtr value_other = other.GetValueTrack(); if (value_self == nullptr || value_other == nullptr) { MS_LOG(EXCEPTION) << "AbstractType value should not be nullptr. this: " << this->ToString() << ", other: " << other.ToString(); } if (!value_self->isa() || !value_other->isa()) { return false; } TypePtr type_self = value_self->cast(); TypePtr type_other = value_other->cast(); bool value_equal = *type_self == *type_other; return value_equal; } std::string AbstractType::ToString() const { std::ostringstream buffer; ValuePtr value_self = GetValueTrack(); if (value_self == nullptr) { buffer << "AbstractType value: nullptr"; return buffer.str(); } if (!value_self->isa()) { buffer << type_name() << "(Value: nullptr)"; return buffer.str(); } TypePtr type_self = value_self->cast(); MS_EXCEPTION_IF_NULL(type_self); buffer << type_name() << "(" << "Value: " << type_self->ToString() << ")"; return buffer.str(); } std::string AbstractError::ToString() const { std::ostringstream buffer; auto value_track = GetValueTrack(); MS_EXCEPTION_IF_NULL(value_track); buffer << type_name() << "(" << "Value: " << value_track->ToString() << ", Node: " << node_->DebugString() << ")"; return buffer.str(); } AbstractBasePtr AbstractFunction::Join(const AbstractBasePtr &other) { MS_EXCEPTION_IF_NULL(other); auto other_func = dyn_cast(other); if (other_func == nullptr) { MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString(); } return Join(other_func); } bool AbstractFunction::operator==(const AbstractBase &other) const { if (!other.isa()) { return false; } const auto &other_func = static_cast(other); bool value_equal = (*this == other_func); return value_equal; } const AbstractBasePtr AbstractSequeue::operator[](const std::size_t &dim) const { if (dim >= size()) { MS_LOG(EXCEPTION) << "Index [" << dim << "] Out of the size [" << size() << "] of the list."; } return elements_[dim]; } std::string AbstractSequeue::ToString() const { std::ostringstream buffer; int i = 0; for (const auto &ele : elements_) { MS_EXCEPTION_IF_NULL(ele); buffer << "element[" << i << "]: " << ele->ToString() << ","; i++; } return buffer.str(); } TypePtrList AbstractSequeue::ElementsType() const { TypePtrList element_type_list; for (const auto &ele : elements_) { MS_EXCEPTION_IF_NULL(ele); TypePtr element_type = ele->BuildType(); element_type_list.push_back(element_type); } return element_type_list; } BaseShapePtrList AbstractSequeue::ElementsShape() const { BaseShapePtrList element_shape_list; for (const auto &ele : elements_) { MS_EXCEPTION_IF_NULL(ele); BaseShapePtr element_shape = ele->BuildShape(); element_shape_list.push_back(element_shape); } return element_shape_list; } AbstractBasePtrList AbstractSequeue::ElementsClone() const { AbstractBasePtrList ele_list; for (const auto &ele : elements_) { MS_EXCEPTION_IF_NULL(ele); AbstractBasePtr clone = ele->Clone(); ele_list.push_back(clone); } return ele_list; } AbstractBasePtrList AbstractSequeue::ElementsBroaden() const { AbstractBasePtrList ele_list; for (const auto &ele : elements_) { MS_EXCEPTION_IF_NULL(ele); AbstractBasePtr broadend = ele->Broaden(); ele_list.push_back(broadend); } return ele_list; } template ValuePtr AbstractSequeue::ElementsBuildValue() const { std::vector element_value_list; for (const auto &ele : elements_) { ValuePtr element_value = ele->BuildValue(); if (element_value->isa()) { return kAnyValue; } element_value_list.push_back(element_value); } return std::make_shared(element_value_list); } template ValuePtr AbstractSequeue::ElementsBuildValue() const; template ValuePtr AbstractSequeue::ElementsBuildValue() const; template AbstractBasePtr AbstractSequeue::ElementsJoin(const AbstractBasePtr &other) { auto other_sequeue = dyn_cast(other); if (other_sequeue == nullptr) { MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString(); } auto joined_list = AbstractJoin(elements_, other_sequeue->elements_); bool changes = false; for (std::size_t i = 0; i < elements_.size(); i++) { if (elements_[i] != joined_list[i]) { changes = true; break; } } if (!changes) { return shared_from_base(); } return std::make_shared(joined_list); } template AbstractBasePtr AbstractSequeue::ElementsJoin(const AbstractBasePtr &); template AbstractBasePtr AbstractSequeue::ElementsJoin(const AbstractBasePtr &); std::size_t AbstractSequeue::hash() const { std::size_t hash_sum = hash_combine(tid(), std::hash{}(elements_.size())); // Hashing all elements is costly, so only take at most 4 elements into account based on // some experiments. for (size_t i = 0; (i < elements_.size()) && (i < 4); i++) { hash_sum = hash_combine(hash_sum, elements_[i]->hash()); } return hash_sum; } bool AbstractTuple::operator==(const AbstractTuple &other) const { if (&other == this) { return true; } if (elements_.size() != other.elements_.size()) { return false; } for (size_t i = 0; i < elements_.size(); i++) { if (!(*(elements_[i]) == *(other.elements_[i]))) { return false; } } return true; } bool AbstractTuple::operator==(const AbstractBase &other) const { if (&other == this) { return true; } if (other.isa()) { auto other_tuple = static_cast(&other); return *this == *other_tuple; } return false; } bool AbstractList::operator==(const AbstractList &other) const { if (&other == this) { return true; } if (elements_.size() != other.elements_.size()) { return false; } for (size_t i = 0; i < elements_.size(); i++) { if (!(*(elements_[i]) == *(other.elements_[i]))) { return false; } } return true; } bool AbstractList::operator==(const AbstractBase &other) const { if (&other == this) { return true; } if (other.isa()) { auto other_list = static_cast(&other); return *this == *other_list; } return false; } TypePtr AbstractSlice::BuildType() const { MS_EXCEPTION_IF_NULL(start_); MS_EXCEPTION_IF_NULL(stop_); MS_EXCEPTION_IF_NULL(step_); TypePtr start = start_->BuildType(); TypePtr stop = stop_->BuildType(); TypePtr step = step_->BuildType(); return std::make_shared(start, stop, step); } bool AbstractSlice::operator==(const AbstractSlice &other) const { if (&other == this) { return true; } return (*start_ == *other.start_ && *stop_ == *other.stop_ && *step_ == *other.step_); } bool AbstractSlice::operator==(const AbstractBase &other) const { if (&other == this) { return true; } if (!other.isa()) { return false; } auto other_slice = static_cast(&other); return *this == *other_slice; } AbstractBasePtr AbstractSlice::Clone() const { MS_EXCEPTION_IF_NULL(start_); MS_EXCEPTION_IF_NULL(stop_); MS_EXCEPTION_IF_NULL(step_); AbstractBasePtr start = start_->Clone(); AbstractBasePtr stop = stop_->Clone(); AbstractBasePtr step = step_->Clone(); return std::make_shared(start, stop, step); } AbstractBasePtr AbstractSlice::Broaden() const { MS_EXCEPTION_IF_NULL(start_); MS_EXCEPTION_IF_NULL(stop_); MS_EXCEPTION_IF_NULL(step_); AbstractBasePtr start = start_->Broaden(); AbstractBasePtr stop = stop_->Broaden(); AbstractBasePtr step = step_->Broaden(); return std::make_shared(start, stop, step); } std::string AbstractSlice::ToString() const { std::ostringstream buffer; buffer << type_name() << "["; MS_EXCEPTION_IF_NULL(start_); buffer << start_->ToString() << " : "; MS_EXCEPTION_IF_NULL(stop_); buffer << stop_->ToString() << " : "; MS_EXCEPTION_IF_NULL(step_); buffer << step_->ToString(); buffer << "]"; return buffer.str(); } ValuePtr AbstractSlice::RealBuildValue() const { MS_EXCEPTION_IF_NULL(start_); MS_EXCEPTION_IF_NULL(stop_); MS_EXCEPTION_IF_NULL(step_); ValuePtr start = start_->BuildValue(); ValuePtr stop = stop_->BuildValue(); ValuePtr step = step_->BuildValue(); if (start->isa() || stop->isa() || step->isa()) { return kAnyValue; } return std::make_shared(start, stop, step); } std::size_t AbstractSlice::hash() const { MS_EXCEPTION_IF_NULL(start_); MS_EXCEPTION_IF_NULL(stop_); MS_EXCEPTION_IF_NULL(step_); return hash_combine({tid(), start_->hash(), stop_->hash(), step_->hash()}); } TypePtr AbstractTensor::BuildType() const { MS_EXCEPTION_IF_NULL(element_); TypePtr element_type = element_->BuildType(); return std::make_shared(element_type); } BaseShapePtr AbstractTensor::BuildShape() const { auto shape = GetShapeTrack(); // Guard from using set_shape(nullptr) if (shape == nullptr) { return kNoShape; } return shape; } AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) { auto other_tensor = dyn_cast(other); if (other_tensor == nullptr) { MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString(); } auto element = element_->Join(other_tensor->element_); auto shape = ShapeJoin(this->shape(), other_tensor->shape()); return std::make_shared(element, shape); } bool AbstractTensor::operator==(const AbstractTensor &other) const { if (&other == this) { return true; } auto v1 = GetValueTrack(); auto v2 = other.GetValueTrack(); if (v1 == nullptr || v2 == nullptr) { MS_LOG(EXCEPTION) << "The value of AbstractTensor is nullptr"; } bool is_value_equal = (v1 == v2); if (v1->isa() && v2->isa()) { is_value_equal = true; } return (*element_ == *other.element_) && (*shape() == *other.shape()) && is_value_equal; } bool AbstractTensor::operator==(const AbstractBase &other) const { if (&other == this) { return true; } if (other.isa()) { auto other_tensor = static_cast(&other); return *this == *other_tensor; } else { return false; } } AbstractBasePtr AbstractTensor::Clone() const { MS_EXCEPTION_IF_NULL(element_); auto clone = std::make_shared(element_->Clone()); ShapePtr shp = shape(); clone->set_shape(shp->Clone()); clone->set_value(GetValueTrack()); return clone; } AbstractBasePtr AbstractTensor::Broaden() const { MS_EXCEPTION_IF_NULL(element_); auto broaden = std::make_shared(element_->Broaden()); auto shp = shape(); broaden->set_shape(shp->Clone()); broaden->set_value(kAnyValue); return broaden; } AbstractBasePtr AbstractTensor::BroadenWithShape() const { MS_EXCEPTION_IF_NULL(element_); auto broaden = std::make_shared(element_->Broaden()); auto shp = shape()->Clone(); shp->Broaden(); broaden->set_shape(shp); broaden->set_value(kAnyValue); return broaden; } ShapePtr AbstractTensor::shape() const { auto shp = dyn_cast(GetShapeTrack()); if (shp == nullptr) { MS_LOG(EXCEPTION) << "Tensor should have a shape."; } return shp; } std::string AbstractTensor::ToString() const { std::ostringstream buffer; BaseShapePtr shape_track = GetShapeTrack(); MS_EXCEPTION_IF_NULL(shape_track); MS_EXCEPTION_IF_NULL(element_); auto value_track = GetValueTrack(); MS_EXCEPTION_IF_NULL(value_track); buffer << type_name() << "(" << "shape: " << shape_track->ToString() << ", element: " << element_->ToString() << ", value_ptr: " << value_track << ", value: " << value_track->ToString() << ")"; return buffer.str(); } TypePtr AbstractDictionary::BuildType() const { std::vector> key_values; for (const auto &item : key_values_) { MS_EXCEPTION_IF_NULL(item.second); TypePtr type = item.second->BuildType(); key_values.emplace_back(item.first, type); } return std::make_shared(key_values); } bool AbstractDictionary::operator==(const AbstractDictionary &other) const { if (key_values_.size() != other.key_values_.size()) { return false; } for (size_t index = 0; index < key_values_.size(); index++) { if (key_values_[index].first != other.key_values_[index].first) { return false; } if (!(*key_values_[index].second == *other.key_values_[index].second)) { return false; } } return true; } bool AbstractDictionary::operator==(const AbstractBase &other) const { if (&other == this) { return true; } if (other.isa()) { auto other_class = static_cast(&other); return *this == *other_class; } return false; } AbstractBasePtr AbstractDictionary::Clone() const { std::vector kv; (void)std::transform(key_values_.begin(), key_values_.end(), std::back_inserter(kv), [](const AbstractAttribute &item) { MS_EXCEPTION_IF_NULL(item.second); return std::make_pair(item.first, item.second->Clone()); }); return std::make_shared(kv); } AbstractBasePtr AbstractDictionary::Broaden() const { std::vector kv; (void)std::transform(key_values_.begin(), key_values_.end(), std::back_inserter(kv), [](const AbstractAttribute &item) { MS_EXCEPTION_IF_NULL(item.second); return std::make_pair(item.first, item.second->Broaden()); }); return std::make_shared(kv); } std::string AbstractDictionary::ToString() const { std::ostringstream buffer; buffer << type_name() << "{ "; for (const auto &kv : key_values_) { MS_EXCEPTION_IF_NULL(kv.second); buffer << "(" << kv.first << ": " << kv.second->ToString() << ") "; } buffer << "}"; return buffer.str(); } std::size_t AbstractDictionary::hash() const { std::size_t hash_sum = std::accumulate(key_values_.begin(), key_values_.end(), tid(), [](std::size_t hash_sum, const AbstractAttribute &item) { hash_sum = hash_combine(hash_sum, std::hash()(item.first)); MS_EXCEPTION_IF_NULL(item.second); hash_sum = hash_combine(hash_sum, item.second->hash()); return hash_sum; }); return hash_sum; } ValuePtr AbstractDictionary::RealBuildValue() const { std::vector> key_values; for (const auto &item : key_values_) { MS_EXCEPTION_IF_NULL(item.second); auto element_value = item.second->BuildValue(); MS_EXCEPTION_IF_NULL(element_value); if (element_value->isa()) { return kAnyValue; } key_values.emplace_back(item.first, element_value); } return std::make_shared(key_values); } TypePtr AbstractClass::BuildType() const { ClassAttrVector attributes_type; for (auto attr : attributes_) { MS_EXCEPTION_IF_NULL(attr.second); TypePtr type = attr.second->BuildType(); std::pair elem(attr.first, type); attributes_type.push_back(elem); } return std::make_shared(tag_, attributes_type, methods_); } bool AbstractClass::operator==(const AbstractClass &other) const { if (!(tag_ == other.tag_)) { return false; } if (attributes_.size() != other.attributes_.size()) { return false; } for (size_t i = 0; i < attributes_.size(); i++) { MS_EXCEPTION_IF_NULL(attributes_[i].second); MS_EXCEPTION_IF_NULL(other.attributes_[i].second); if (!(*attributes_[i].second == *other.attributes_[i].second)) { MS_LOG(DEBUG) << "attr " << attributes_[i].first << " not equal, arg1:" << attributes_[i].second->ToString() << " arg2:" << other.attributes_[i].second->ToString(); return false; } } // method compare; if (methods_.size() != other.methods_.size()) { return false; } for (const auto &iter : methods_) { auto iter_other = other.methods_.find(iter.first); if (iter_other == other.methods_.end()) { return false; } if (!(*iter.second == *iter_other->second)) { return false; } } return true; } bool AbstractClass::operator==(const AbstractBase &other) const { if (other.isa()) { auto other_class = static_cast(&other); return *this == *other_class; } return false; } AbstractBasePtr AbstractClass::GetAttribute(const std::string &name) { auto it = std::find_if(attributes_.begin(), attributes_.end(), [name](const AbstractAttribute &pair) -> bool { return pair.first == name; }); if (it != attributes_.end()) { return it->second; } return nullptr; } ValuePtr AbstractClass::GetMethod(const std::string &name) { auto method_pair = methods_.find(name); if (method_pair != methods_.end()) { return method_pair->second; } return kAnyValue; } AbstractBasePtr AbstractClass::Clone() const { std::vector attributes_clone; for (auto attr : attributes_) { MS_EXCEPTION_IF_NULL(attr.second); AbstractBasePtr clone = attr.second->Clone(); AbstractAttribute elem(attr.first, clone); attributes_clone.push_back(elem); } return std::make_shared(tag_, attributes_clone, methods_); } AbstractBasePtr AbstractClass::Broaden() const { std::vector attributes_clone; for (auto attr : attributes_) { MS_EXCEPTION_IF_NULL(attr.second); AbstractBasePtr clone = attr.second->Broaden(); AbstractAttribute elem(attr.first, clone); attributes_clone.push_back(elem); } return std::make_shared(tag_, attributes_clone, methods_); } std::string AbstractClass::ToString() const { std::ostringstream buffer; buffer << type_name() << "(tag: " << tag_ << ") attrs:("; bool append_comma = false; for (const auto &attr : attributes_) { if (append_comma) { buffer << ", "; } else { append_comma = true; } MS_EXCEPTION_IF_NULL(attr.second); buffer << attr.first << ":" << attr.second->ToString(); } buffer << ") method:("; append_comma = false; for (const auto &iter : methods_) { if (append_comma) { buffer << ", "; } else { append_comma = true; } MS_EXCEPTION_IF_NULL(iter.second); buffer << iter.first << ":" << iter.second->ToString(); } buffer << ")"; return buffer.str(); } std::size_t AbstractClass::hash() const { std::size_t hash_sum = std::accumulate(attributes_.begin(), attributes_.end(), hash_combine(tid(), tag_.hash()), [](std::size_t hash_sum, const AbstractAttribute &item) { MS_EXCEPTION_IF_NULL(item.second); return hash_combine(hash_sum, item.second->hash()); }); return hash_sum; } ValuePtr AbstractClass::RealBuildValue() const { auto cls = BuildType()->cast(); std::unordered_map attributes_value_map; for (const auto &attr : attributes_) { MS_EXCEPTION_IF_NULL(attr.second); ValuePtr _value = attr.second->BuildValue(); if (_value->isa()) { return kAnyValue; } attributes_value_map[attr.first] = _value; } cls->set_value(attributes_value_map); return cls; } TypePtr AbstractJTagged::BuildType() const { MS_EXCEPTION_IF_NULL(element_); TypePtr subtype = element_->BuildType(); return std::make_shared(subtype); } AbstractBasePtr AbstractJTagged::Join(const AbstractBasePtr &other) { auto other_jtagged = dyn_cast(other); if (other_jtagged == nullptr) { MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString(); } auto joined_elem = element_->Join(other_jtagged->element_); return std::make_shared(joined_elem); } bool AbstractJTagged::operator==(const AbstractJTagged &other) const { MS_EXCEPTION_IF_NULL(element_); MS_EXCEPTION_IF_NULL(other.element_); return (*element_ == *other.element_); } bool AbstractJTagged::operator==(const AbstractBase &other) const { if (other.isa()) { auto other_jtagged = static_cast(&other); return *this == *other_jtagged; } return false; } std::string AbstractJTagged::ToString() const { std::ostringstream buffer; MS_EXCEPTION_IF_NULL(element_); buffer << type_name() << "(" << "element: " << element_->ToString() << ")"; return buffer.str(); } TypePtr AbstractRef::BuildType() const { TypePtr subtype = ref_->BuildType(); TypePtr subtype_origin = ref_origin_->BuildType(); return std::make_shared(subtype, subtype_origin); } bool AbstractRef::operator==(const AbstractRef &other) const { return (*ref_ == *other.ref_) && (*ref_key_ == *other.ref_key_) && (*ref_origin_ == *other.ref_origin_); } bool AbstractRef::operator==(const AbstractBase &other) const { if (other.isa()) { auto other_conf = static_cast(&other); return *this == *other_conf; } return false; } std::string AbstractRef::ToString() const { std::ostringstream buffer; buffer << type_name() << "(" << "key: " << ref_key_->ToString() << "ref_value: " << ref_->ToString() << "origin_value: " << ref_origin_->ToString(); auto value = GetValueTrack(); if (value) { buffer << ", value: " << value->ToString(); } buffer << ")"; return buffer.str(); } bool AbstractNone::operator==(const AbstractNone &) const { return true; } bool AbstractNone::operator==(const AbstractBase &other) const { if (other.isa()) { auto other_none = static_cast(&other); return *this == *other_none; } return false; } std::string AbstractNone::ToString() const { std::ostringstream buffer; buffer << type_name() << "(Value: None)"; return buffer.str(); } ValuePtr AbstractNone::RealBuildValue() const { return kNone; } bool AbstractRefKey::operator==(const AbstractRefKey &other) const { ValuePtr value_self = GetValueTrack(); ValuePtr value_other = other.GetValueTrack(); if (value_self != nullptr && value_other != nullptr) { if (value_self->isa() && value_other->isa()) { return true; } if (!value_self->isa() || !value_other->isa()) { return false; } RefKeyPtr type_self = value_self->cast(); RefKeyPtr type_other = value_other->cast(); return *type_self == *type_other; } else if (value_self != nullptr || value_other != nullptr) { return false; } return true; } bool AbstractRefKey::operator==(const AbstractBase &other) const { if (other.isa()) { auto other_confkey = static_cast(&other); return *this == *other_confkey; } else { return false; } } std::string AbstractRefKey::ToString() const { std::ostringstream buffer; buffer << type_name(); auto value = GetValueTrack(); if (value) { buffer << "(value: " << value->ToString() << ")"; } return buffer.str(); } bool AbstractNull::operator==(const AbstractNull &) const { return true; } bool AbstractNull::operator==(const AbstractBase &other) const { if (&other == this) { return true; } if (other.isa()) { auto other_none = static_cast(&other); return *this == *other_none; } else { return false; } } std::string AbstractNull::ToString() const { std::ostringstream buffer; buffer << type_name() << "(Value: Null)"; return buffer.str(); } bool AbstractEllipsis::operator==(const AbstractEllipsis &) const { return true; } bool AbstractEllipsis::operator==(const AbstractBase &other) const { if (&other == this) { return true; } if (other.isa()) { auto other_none = static_cast(&other); return *this == *other_none; } else { return false; } } std::string AbstractEllipsis::ToString() const { std::ostringstream buffer; buffer << type_name() << "(Value: Ellipsis)"; return buffer.str(); } TypePtr AbstractKeywordArg::BuildType() const { MS_EXCEPTION_IF_NULL(arg_value_); TypePtr type = arg_value_->BuildType(); return std::make_shared(arg_name_, type); } AbstractBasePtr AbstractKeywordArg::Clone() const { MS_EXCEPTION_IF_NULL(arg_value_); return std::make_shared(arg_name_, arg_value_->Clone()); } AbstractBasePtr AbstractKeywordArg::Broaden() const { MS_EXCEPTION_IF_NULL(arg_value_); return std::make_shared(arg_name_, arg_value_->Broaden()); } std::size_t AbstractKeywordArg::hash() const { MS_EXCEPTION_IF_NULL(arg_value_); return hash_combine({tid(), std::hash{}(arg_name_), arg_value_->hash()}); } std::string AbstractKeywordArg::ToString() const { std::ostringstream buffer; MS_EXCEPTION_IF_NULL(arg_value_); buffer << type_name() << "("; buffer << "key : " << arg_name_; buffer << "value : " << arg_value_->ToString(); buffer << ")"; return buffer.str(); } bool AbstractKeywordArg::operator==(const AbstractBase &other) const { if (&other == this) { return true; } if (other.isa()) { auto other_tuple = static_cast(&other); return *this == *other_tuple; } return false; } bool AbstractKeywordArg::operator==(const AbstractKeywordArg &other) const { if (&other == this) { return true; } MS_EXCEPTION_IF_NULL(arg_value_); MS_EXCEPTION_IF_NULL(other.arg_value_); return other.arg_name_ == arg_name_ && *other.arg_value_ == *arg_value_; } ValuePtr AbstractKeywordArg::RealBuildValue() const { MS_EXCEPTION_IF_NULL(arg_value_); ValuePtr value = arg_value_->BuildValue(); MS_EXCEPTION_IF_NULL(value); if (value->isa()) { return kAnyValue; } return std::make_shared(arg_name_, value); } std::size_t AbstractBasePtrListHash(const AbstractBasePtrList &args_spec_list) { std::size_t hash_value = 0; // Hashing all elements is costly, so only take at most 4 elements into account based on // some experiments. for (size_t i = 0; (i < args_spec_list.size()) && (i < 4); i++) { MS_EXCEPTION_IF_NULL(args_spec_list[i]); hash_value = hash_combine(hash_value, args_spec_list[i]->hash()); } return hash_value; } bool AbstractBasePtrListDeepEqual(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs) { if (lhs.size() != rhs.size()) { return false; } std::size_t size = lhs.size(); for (std::size_t i = 0; i < size; i++) { MS_EXCEPTION_IF_NULL(lhs[i]); MS_EXCEPTION_IF_NULL(rhs[i]); if (!(*lhs[i] == *rhs[i])) { return false; } } return true; } std::size_t AbstractBasePtrListHasher::operator()(const AbstractBasePtrList &args_spec_list) const { return AbstractBasePtrListHash(args_spec_list); } bool AbstractBasePtrListEqual::operator()(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs) const { return AbstractBasePtrListDeepEqual(lhs, rhs); } } // namespace abstract } // namespace mindspore