/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "ir/value.h" #include #include #include #include #include "utils/convert_utils_base.h" namespace mindspore { const ValuePtr ValueSequeue::operator[](const std::size_t &dim) const { if (dim >= size()) { MS_LOG(EXCEPTION) << "List index [" << dim << "] is out of range [" << size() << "]."; } return elements_[dim]; } bool ValueSequeue::erase(size_t idx) { if (idx < size()) { (void)elements_.erase(elements_.begin() + SizeToInt(idx)); return true; } else { return false; } } bool BoolImm::operator==(const Value &other) const { if (other.isa()) { auto other_ = static_cast(other); return *this == other_; } else { return false; } } bool BoolImm::operator==(const BoolImm &other) const { return v_ == other.v_; } bool Int8Imm::operator==(const Value &other) const { if (other.isa()) { auto other_ = static_cast(other); return *this == other_; } else { return false; } } bool Int8Imm::operator==(const Int8Imm &other) const { return v_ == other.v_; } bool Int16Imm::operator==(const Value &other) const { if (other.isa()) { auto other_ = static_cast(other); return *this == other_; } else { return false; } } bool Int16Imm::operator==(const Int16Imm &other) const { return v_ == other.v_; } bool Int32Imm::operator==(const Value &other) const { if (other.isa()) { auto other_ = static_cast(other); return *this == other_; } else { return false; } } bool Int32Imm::operator==(const Int32Imm &other) const { return v_ == other.v_; } bool Int64Imm::operator==(const Value &other) const { if (other.isa()) { auto other_ = static_cast(other); return *this == other_; } else { return false; } } bool Int64Imm::operator==(const Int64Imm &other) const { return v_ == other.v_; } bool UInt8Imm::operator==(const Value &other) const { if (other.isa()) { auto other_ = static_cast(other); return *this == other_; } else { return false; } } bool UInt8Imm::operator==(const UInt8Imm &other) const { return v_ == other.v_; } bool UInt16Imm::operator==(const Value &other) const { if (other.isa()) { auto other_ = static_cast(other); return *this == other_; } else { return false; } } bool UInt16Imm::operator==(const UInt16Imm &other) const { return v_ == other.v_; } bool UInt32Imm::operator==(const Value &other) const { if (other.isa()) { auto other_ = static_cast(other); return *this == other_; } else { return false; } } bool UInt32Imm::operator==(const UInt32Imm &other) const { return v_ == other.v_; } bool UInt64Imm::operator==(const Value &other) const { if (other.isa()) { auto other_ = static_cast(other); return *this == other_; } else { return false; } } bool UInt64Imm::operator==(const UInt64Imm &other) const { return v_ == other.v_; } bool FP32Imm::operator==(const Value &other) const { if (other.isa()) { auto other_ = static_cast(other); return *this == other_; } else { return false; } } bool FP32Imm::operator==(const FP32Imm &other) const { if (std::isinf(v_) && std::isinf(other.v_)) { return true; } return fabs(v_ - other.v_) < FLT_EPSILON; } bool FP64Imm::operator==(const Value &other) const { if (other.isa()) { auto other_ = static_cast(other); return *this == other_; } else { return false; } } bool ValueSequeue::operator==(const Value &other) const { if (other.isa()) { auto other_ = static_cast(other); return *this == other_; } else { return false; } } bool ValueSequeue::operator==(const ValueSequeue &other) const { if (other.elements_.size() != elements_.size()) { return false; } return std::equal(elements_.begin(), elements_.end(), other.elements_.begin(), [](const ValuePtr &lhs, const ValuePtr &rhs) { return *lhs == *rhs; }); } std::string ValueSequeue::ToString() const { std::ostringstream buffer; bool begin = true; for (auto &attr : elements_) { if (!begin) { buffer << ", "; } else { begin = false; } MS_EXCEPTION_IF_NULL(attr); buffer << attr->ToString(); } return buffer.str(); } std::string ValueSequeue::DumpText() const { std::ostringstream oss; for (size_t i = 0; i < elements_.size(); ++i) { MS_EXCEPTION_IF_NULL(elements_[i]); oss << (i > 0 ? ", " : "") << elements_[i]->DumpText(); } return oss.str(); } bool FP64Imm::operator==(const FP64Imm &other) const { if (std::isinf(v_) && std::isinf(other.v_)) { return true; } return fabs(v_ - other.v_) < DBL_EPSILON; } bool StringImm::operator==(const Value &other) const { if (other.isa()) { auto other_ = static_cast(other); return *this == other_; } else { return false; } } bool StringImm::operator==(const StringImm &other) const { return str_ == other.str_; } bool AnyValue::operator==(const Value &other) const { if (other.isa()) { return true; } else { return false; } } const ValuePtr kAnyValue = std::make_shared(); std::size_t ValueSlice::hash() const { MS_EXCEPTION_IF_NULL(start_); MS_EXCEPTION_IF_NULL(stop_); MS_EXCEPTION_IF_NULL(step_); return hash_combine({tid(), start_->hash(), stop_->hash(), step_->hash()}); } bool ValueSlice::operator==(const Value &other) const { if (other.isa()) { auto other_ = static_cast(other); return *this == other_; } else { return false; } } bool ValueSlice::operator==(const ValueSlice &other) const { MS_EXCEPTION_IF_NULL(start_); MS_EXCEPTION_IF_NULL(stop_); MS_EXCEPTION_IF_NULL(step_); return (*start_ == *other.start_ && *stop_ == *other.stop_ && *step_ == *other.step_); } std::string ValueSlice::ToString() const { MS_EXCEPTION_IF_NULL(start_); MS_EXCEPTION_IF_NULL(stop_); MS_EXCEPTION_IF_NULL(step_); std::ostringstream buffer; buffer << "Slice["; buffer << start_->ToString() << " : "; buffer << stop_->ToString() << " : "; buffer << step_->ToString(); buffer << "]"; return buffer.str(); } std::size_t KeywordArg::hash() const { MS_EXCEPTION_IF_NULL(value_); return hash_combine({tid(), std::hash{}(key_), value_->hash()}); } bool KeywordArg::operator==(const Value &other) const { if (other.isa()) { auto other_ = static_cast(other); return *this == other_; } else { return false; } } bool KeywordArg::operator==(const KeywordArg &other) const { return (other.key_ == key_ && *other.value_ == *value_); } std::string KeywordArg::ToString() const { std::ostringstream buffer; buffer << "KeywordArg["; buffer << "key : " << key_; MS_EXCEPTION_IF_NULL(value_); buffer << ", value : " << value_->ToString(); buffer << "]"; return buffer.str(); } const ValuePtr ValueDictionary::operator[](const std::string &key) const { auto it = std::find_if(key_values_.begin(), key_values_.end(), [key](const std::pair &item) { return item.first == key; }); if (it == key_values_.end()) { MS_LOG(EXCEPTION) << "The key " << key << " is not in the map"; } return it->second; } bool ValueDictionary::operator==(const Value &other) const { if (other.isa()) { auto other_ = static_cast(other); return *this == other_; } else { return false; } } bool ValueDictionary::operator==(const ValueDictionary &other) const { if (key_values_.size() != other.key_values_.size()) { return false; } for (size_t index = 0; index < key_values_.size(); index++) { if (key_values_[index].first != other.key_values_[index].first) { return false; } if (!(*key_values_[index].second == *other.key_values_[index].second)) { return false; } } return true; } } // namespace mindspore