/** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_IR_SCALAR_H_ #define MINDSPORE_CCSRC_IR_SCALAR_H_ namespace mindspore { /* namespace to support inference engine */ #include #include #include #include #include #include #include #include #include "ir/base.h" #include "ir/dtype.h" class Scalar : public Value { public: Scalar() = default; explicit Scalar(const TypePtr t) : Value(t) {} ~Scalar() override = default; MS_DECLARE_PARENT(Scalar, Value) virtual bool IsZero() = 0; virtual bool IsOne() = 0; abstract::AbstractBasePtr ToAbstract() override; protected: std::size_t hash_ = 0; }; using ScalarPtr = std::shared_ptr; class BoolImm : public Scalar { public: explicit BoolImm(bool b) : Scalar(kBool), v_(b) { hash_ = std::hash{}(v_); } ~BoolImm() override = default; MS_DECLARE_PARENT(BoolImm, Scalar) std::size_t hash() const override { return hash_; } bool value() const { return v_; } bool IsZero() override { return v_ == false; } bool IsOne() override { return v_ == true; } bool operator==(const Value &other) const override; bool operator==(const BoolImm &other) const; std::string ToString() const override { if (v_) { return "true"; } else { return "false"; } } std::string DumpText() const override { std::ostringstream oss; oss << "Bool(" << v_ << ")"; return oss.str(); } private: bool v_; }; using BoolImmPtr = std::shared_ptr; IMM_TRAITS(BoolImmPtr, bool) class IntergerImm : public Scalar { public: IntergerImm() = default; explicit IntergerImm(const TypePtr &t) : Scalar(t) {} ~IntergerImm() override = default; MS_DECLARE_PARENT(IntergerImm, Scalar) }; class Int8Imm : public IntergerImm { public: Int8Imm() : IntergerImm(kInt8), v_(0) {} explicit Int8Imm(int8_t v) : IntergerImm(kInt8), v_(v) { hash_ = std::hash{}(v_); } ~Int8Imm() override = default; MS_DECLARE_PARENT(Int8Imm, IntergerImm) std::size_t hash() const override { return hash_; } bool IsZero() override { return v_ == 0; } bool IsOne() override { return v_ == 1; } int8_t value() const { return v_; } bool operator==(const Value &other) const override; bool operator==(const Int8Imm &other) const; std::string ToString() const override { return std::to_string(v_); } std::string DumpText() const override { std::ostringstream oss; oss << "I8(" << v_ << ")"; return oss.str(); } private: int8_t v_; }; using Int8ImmPtr = std::shared_ptr; IMM_TRAITS(Int8ImmPtr, int8_t) class Int16Imm : public IntergerImm { public: Int16Imm() : IntergerImm(kInt16), v_(0) {} explicit Int16Imm(int16_t v) : IntergerImm(kInt16), v_(v) { hash_ = std::hash{}(v_); } ~Int16Imm() override = default; MS_DECLARE_PARENT(Int16Imm, IntergerImm) std::size_t hash() const override { return hash_; } bool IsZero() override { return v_ == 0; } bool IsOne() override { return v_ == 1; } int16_t value() const { return v_; } bool operator==(const Value &other) const override; bool operator==(const Int16Imm &other) const; std::string ToString() const override { return std::to_string(v_); } std::string DumpText() const override { std::ostringstream oss; oss << "I16(" << v_ << ")"; return oss.str(); } private: int16_t v_; }; using Int16ImmPtr = std::shared_ptr; IMM_TRAITS(Int16ImmPtr, int16_t) class Int32Imm : public IntergerImm { public: Int32Imm() : IntergerImm(kInt32), v_(0) {} explicit Int32Imm(int v) : IntergerImm(kInt32), v_(v) { hash_ = std::hash{}(v_); } ~Int32Imm() override = default; MS_DECLARE_PARENT(Int32Imm, IntergerImm) std::size_t hash() const override { return hash_; } bool IsZero() override { return v_ == 0; } bool IsOne() override { return v_ == 1; } int32_t value() const { return v_; } bool operator==(const Value &other) const override; bool operator==(const Int32Imm &other) const; std::string ToString() const override { return std::to_string(v_); } std::string DumpText() const override { std::ostringstream oss; oss << "I32(" << v_ << ")"; return oss.str(); } private: int32_t v_; }; using Int32ImmPtr = std::shared_ptr; IMM_TRAITS(Int32ImmPtr, int32_t) class Int64Imm : public IntergerImm { public: Int64Imm() : IntergerImm(kInt64), v_(0) {} explicit Int64Imm(int64_t v) : IntergerImm(kInt64), v_(v) { hash_ = std::hash{}(v_); } ~Int64Imm() override = default; MS_DECLARE_PARENT(Int64Imm, IntergerImm) std::size_t hash() const override { return hash_; } bool IsZero() override { return v_ == 0; } bool IsOne() override { return v_ == 1; } int64_t value() const { return v_; } bool operator==(const Value &other) const override; bool operator==(const Int64Imm &other) const; std::string ToString() const override { return std::to_string(v_); } std::string DumpText() const override { std::ostringstream oss; oss << "I64(" << v_ << ")"; return oss.str(); } private: int64_t v_; }; using Int64ImmPtr = std::shared_ptr; IMM_TRAITS(Int64ImmPtr, int64_t) class UInt8Imm : public IntergerImm { public: UInt8Imm() : IntergerImm(kUInt8), v_(0) {} explicit UInt8Imm(uint8_t v) : IntergerImm(kUInt8), v_(v) { hash_ = std::hash{}(v_); } ~UInt8Imm() override = default; MS_DECLARE_PARENT(UInt8Imm, IntergerImm) std::size_t hash() const override { return hash_; } bool IsZero() override { return v_ == 0; } bool IsOne() override { return v_ == 1; } uint8_t value() const { return v_; } bool operator==(const Value &other) const override; bool operator==(const UInt8Imm &other) const; std::string ToString() const override { return std::to_string(v_); } std::string DumpText() const override { std::ostringstream oss; oss << "U8(" << v_ << ")"; return oss.str(); } private: uint8_t v_; }; using UInt8ImmPtr = std::shared_ptr; IMM_TRAITS(UInt8ImmPtr, uint8_t); class UInt16Imm : public IntergerImm { public: UInt16Imm() : IntergerImm(kUInt16), v_(0) {} explicit UInt16Imm(uint16_t v) : IntergerImm(kUInt16), v_(v) { hash_ = std::hash{}(v_); } ~UInt16Imm() override = default; MS_DECLARE_PARENT(UInt16Imm, IntergerImm) std::size_t hash() const override { return hash_; } bool IsZero() override { return v_ == 0; } bool IsOne() override { return v_ == 1; } uint16_t value() const { return v_; } bool operator==(const Value &other) const override; bool operator==(const UInt16Imm &other) const; std::string ToString() const override { return std::to_string(v_); } std::string DumpText() const override { std::ostringstream oss; oss << "U16(" << v_ << ")"; return oss.str(); } private: uint16_t v_; }; using UInt16ImmPtr = std::shared_ptr; IMM_TRAITS(UInt16ImmPtr, uint16_t); class UInt32Imm : public IntergerImm { public: UInt32Imm() : IntergerImm(kUInt32), v_(0) {} explicit UInt32Imm(uint32_t v) : IntergerImm(kUInt32), v_(v) { hash_ = std::hash{}(v_); } ~UInt32Imm() override = default; MS_DECLARE_PARENT(UInt32Imm, IntergerImm) std::size_t hash() const override { return hash_; } bool IsZero() override { return v_ == 0; } bool IsOne() override { return v_ == 1; } uint32_t value() const { return v_; } bool operator==(const Value &other) const override; bool operator==(const UInt32Imm &other) const; std::string ToString() const override { return std::to_string(v_); } std::string DumpText() const override { std::ostringstream oss; oss << "U32(" << v_ << ")"; return oss.str(); } private: uint32_t v_; }; using UInt32ImmPtr = std::shared_ptr; IMM_TRAITS(UInt32ImmPtr, uint32_t); class UInt64Imm : public IntergerImm { public: UInt64Imm() : IntergerImm(kUInt64), v_(0) {} explicit UInt64Imm(uint64_t v) : IntergerImm(kUInt64), v_(v) { hash_ = std::hash{}(v); } ~UInt64Imm() override = default; MS_DECLARE_PARENT(UInt64Imm, IntergerImm) std::size_t hash() const override { return hash_; } bool IsZero() override { return v_ == 0; } bool IsOne() override { return v_ == 1; } uint64_t value() const { return v_; } bool operator==(const Value &other) const override; bool operator==(const UInt64Imm &other) const; std::string ToString() const override { return std::to_string(v_); } std::string DumpText() const override { std::ostringstream oss; oss << "U64(" << v_ << ")"; return oss.str(); } private: uint64_t v_; }; using UInt64ImmPtr = std::shared_ptr; IMM_TRAITS(UInt64ImmPtr, uint64_t); class FloatImm : public Scalar { public: FloatImm() = default; explicit FloatImm(const TypePtr &t) : Scalar(t) {} ~FloatImm() override = default; MS_DECLARE_PARENT(FloatImm, Scalar) }; using FloatImmPtr = std::shared_ptr; class FP32Imm : public FloatImm { public: FP32Imm() : FloatImm(kFloat32), v_(0.0) {} explicit FP32Imm(float v) : FloatImm(kFloat32), v_(v) { hash_ = std::hash{}(v_); } ~FP32Imm() override = default; MS_DECLARE_PARENT(FP32Imm, FloatImm) std::size_t hash() const override { return hash_; } bool IsZero() override { return fabs(v_) <= FLT_EPSILON; } bool IsOne() override { return fabs(v_ - 1.0) <= FLT_EPSILON; } float value() const { return v_; } bool operator==(const Value &other) const override; bool operator==(const FP32Imm &other) const; std::string ToString() const override { return std::to_string(v_); } std::string DumpText() const override { std::ostringstream oss; oss << "F32(" << v_ << ")"; return oss.str(); } private: float v_; }; using FP32ImmPtr = std::shared_ptr; IMM_TRAITS(FP32ImmPtr, float) class FP64Imm : public FloatImm { public: FP64Imm() : FloatImm(kFloat64), v_(0.0) {} explicit FP64Imm(double v) : FloatImm(kFloat64), v_(v) { hash_ = std::hash{}(v_); } ~FP64Imm() override = default; MS_DECLARE_PARENT(FP64Imm, FloatImm) std::size_t hash() const override { return hash_; } bool IsZero() override { return fabs(v_) <= DBL_EPSILON; } bool IsOne() override { return fabs(v_ - 1.0) <= DBL_EPSILON; } double value() const { return v_; } bool operator==(const Value &other) const override; bool operator==(const FP64Imm &other) const; std::string ToString() const override { return std::to_string(v_); } std::string DumpText() const override { std::ostringstream oss; oss << "F64(" << v_ << ")"; return oss.str(); } private: double v_; }; using FP64ImmPtr = std::shared_ptr; IMM_TRAITS(FP64ImmPtr, double) } // namespace mindspore #endif // MINDSPORE_CCSRC_IR_SCALAR_H_