/** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_UTILS_COPLEX_H_ #define MINDSPORE_CCSRC_UTILS_COPLEX_H_ #include #include #ifdef ENABLE_GPU #include #endif #include "base/float16.h" #if defined(__CUDACC__) #define HOST_DEVICE __host__ __device__ #else #define HOST_DEVICE #endif namespace mindspore { namespace utils { // Implement Complex for mindspore, inspired by std::complex. template struct alignas(sizeof(T) * 2) Complex { Complex() = default; ~Complex() = default; Complex(const Complex &other) noexcept = default; Complex(Complex &&other) noexcept = default; Complex &operator=(const Complex &other) noexcept = default; Complex &operator=(Complex &&other) noexcept = default; HOST_DEVICE inline constexpr Complex(const T &real, const T &imag = T()) : real_(real), imag_(imag) {} template inline explicit constexpr Complex(const std::complex &other) : Complex(other.real(), other.imag()) {} template inline explicit constexpr operator std::complex() const { return std::complex(std::complex(real(), imag())); } HOST_DEVICE inline explicit constexpr Complex(const float16 &real) : real_(static_cast(real)), imag_(T()) {} #if defined(__CUDACC__) template HOST_DEVICE inline explicit Complex(const thrust::complex &other) : real_(other.real()), imag_(other.imag()) {} template HOST_DEVICE inline HOST_DEVICE explicit operator thrust::complex() const { return static_cast>(thrust::complex(real(), imag())); } #endif template HOST_DEVICE explicit Complex(const std::enable_if_t::value, Complex> &other) : real_(other.real()), imag_(other.imag()) {} template HOST_DEVICE explicit Complex(const std::enable_if_t::value, Complex> &other) : real_(other.real()), imag_(other.imag()) {} HOST_DEVICE inline explicit operator bool() const { return static_cast(real_) || static_cast(imag_); } HOST_DEVICE inline explicit operator signed char() const { return static_cast(real_); } HOST_DEVICE inline explicit operator unsigned char() const { return static_cast(real_); } HOST_DEVICE inline explicit operator double() const { return static_cast(real_); } HOST_DEVICE inline explicit operator float() const { return static_cast(real_); } HOST_DEVICE inline explicit operator int16_t() const { return static_cast(real_); } HOST_DEVICE inline explicit operator uint16_t() const { return static_cast(real_); } HOST_DEVICE inline explicit operator int32_t() const { return static_cast(real_); } HOST_DEVICE inline explicit operator uint32_t() const { return static_cast(real_); } HOST_DEVICE inline explicit operator int64_t() const { return static_cast(real_); } HOST_DEVICE inline explicit operator uint64_t() const { return static_cast(real_); } HOST_DEVICE inline explicit operator float16() const { return static_cast(real_); } HOST_DEVICE inline constexpr Complex &operator=(const T &real) { real_ = real; imag_ = T(); return *this; } HOST_DEVICE inline Complex &operator+=(const T &real) { real_ += real; return *this; } HOST_DEVICE inline Complex &operator-=(const T &real) { real_ -= real; return *this; } HOST_DEVICE inline Complex &operator*=(const T &real) { real_ *= real; return *this; } // Note: check division by zero before use it. HOST_DEVICE inline Complex &operator/=(const T &real) { real_ /= real; return *this; } template HOST_DEVICE inline Complex &operator=(const Complex &z) { real_ = z.real(); imag_ = z.imag(); return *this; } template HOST_DEVICE inline Complex &operator+=(const Complex &z) { real_ += z.real(); imag_ += z.imag(); return *this; } template HOST_DEVICE inline Complex &operator-=(const Complex &z) { real_ -= z.real(); imag_ -= z.imag(); return *this; } template HOST_DEVICE inline Complex &operator*=(const Complex &z); // Note: check division by zero before use it. template HOST_DEVICE inline Complex &operator/=(const Complex &z); HOST_DEVICE inline constexpr T real() const { return real_; } HOST_DEVICE inline constexpr T imag() const { return imag_; } HOST_DEVICE inline void real(T val) { real_ = val; } HOST_DEVICE inline void imag(T val) { imag_ = val; } private: T real_; T imag_; }; template template HOST_DEVICE inline Complex &Complex::operator*=(const Complex &z) { const T real = real_ * z.real() - imag_ * z.imag(); imag_ = real_ * z.imag() + imag_ * z.real(); real_ = real; return *this; } // Note: check division by zero before use it. template template HOST_DEVICE inline Complex &Complex::operator/=(const Complex &z) { T a = real_; T b = imag_; U c = z.real(); U d = z.imag(); auto denominator = c * c + d * d; real_ = (a * c + b * d) / denominator; imag_ = (b * c - a * d) / denominator; return *this; } template HOST_DEVICE inline Complex operator+(const Complex &lhs, const Complex &rhs) { Complex result = lhs; result += rhs; return result; } template HOST_DEVICE inline Complex operator+(const Complex &lhs, const T &rhs) { Complex result = lhs; result += rhs; return result; } template HOST_DEVICE inline Complex operator+(const T &lhs, const Complex &rhs) { Complex result = rhs; result += lhs; return result; } template HOST_DEVICE inline Complex operator-(const Complex &lhs, const Complex &rhs) { Complex result = lhs; result -= rhs; return result; } template HOST_DEVICE inline Complex operator-(const Complex &lhs, const T &rhs) { Complex result = lhs; result -= rhs; return result; } template HOST_DEVICE inline Complex operator-(const T &lhs, const Complex &rhs) { Complex result(lhs, -rhs.imag()); result -= rhs.real(); return result; } template HOST_DEVICE inline Complex operator*(const Complex &lhs, const Complex &rhs) { Complex result = lhs; result *= rhs; return result; } template HOST_DEVICE inline Complex operator*(const Complex &lhs, const T &rhs) { Complex result = lhs; result *= rhs; return result; } template HOST_DEVICE inline Complex operator*(const T &lhs, const Complex &rhs) { Complex result = rhs; result *= lhs; return result; } // Note: check division by zero before use it. template HOST_DEVICE inline Complex operator/(const Complex &lhs, const Complex &rhs) { Complex result = lhs; result /= rhs; return result; } // Note: check division by zero before use it. template HOST_DEVICE inline Complex operator/(const Complex &lhs, const T &rhs) { Complex result = lhs; result /= rhs; return result; } // Note: check division by zero before use it. template HOST_DEVICE inline Complex operator/(const T &lhs, const Complex &rhs) { Complex result = lhs; result /= rhs; return result; } template HOST_DEVICE inline Complex operator+(const Complex &z) { return z; } template HOST_DEVICE inline Complex operator-(const Complex &z) { return Complex(-z.real(), -z.imag()); } template HOST_DEVICE inline bool operator==(const Complex &lhs, const Complex &rhs) { return lhs.real() == rhs.real() && lhs.imag() == rhs.imag(); } template HOST_DEVICE inline bool operator==(const T &lhs, const Complex &rhs) { return lhs == rhs.real() && rhs.imag() == 0; } template HOST_DEVICE inline bool operator==(const Complex &lhs, const T &rhs) { return lhs.real() == rhs && lhs.imag() == 0; } template HOST_DEVICE inline bool operator!=(const Complex &lhs, const Complex &rhs) { return !(lhs == rhs); } template HOST_DEVICE inline bool operator!=(const T &lhs, const Complex &rhs) { return !(lhs == rhs); } template HOST_DEVICE inline bool operator!=(const Complex &lhs, const T &rhs) { return !(lhs == rhs); } template inline std::ostream &operator<<(std::ostream &os, const Complex &v) { return (os << std::noshowpos << v.real() << std::showpos << v.imag() << 'j'); } template HOST_DEVICE inline T abs(const Complex &z) { #if defined(__CUDACC__) return thrust::abs(thrust::complex(z)); #else return std::abs(std::complex(z)); #endif } } // namespace utils } // namespace mindspore template using Complex = mindspore::utils::Complex; namespace std { template class numeric_limits> : public numeric_limits {}; } // namespace std #endif // MINDSPORE_CCSRC_UTILS_COPLEX_H_