Merge pull request !21351 from zhouyaqiang0/complex_supporttags/v1.5.0-rc1
| @@ -109,6 +109,22 @@ REGISTER_PYBIND_DEFINE( | |||
| Float data(t[0].cast<py::int_>()); | |||
| return data; | |||
| })); | |||
| (void)py::class_<Complex, Number, std::shared_ptr<Complex>>(m_sub, "Complex") | |||
| .def(py::init()) | |||
| .def(py::init<int>(), py::arg("nbits")) | |||
| .def(py::pickle( | |||
| [](const Complex &t) { // __getstate__ | |||
| /* Return a tuple that fully encodes the state of the object */ | |||
| return py::make_tuple(py::int_(t.nbits())); | |||
| }, | |||
| [](const py::tuple &t) { // __setstate__ | |||
| if (t.size() != 1) { | |||
| throw std::runtime_error("Invalid state!"); | |||
| } | |||
| /* Create a new C++ instance */ | |||
| Complex data(t[0].cast<py::int_>()); | |||
| return data; | |||
| })); | |||
| (void)py::class_<List, Type, std::shared_ptr<List>>(m_sub, "List") | |||
| .def(py::init()) | |||
| .def(py::init<std::vector<TypePtr>>(), py::arg("elements")); | |||
| @@ -20,6 +20,7 @@ | |||
| #include <sstream> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <complex> | |||
| #include "pybind_api/api_register.h" | |||
| #include "abstract/abstract_value.h" | |||
| @@ -78,9 +79,15 @@ static TypeId GetDataType(const py::buffer_info &buf) { | |||
| case '?': | |||
| return TypeId::kNumberTypeBool; | |||
| } | |||
| } else if (buf.format.size() >= 2 && buf.format.back() == 'w') { | |||
| } else if (buf.format.size() >= 2) { | |||
| // Support np.str_ dtype, format: {x}w. {x} is a number that means the maximum length of the string items. | |||
| return TypeId::kObjectTypeString; | |||
| if (buf.format.back() == 'w') { | |||
| return TypeId::kObjectTypeString; | |||
| } else if (buf.format == "Zf") { | |||
| return TypeId::kNumberTypeComplex64; | |||
| } else if (buf.format == "Zd") { | |||
| return TypeId::kNumberTypeComplex128; | |||
| } | |||
| } | |||
| MS_LOG(WARNING) << "Unsupported DataType format " << buf.format << ", item size " << buf.itemsize; | |||
| return TypeId::kTypeUnknown; | |||
| @@ -114,6 +121,10 @@ static std::string GetPyTypeFormat(TypeId data_type) { | |||
| return py::format_descriptor<bool>::format(); | |||
| case TypeId::kObjectTypeString: | |||
| return py::format_descriptor<uint8_t>::format(); | |||
| case TypeId::kNumberTypeComplex64: | |||
| return py::format_descriptor<std::complex<float>>::format(); | |||
| case TypeId::kNumberTypeComplex128: | |||
| return py::format_descriptor<std::complex<double>>::format(); | |||
| default: | |||
| MS_LOG(WARNING) << "Unsupported DataType " << data_type << "."; | |||
| return ""; | |||
| @@ -0,0 +1,323 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_UTILS_COPLEX_H_ | |||
| #define MINDSPORE_CCSRC_UTILS_COPLEX_H_ | |||
| #include <complex> | |||
| #include <limits> | |||
| #ifdef ENABLE_GPU | |||
| #include <thrust/complex.h> | |||
| #endif | |||
| #include "base/float16.h" | |||
| #if defined(__CUDACC__) | |||
| #define HOST_DEVICE __host__ __device__ | |||
| #else | |||
| #define HOST_DEVICE | |||
| #endif | |||
| namespace mindspore { | |||
| namespace utils { | |||
| // Implement Complex for mindspore, inspired by std::complex. | |||
| template <typename T> | |||
| struct alignas(sizeof(T) * 2) Complex { | |||
| Complex() = default; | |||
| ~Complex() = default; | |||
| Complex(const Complex<T> &other) noexcept = default; | |||
| Complex(Complex<T> &&other) noexcept = default; | |||
| Complex &operator=(const Complex<T> &other) noexcept = default; | |||
| Complex &operator=(Complex<T> &&other) noexcept = default; | |||
| HOST_DEVICE inline constexpr Complex(const T &real, const T &imag = T()) : real_(real), imag_(imag) {} | |||
| template <typename U> | |||
| inline explicit constexpr Complex(const std::complex<U> &other) : Complex(other.real(), other.imag()) {} | |||
| template <typename U> | |||
| inline explicit constexpr operator std::complex<U>() const { | |||
| return std::complex<U>(std::complex<T>(real(), imag())); | |||
| } | |||
| HOST_DEVICE inline explicit constexpr Complex(const float16 &real) : real_(static_cast<T>(real)), imag_(T()) {} | |||
| #if defined(__CUDACC__) | |||
| template <typename U> | |||
| HOST_DEVICE inline explicit Complex(const thrust::complex<U> &other) : real_(other.real()), imag_(other.imag()) {} | |||
| template <typename U> | |||
| HOST_DEVICE inline HOST_DEVICE explicit operator thrust::complex<U>() const { | |||
| return static_cast<thrust::complex<U>>(thrust::complex<T>(real(), imag())); | |||
| } | |||
| #endif | |||
| template <typename U = T> | |||
| HOST_DEVICE explicit Complex(const std::enable_if_t<std::is_same<U, float>::value, Complex<double>> &other) | |||
| : real_(other.real()), imag_(other.imag()) {} | |||
| template <typename U = T> | |||
| HOST_DEVICE explicit Complex(const std::enable_if_t<std::is_same<U, double>::value, Complex<float>> &other) | |||
| : real_(other.real()), imag_(other.imag()) {} | |||
| HOST_DEVICE inline explicit operator bool() const { return static_cast<bool>(real_) || static_cast<bool>(imag_); } | |||
| HOST_DEVICE inline explicit operator signed char() const { return static_cast<signed char>(real_); } | |||
| HOST_DEVICE inline explicit operator unsigned char() const { return static_cast<unsigned char>(real_); } | |||
| HOST_DEVICE inline explicit operator double() const { return static_cast<double>(real_); } | |||
| HOST_DEVICE inline explicit operator float() const { return static_cast<float>(real_); } | |||
| HOST_DEVICE inline explicit operator int16_t() const { return static_cast<int16_t>(real_); } | |||
| HOST_DEVICE inline explicit operator uint16_t() const { return static_cast<uint16_t>(real_); } | |||
| HOST_DEVICE inline explicit operator int32_t() const { return static_cast<int32_t>(real_); } | |||
| HOST_DEVICE inline explicit operator uint32_t() const { return static_cast<uint32_t>(real_); } | |||
| HOST_DEVICE inline explicit operator int64_t() const { return static_cast<int64_t>(real_); } | |||
| HOST_DEVICE inline explicit operator uint64_t() const { return static_cast<uint64_t>(real_); } | |||
| HOST_DEVICE inline explicit operator float16() const { return static_cast<float16>(real_); } | |||
| HOST_DEVICE inline constexpr Complex<T> &operator=(const T &real) { | |||
| real_ = real; | |||
| imag_ = T(); | |||
| return *this; | |||
| } | |||
| HOST_DEVICE inline Complex<T> &operator+=(const T &real) { | |||
| real_ += real; | |||
| return *this; | |||
| } | |||
| HOST_DEVICE inline Complex<T> &operator-=(const T &real) { | |||
| real_ -= real; | |||
| return *this; | |||
| } | |||
| HOST_DEVICE inline Complex<T> &operator*=(const T &real) { | |||
| real_ *= real; | |||
| return *this; | |||
| } | |||
| // Note: check division by zero before use it. | |||
| HOST_DEVICE inline Complex<T> &operator/=(const T &real) { | |||
| real_ /= real; | |||
| return *this; | |||
| } | |||
| template <typename U> | |||
| HOST_DEVICE inline Complex<T> &operator=(const Complex<U> &z) { | |||
| real_ = z.real(); | |||
| imag_ = z.imag(); | |||
| return *this; | |||
| } | |||
| template <typename U> | |||
| HOST_DEVICE inline Complex<T> &operator+=(const Complex<U> &z) { | |||
| real_ += z.real(); | |||
| imag_ += z.imag(); | |||
| return *this; | |||
| } | |||
| template <typename U> | |||
| HOST_DEVICE inline Complex<T> &operator-=(const Complex<U> &z) { | |||
| real_ -= z.real(); | |||
| imag_ -= z.imag(); | |||
| return *this; | |||
| } | |||
| template <typename U> | |||
| HOST_DEVICE inline Complex<T> &operator*=(const Complex<U> &z); | |||
| // Note: check division by zero before use it. | |||
| template <typename U> | |||
| HOST_DEVICE inline Complex<T> &operator/=(const Complex<U> &z); | |||
| HOST_DEVICE inline constexpr T real() const { return real_; } | |||
| HOST_DEVICE inline constexpr T imag() const { return imag_; } | |||
| HOST_DEVICE inline void real(T val) { real_ = val; } | |||
| HOST_DEVICE inline void imag(T val) { imag_ = val; } | |||
| private: | |||
| T real_; | |||
| T imag_; | |||
| }; | |||
| template <typename T> | |||
| template <typename U> | |||
| HOST_DEVICE inline Complex<T> &Complex<T>::operator*=(const Complex<U> &z) { | |||
| const T real = real_ * z.real() - imag_ * z.imag(); | |||
| imag_ = real_ * z.imag() + imag_ * z.real(); | |||
| real_ = real; | |||
| return *this; | |||
| } | |||
| // Note: check division by zero before use it. | |||
| template <typename T> | |||
| template <typename U> | |||
| HOST_DEVICE inline Complex<T> &Complex<T>::operator/=(const Complex<U> &z) { | |||
| T a = real_; | |||
| T b = imag_; | |||
| U c = z.real(); | |||
| U d = z.imag(); | |||
| auto denominator = c * c + d * d; | |||
| real_ = (a * c + b * d) / denominator; | |||
| imag_ = (b * c - a * d) / denominator; | |||
| return *this; | |||
| } | |||
| template <typename T> | |||
| HOST_DEVICE inline Complex<T> operator+(const Complex<T> &lhs, const Complex<T> &rhs) { | |||
| Complex<T> result = lhs; | |||
| result += rhs; | |||
| return result; | |||
| } | |||
| template <typename T> | |||
| HOST_DEVICE inline Complex<T> operator+(const Complex<T> &lhs, const T &rhs) { | |||
| Complex<T> result = lhs; | |||
| result += rhs; | |||
| return result; | |||
| } | |||
| template <typename T> | |||
| HOST_DEVICE inline Complex<T> operator+(const T &lhs, const Complex<T> &rhs) { | |||
| Complex<T> result = rhs; | |||
| result += lhs; | |||
| return result; | |||
| } | |||
| template <typename T> | |||
| HOST_DEVICE inline Complex<T> operator-(const Complex<T> &lhs, const Complex<T> &rhs) { | |||
| Complex<T> result = lhs; | |||
| result -= rhs; | |||
| return result; | |||
| } | |||
| template <typename T> | |||
| HOST_DEVICE inline Complex<T> operator-(const Complex<T> &lhs, const T &rhs) { | |||
| Complex<T> result = lhs; | |||
| result -= rhs; | |||
| return result; | |||
| } | |||
| template <typename T> | |||
| HOST_DEVICE inline Complex<T> operator-(const T &lhs, const Complex<T> &rhs) { | |||
| Complex<T> result(lhs, -rhs.imag()); | |||
| result -= rhs.real(); | |||
| return result; | |||
| } | |||
| template <typename T> | |||
| HOST_DEVICE inline Complex<T> operator*(const Complex<T> &lhs, const Complex<T> &rhs) { | |||
| Complex<T> result = lhs; | |||
| result *= rhs; | |||
| return result; | |||
| } | |||
| template <typename T> | |||
| HOST_DEVICE inline Complex<T> operator*(const Complex<T> &lhs, const T &rhs) { | |||
| Complex<T> result = lhs; | |||
| result *= rhs; | |||
| return result; | |||
| } | |||
| template <typename T> | |||
| HOST_DEVICE inline Complex<T> operator*(const T &lhs, const Complex<T> &rhs) { | |||
| Complex<T> result = rhs; | |||
| result *= lhs; | |||
| return result; | |||
| } | |||
| // Note: check division by zero before use it. | |||
| template <typename T> | |||
| HOST_DEVICE inline Complex<T> operator/(const Complex<T> &lhs, const Complex<T> &rhs) { | |||
| Complex<T> result = lhs; | |||
| result /= rhs; | |||
| return result; | |||
| } | |||
| // Note: check division by zero before use it. | |||
| template <typename T> | |||
| HOST_DEVICE inline Complex<T> operator/(const Complex<T> &lhs, const T &rhs) { | |||
| Complex<T> result = lhs; | |||
| result /= rhs; | |||
| return result; | |||
| } | |||
| // Note: check division by zero before use it. | |||
| template <typename T> | |||
| HOST_DEVICE inline Complex<T> operator/(const T &lhs, const Complex<T> &rhs) { | |||
| Complex<T> result = lhs; | |||
| result /= rhs; | |||
| return result; | |||
| } | |||
| template <typename T> | |||
| HOST_DEVICE inline Complex<T> operator+(const Complex<T> &z) { | |||
| return z; | |||
| } | |||
| template <typename T> | |||
| HOST_DEVICE inline Complex<T> operator-(const Complex<T> &z) { | |||
| return Complex<T>(-z.real(), -z.imag()); | |||
| } | |||
| template <typename T> | |||
| HOST_DEVICE inline bool operator==(const Complex<T> &lhs, const Complex<T> &rhs) { | |||
| return lhs.real() == rhs.real() && lhs.imag() == rhs.imag(); | |||
| } | |||
| template <typename T> | |||
| HOST_DEVICE inline bool operator==(const T &lhs, const Complex<T> &rhs) { | |||
| return lhs == rhs.real() && rhs.imag() == 0; | |||
| } | |||
| template <typename T> | |||
| HOST_DEVICE inline bool operator==(const Complex<T> &lhs, const T &rhs) { | |||
| return lhs.real() == rhs && lhs.imag() == 0; | |||
| } | |||
| template <typename T> | |||
| HOST_DEVICE inline bool operator!=(const Complex<T> &lhs, const Complex<T> &rhs) { | |||
| return !(lhs == rhs); | |||
| } | |||
| template <typename T> | |||
| HOST_DEVICE inline bool operator!=(const T &lhs, const Complex<T> &rhs) { | |||
| return !(lhs == rhs); | |||
| } | |||
| template <typename T> | |||
| HOST_DEVICE inline bool operator!=(const Complex<T> &lhs, const T &rhs) { | |||
| return !(lhs == rhs); | |||
| } | |||
| template <typename T> | |||
| inline std::ostream &operator<<(std::ostream &os, const Complex<T> &v) { | |||
| return (os << std::noshowpos << v.real() << std::showpos << v.imag() << 'j'); | |||
| } | |||
| template <typename T> | |||
| HOST_DEVICE inline T abs(const Complex<T> &z) { | |||
| #if defined(__CUDACC__) | |||
| return thrust::abs(thrust::complex<T>(z)); | |||
| #else | |||
| return std::abs(std::complex<T>(z)); | |||
| #endif | |||
| } | |||
| } // namespace utils | |||
| } // namespace mindspore | |||
| template <typename T> | |||
| using Complex = mindspore::utils::Complex<T>; | |||
| namespace std { | |||
| template <typename T> | |||
| class numeric_limits<mindspore::utils::Complex<T>> : public numeric_limits<T> {}; | |||
| } // namespace std | |||
| #endif // MINDSPORE_CCSRC_UTILS_COPLEX_H_ | |||
| @@ -38,7 +38,8 @@ __dtype__ = [ | |||
| "number", "tensor", | |||
| "string", "type_none", | |||
| "tensor_type", | |||
| "Type", "Int" | |||
| "Type", "Int", | |||
| "complex64", "complex128" | |||
| ] | |||
| __method__ = [ | |||
| @@ -77,6 +78,8 @@ float32 = typing.Float(32) | |||
| single = float32 | |||
| float64 = typing.Float(64) | |||
| double = float64 | |||
| complex64 = typing.Complex(64) | |||
| complex128 = typing.Complex(128) | |||
| number = typing.Number() | |||
| int_ = typing.Int() | |||
| @@ -124,14 +127,16 @@ number_type = (int8, | |||
| uint64, | |||
| float16, | |||
| float32, | |||
| float64,) | |||
| float64, | |||
| complex64, | |||
| complex128,) | |||
| int_type = (int8, int16, int32, int64,) | |||
| uint_type = (uint8, uint16, uint32, uint64,) | |||
| float_type = (float16, float32, float64,) | |||
| implicit_conversion_seq = {t: idx for idx, t in enumerate(( | |||
| bool_, int8, uint8, int16, int32, int64, float16, float32, float64))} | |||
| bool_, int8, uint8, int16, int32, int64, float16, float32, float64, complex64, complex128))} | |||
| _simple_types = { | |||
| list: list_, | |||
| @@ -140,6 +145,7 @@ _simple_types = { | |||
| bool: bool_, | |||
| int: int64, | |||
| float: float64, | |||
| complex: complex128, | |||
| str: string, | |||
| np.bool_: bool_, | |||
| np.str: string, | |||
| @@ -228,6 +234,8 @@ def dtype_to_nptype(type_): | |||
| float16: np.float16, | |||
| float32: np.float32, | |||
| float64: np.float64, | |||
| complex64: np.complex64, | |||
| complex128: np.complex128, | |||
| }[type_] | |||
| @@ -260,6 +268,8 @@ def dtype_to_pytype(type_): | |||
| list_: list, | |||
| tuple_: tuple, | |||
| string: str, | |||
| complex64: complex, | |||
| complex128: complex, | |||
| type_none: type(None) | |||
| }[type_] | |||
| @@ -26,7 +26,7 @@ from .._checkparam import Validator as validator | |||
| __all__ = ['Tensor', 'RowTensor', 'SparseTensor'] | |||
| np_types = (np.int8, np.int16, np.int32, np.int64, | |||
| np.uint8, np.uint16, np.uint32, np.uint64, np.float16, | |||
| np.float32, np.float64, np.bool_) | |||
| np.float32, np.float64, np.bool_, np.complex64, np.complex128) | |||
| class Tensor(Tensor_): | |||
| @@ -91,7 +91,7 @@ class Tensor(Tensor_): | |||
| validator.check_value_type('input_data', input_data, (Tensor_, np.ndarray, list, tuple, float, int, bool), | |||
| 'Tensor') | |||
| valid_dtypes = (np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64, | |||
| np.float16, np.float32, np.float64, np.bool_, np.str_) | |||
| np.float16, np.float32, np.float64, np.bool_, np.str_, np.complex64, np.complex128) | |||
| if isinstance(input_data, np.ndarray) and input_data.dtype not in valid_dtypes and \ | |||
| input_data.dtype.kind != 'U': # Support dtype np.str_ | |||
| raise TypeError(f"For Tensor, the input_data is a numpy array, " | |||
| @@ -27,11 +27,12 @@ | |||
| namespace mindspore { | |||
| namespace abstract { | |||
| const std::map<TypeId, size_t> type_map = {{kNumberTypeBool, 1}, {kNumberTypeInt, 4}, {kNumberTypeInt8, 1}, | |||
| {kNumberTypeInt16, 2}, {kNumberTypeInt32, 4}, {kNumberTypeInt64, 8}, | |||
| {kNumberTypeUInt, 4}, {kNumberTypeUInt8, 1}, {kNumberTypeUInt16, 2}, | |||
| {kNumberTypeUInt32, 4}, {kNumberTypeUInt64, 8}, {kNumberTypeFloat, 4}, | |||
| {kNumberTypeFloat16, 2}, {kNumberTypeFloat32, 4}, {kNumberTypeFloat64, 8}}; | |||
| const std::map<TypeId, size_t> type_map = { | |||
| {kNumberTypeBool, 1}, {kNumberTypeInt, 4}, {kNumberTypeInt8, 1}, {kNumberTypeInt16, 2}, | |||
| {kNumberTypeInt32, 4}, {kNumberTypeInt64, 8}, {kNumberTypeUInt, 4}, {kNumberTypeUInt8, 1}, | |||
| {kNumberTypeUInt16, 2}, {kNumberTypeUInt32, 4}, {kNumberTypeUInt64, 8}, {kNumberTypeFloat, 4}, | |||
| {kNumberTypeFloat16, 2}, {kNumberTypeFloat32, 4}, {kNumberTypeFloat64, 8}, {kNumberTypeComplex64, 8}, | |||
| {kNumberTypeComplex128, 16}}; | |||
| ValuePtr ValueJoin(const ValuePtr &value1, const ValuePtr &value2) { | |||
| MS_EXCEPTION_IF_NULL(value1); | |||
| @@ -0,0 +1,75 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_BASE_COMPLEX_STORAGE_H_ | |||
| #define MINDSPORE_CORE_BASE_COMPLEX_STORAGE_H_ | |||
| #include "base/float16.h" | |||
| namespace mindspore { | |||
| template <typename T> | |||
| struct alignas(sizeof(T) * 2) ComplexStorage { | |||
| T real_; | |||
| T imag_; | |||
| ComplexStorage() = default; | |||
| ~ComplexStorage() = default; | |||
| ComplexStorage(const ComplexStorage<T> &other) noexcept = default; | |||
| ComplexStorage(ComplexStorage<T> &&other) noexcept = default; | |||
| ComplexStorage &operator=(const ComplexStorage<T> &other) noexcept = default; | |||
| ComplexStorage &operator=(ComplexStorage<T> &&other) noexcept = default; | |||
| inline constexpr ComplexStorage(const T &real, const T &imag = T()) : real_(real), imag_(imag) {} | |||
| inline explicit constexpr ComplexStorage(const float16 &real) : real_(static_cast<T>(real)), imag_(T()) {} | |||
| template <typename U = T> | |||
| explicit ComplexStorage(const std::enable_if_t<std::is_same<U, float>::value, ComplexStorage<double>> &other) | |||
| : real_(other.real_), imag_(other.imag_) {} | |||
| template <typename U = T> | |||
| explicit ComplexStorage(const std::enable_if_t<std::is_same<U, double>::value, ComplexStorage<float>> &other) | |||
| : real_(other.real_), imag_(other.imag_) {} | |||
| inline explicit operator bool() const { return static_cast<bool>(real_) || static_cast<bool>(imag_); } | |||
| inline explicit operator signed char() const { return static_cast<signed char>(real_); } | |||
| inline explicit operator unsigned char() const { return static_cast<unsigned char>(real_); } | |||
| inline explicit operator double() const { return static_cast<double>(real_); } | |||
| inline explicit operator float() const { return static_cast<float>(real_); } | |||
| inline explicit operator int16_t() const { return static_cast<int16_t>(real_); } | |||
| inline explicit operator uint16_t() const { return static_cast<uint16_t>(real_); } | |||
| inline explicit operator int32_t() const { return static_cast<int32_t>(real_); } | |||
| inline explicit operator uint32_t() const { return static_cast<uint32_t>(real_); } | |||
| inline explicit operator int64_t() const { return static_cast<int64_t>(real_); } | |||
| inline explicit operator uint64_t() const { return static_cast<uint64_t>(real_); } | |||
| inline explicit operator float16() const { return static_cast<float16>(real_); } | |||
| }; | |||
| template <typename T> | |||
| inline bool operator==(const ComplexStorage<T> &lhs, const ComplexStorage<T> &rhs) { | |||
| return lhs.real_ == rhs.real_ && lhs.imag_ == rhs.imag_; | |||
| } | |||
| template <typename T> | |||
| inline std::ostream &operator<<(std::ostream &os, const ComplexStorage<T> &v) { | |||
| return (os << std::noshowpos << v.real_ << std::showpos << v.imag_ << 'j'); | |||
| } | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_BASE_COMPLEX_STORAGE_H_ | |||
| @@ -46,4 +46,10 @@ Float::Float(const int nbits) : Number(FloatBitsToTypeId(nbits), nbits, false) { | |||
| MS_LOG(EXCEPTION) << "Wrong number of bits."; | |||
| } | |||
| } | |||
| Complex::Complex(const int nbits) : Number(ComplexBitsToTypeId(nbits), nbits, false) { | |||
| if (nbits != 64 && nbits != 128) { | |||
| MS_LOG(EXCEPTION) << "Wrong number of bits."; | |||
| } | |||
| } | |||
| } // namespace mindspore | |||
| @@ -150,20 +150,19 @@ class MS_CORE_API Float : public Number { | |||
| } | |||
| }; | |||
| // Complex64 | |||
| class MS_CORE_API Complex64 : public Number { | |||
| // Complex | |||
| class MS_CORE_API Complex : public Number { | |||
| public: | |||
| Complex64() : Number(kNumberTypeComplex64, 64, false) {} | |||
| ~Complex64() override {} | |||
| MS_DECLARE_PARENT(Complex64, Number) | |||
| Complex() : Number(kNumberTypeComplex64, 64, false) {} | |||
| explicit Complex(const int nbits); | |||
| ~Complex() override {} | |||
| MS_DECLARE_PARENT(Complex, Number) | |||
| TypeId generic_type_id() const override { return kNumberTypeComplex64; } | |||
| TypePtr DeepCopy() const override { return std::make_shared<Complex64>(); } | |||
| TypePtr DeepCopy() const override { return std::make_shared<Complex>(nbits()); } | |||
| std::string ToString() const override { return GetTypeName("Complex"); } | |||
| std::string ToReprString() const override { return nbits() == 0 ? "complex64_" : GetTypeName("complex64"); } | |||
| std::string DumpText() const override { | |||
| return nbits() == 0 ? std::string("Complex64") : std::string("C") + std::to_string(nbits()); | |||
| } | |||
| std::string ToReprString() const override { return GetTypeName("complex"); } | |||
| std::string DumpText() const override { return std::string("C") + std::to_string(nbits()); } | |||
| }; | |||
| inline const TypePtr kBool = std::make_shared<Bool>(); | |||
| @@ -182,7 +181,8 @@ inline const TypePtr kInt = std::make_shared<Int>(); | |||
| inline const TypePtr kUInt = std::make_shared<UInt>(); | |||
| inline const TypePtr kFloat = std::make_shared<Float>(); | |||
| inline const TypePtr kNumber = std::make_shared<Number>(); | |||
| inline const TypePtr kComplex64 = std::make_shared<Complex64>(); | |||
| inline const TypePtr kComplex64 = std::make_shared<Complex>(64); | |||
| inline const TypePtr kComplex128 = std::make_shared<Complex>(128); | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_IR_DTYPE_NUMBER_H_ | |||
| @@ -87,6 +87,7 @@ enum class BitsNum : int { | |||
| eBits16 = 16, | |||
| eBits32 = 32, | |||
| eBits64 = 64, | |||
| eBits128 = 128, | |||
| }; | |||
| TypeId IntBitsToTypeId(const int nbits) { | |||
| switch (nbits) { | |||
| @@ -131,6 +132,17 @@ TypeId FloatBitsToTypeId(const int nbits) { | |||
| } | |||
| } | |||
| TypeId ComplexBitsToTypeId(const int nbits) { | |||
| switch (nbits) { | |||
| case static_cast<int>(BitsNum::eBits64): | |||
| return kNumberTypeComplex64; | |||
| case static_cast<int>(BitsNum::eBits128): | |||
| return kNumberTypeComplex128; | |||
| default: | |||
| MS_LOG(EXCEPTION) << "Wrong number of bits:" << nbits; | |||
| } | |||
| } | |||
| const std::string &TypeIdLabel(const TypeId &v) { | |||
| static const std::string unknown("[Unknown Type Id]"); | |||
| auto iter = g_type_2_lable.find(v); | |||
| @@ -41,6 +41,7 @@ namespace mindspore { | |||
| TypeId IntBitsToTypeId(const int nbits); | |||
| TypeId UIntBitsToTypeId(const int nbits); | |||
| TypeId FloatBitsToTypeId(const int nbits); | |||
| TypeId ComplexBitsToTypeId(const int nbits); | |||
| const std::string &TypeIdLabel(const TypeId &v); | |||
| TypeId NormalizeTypeId(const TypeId type_id); | |||
| bool IsSameObjectType(const Type &lhs, const Type &rhs); | |||
| @@ -79,6 +79,7 @@ enum TypeId : int { | |||
| kNumberTypeFloat32, | |||
| kNumberTypeFloat64, | |||
| kNumberTypeComplex64, | |||
| kNumberTypeComplex128, | |||
| kNumberTypeEnd, | |||
| // | |||
| // Monad Types | |||
| @@ -61,41 +61,20 @@ bool TypeListEqual::operator()(TypePtrList const &lhs, TypePtrList const &rhs) c | |||
| } | |||
| TypePtr TypeIdToType(TypeId id) { | |||
| static std::unordered_map<TypeId, TypePtr> type_id_to_type = {{kNumberTypeFloat16, kFloat16}, | |||
| {kNumberTypeFloat, kFloat32}, | |||
| {kNumberTypeFloat32, kFloat32}, | |||
| {kNumberTypeFloat64, kFloat64}, | |||
| {kNumberTypeComplex64, kComplex64}, | |||
| {kNumberTypeInt8, kInt8}, | |||
| {kNumberTypeInt16, kInt16}, | |||
| {kNumberTypeInt32, kInt32}, | |||
| {kNumberTypeInt, kInt32}, | |||
| {kNumberTypeInt64, kInt64}, | |||
| {kNumberTypeUInt8, kUInt8}, | |||
| {kNumberTypeUInt16, kUInt16}, | |||
| {kNumberTypeUInt32, kUInt32}, | |||
| {kNumberTypeUInt64, kUInt64}, | |||
| {kNumberTypeBool, kBool}, | |||
| {kMetaTypeExternal, kTypeExternal}, | |||
| {kMetaTypeAnything, kAnyType}, | |||
| {kMetaTypeNone, kTypeNone}, | |||
| {kMetaTypeNull, kTypeNull}, | |||
| {kMetaTypeEllipsis, kTypeEllipsis}, | |||
| {kObjectTypeEnvType, kTypeEnv}, | |||
| {kObjectTypeRefKey, kRefKeyType}, | |||
| {kObjectTypeRef, kRefType}, | |||
| {kMetaTypeTypeType, kTypeType}, | |||
| {kObjectTypeString, kString}, | |||
| {kObjectTypeList, kList}, | |||
| {kObjectTypeTuple, kTuple}, | |||
| {kObjectTypeDictionary, kDict}, | |||
| {kObjectTypeSlice, kSlice}, | |||
| {kObjectTypeKeyword, kKeyword}, | |||
| {kObjectTypeTensorType, kTensorType}, | |||
| {kObjectTypeUMonad, kUMonadType}, | |||
| {kObjectTypeIOMonad, kIOMonadType}, | |||
| {kTypeUnknown, kTypeNone}, | |||
| {kMetaTypeProblem, kTypeNone}}; | |||
| static std::unordered_map<TypeId, TypePtr> type_id_to_type = { | |||
| {kNumberTypeFloat16, kFloat16}, {kNumberTypeFloat, kFloat32}, {kNumberTypeFloat32, kFloat32}, | |||
| {kNumberTypeFloat64, kFloat64}, {kNumberTypeComplex64, kComplex64}, {kNumberTypeInt8, kInt8}, | |||
| {kNumberTypeInt16, kInt16}, {kNumberTypeInt32, kInt32}, {kNumberTypeInt, kInt32}, | |||
| {kNumberTypeInt64, kInt64}, {kNumberTypeUInt8, kUInt8}, {kNumberTypeUInt16, kUInt16}, | |||
| {kNumberTypeUInt32, kUInt32}, {kNumberTypeUInt64, kUInt64}, {kNumberTypeBool, kBool}, | |||
| {kNumberTypeComplex64, kComplex64}, {kNumberTypeComplex128, kComplex128}, {kMetaTypeExternal, kTypeExternal}, | |||
| {kMetaTypeAnything, kAnyType}, {kMetaTypeNone, kTypeNone}, {kMetaTypeNull, kTypeNull}, | |||
| {kMetaTypeEllipsis, kTypeEllipsis}, {kObjectTypeEnvType, kTypeEnv}, {kObjectTypeRefKey, kRefKeyType}, | |||
| {kObjectTypeRef, kRefType}, {kMetaTypeTypeType, kTypeType}, {kObjectTypeString, kString}, | |||
| {kObjectTypeList, kList}, {kObjectTypeTuple, kTuple}, {kObjectTypeDictionary, kDict}, | |||
| {kObjectTypeSlice, kSlice}, {kObjectTypeKeyword, kKeyword}, {kObjectTypeTensorType, kTensorType}, | |||
| {kObjectTypeUMonad, kUMonadType}, {kObjectTypeIOMonad, kIOMonadType}, {kTypeUnknown, kTypeNone}, | |||
| {kMetaTypeProblem, kTypeNone}}; | |||
| const auto &it = type_id_to_type.find(id); | |||
| if (it == type_id_to_type.end()) { | |||
| MS_LOG(EXCEPTION) << "Not support the type: " << id; | |||
| @@ -31,6 +31,7 @@ | |||
| #include "abstract/utils.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "base/complex_storage.h" | |||
| namespace mindspore { | |||
| namespace tensor { | |||
| @@ -73,7 +74,10 @@ std::unique_ptr<T[]> NewData(const U *input, size_t size) { | |||
| return nullptr; | |||
| } | |||
| auto data = std::make_unique<T[]>(size); | |||
| if constexpr (!std::is_same<T, U>::value && (std::is_same<T, float16>::value || std::is_same<U, float16>::value)) { | |||
| if constexpr (!std::is_same<T, U>::value && | |||
| (std::is_same<T, float16>::value || std::is_same<U, float16>::value || | |||
| std::is_same<T, ComplexStorage<float>>::value || std::is_same<U, ComplexStorage<float>>::value || | |||
| std::is_same<T, ComplexStorage<double>>::value || std::is_same<U, ComplexStorage<double>>::value)) { | |||
| // Because float16 do not support implicit cast from/to other types, | |||
| // We can not use std::copy() on array of float16, use a loop here. | |||
| for (size_t i = 0; i < size; ++i) { | |||
| @@ -146,7 +150,11 @@ std::unique_ptr<T[]> CopyData(const ShapeVector &shape, void *const data, TypeId | |||
| return NewData<T>(buf, size); | |||
| } | |||
| case kNumberTypeComplex64: { | |||
| auto buf = static_cast<double *>(data); | |||
| auto buf = static_cast<ComplexStorage<float> *>(data); | |||
| return NewData<T>(buf, size); | |||
| } | |||
| case kNumberTypeComplex128: { | |||
| auto buf = static_cast<ComplexStorage<double> *>(data); | |||
| return NewData<T>(buf, size); | |||
| } | |||
| case kObjectTypeString: { | |||
| @@ -233,7 +241,8 @@ class TensorDataImpl : public TensorData { | |||
| std::is_same<T, bool>::value || std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value || | |||
| std::is_same<T, int16_t>::value || std::is_same<T, int32_t>::value || std::is_same<T, int64_t>::value || | |||
| std::is_same<T, uint16_t>::value || std::is_same<T, uint32_t>::value || std::is_same<T, uint64_t>::value || | |||
| std::is_same<T, float16>::value || std::is_same<T, float>::value || std::is_same<T, double>::value; | |||
| std::is_same<T, float16>::value || std::is_same<T, float>::value || std::is_same<T, double>::value || | |||
| std::is_same<T, ComplexStorage<float>>::value || std::is_same<T, ComplexStorage<double>>::value; | |||
| static_assert(valid, "Type is invalid"); | |||
| if (data_size_ == 0) { | |||
| return ""; | |||
| @@ -302,10 +311,14 @@ class TensorDataImpl : public TensorData { | |||
| constexpr auto isBool = std::is_same<T, bool>::value; | |||
| constexpr auto isFloat = | |||
| std::is_same<T, float16>::value || std::is_same<T, float>::value || std::is_same<T, double>::value; | |||
| constexpr auto isComplex = | |||
| std::is_same<T, ComplexStorage<float>>::value || std::is_same<T, ComplexStorage<double>>::value; | |||
| constexpr int linefeedThreshold = isFloat ? kThreshold1DFloat : (isBool ? kThreshold1DBool : kThreshold1DInt); | |||
| for (ssize_t i = start; i < end && (cursor + i) < static_cast<ssize_t>(data_size_); i++) { | |||
| const auto value = data_[cursor + i]; | |||
| if constexpr (isFloat) { | |||
| if constexpr (isComplex) { | |||
| ss << value; | |||
| } else if constexpr (isFloat) { | |||
| OutputFloatDataString(ss, isScalar, value); | |||
| } else if (isBool) { | |||
| OutputBoolDataString(ss, isScalar, value); | |||
| @@ -458,7 +471,9 @@ TensorDataPtr MakeTensorData(TypeId data_type, const ShapeVector &shape, const A | |||
| case kNumberTypeFloat64: | |||
| return std::make_shared<TensorDataImpl<double>>(shape, args...); | |||
| case kNumberTypeComplex64: | |||
| return std::make_shared<TensorDataImpl<double>>(shape, args...); | |||
| return std::make_shared<TensorDataImpl<ComplexStorage<float>>>(shape, args...); | |||
| case kNumberTypeComplex128: | |||
| return std::make_shared<TensorDataImpl<ComplexStorage<double>>>(shape, args...); | |||
| case kObjectTypeString: | |||
| return std::make_shared<TensorDataImpl<uint8_t>>(shape, args...); | |||
| case kObjectTypeTensorType: | |||
| @@ -908,3 +908,6 @@ class DataType: | |||
| F64_HWCN = ("float64", "HWCN") | |||
| F64_NDHWC = ("float64", "NDHWC") | |||
| F64_ChannelLast = ("float64", "ChannelLast") | |||
| C64_Default = ("complex64", "DefaultFormat") | |||
| C128_Default = ("complex128", "DefaultFormat") | |||
| @@ -0,0 +1,158 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <memory> | |||
| #include "common/common_test.h" | |||
| #include "utils/complex.h" | |||
| namespace mindspore { | |||
| class TestComplex : public UT::Common { | |||
| public: | |||
| TestComplex() {} | |||
| }; | |||
| TEST_F(TestComplex, test_size) { | |||
| ASSERT_EQ(sizeof(Complex<float>), 2 * sizeof(float)); | |||
| ASSERT_EQ(sizeof(Complex<double>), 2 * sizeof(double)); | |||
| ASSERT_EQ(alignof(Complex<float>), 2 * sizeof(float)); | |||
| ASSERT_EQ(alignof(Complex<double>), 2 * sizeof(double)); | |||
| } | |||
| template <typename T> | |||
| void test_construct() { | |||
| constexpr T real = T(1.11f); | |||
| constexpr T imag = T(2.22f); | |||
| ASSERT_EQ(Complex<T>().real(), T()); | |||
| ASSERT_EQ(Complex<T>().imag(), T()); | |||
| ASSERT_EQ(Complex<T>(real, imag).real(), real); | |||
| ASSERT_EQ(Complex<T>(real, imag).imag(), imag); | |||
| ASSERT_EQ(Complex<T>(real).real(), real); | |||
| ASSERT_EQ(Complex<T>(real).imag(), T()); | |||
| } | |||
| template <typename T1, typename T2> | |||
| void test_conver_construct() { | |||
| ASSERT_EQ(Complex<T1>(Complex<T2>(T2(1.11f), T2(2.22f))).real(), T1(1.11f)); | |||
| ASSERT_EQ(Complex<T1>(Complex<T2>(T2(1.11f), T2(2.22f))).imag(), T1(2.22f)); | |||
| } | |||
| template <typename T> | |||
| void test_conver_std_construct() { | |||
| ASSERT_EQ(Complex<T>(std::complex<T>(T(1.11f), T(2.22f))).real(), T(1.11f)); | |||
| ASSERT_EQ(Complex<T>(std::complex<T>(T(1.11f), T(2.22f))).imag(), T(2.22f)); | |||
| } | |||
| TEST_F(TestComplex, test_construct) { | |||
| test_construct<float>(); | |||
| test_construct<double>(); | |||
| test_conver_construct<float, float>(); | |||
| test_conver_construct<double, double>(); | |||
| test_conver_construct<float, double>(); | |||
| test_conver_construct<double, float>(); | |||
| test_conver_std_construct<float>(); | |||
| test_conver_std_construct<double>(); | |||
| } | |||
| template <typename T> | |||
| void test_convert_operator(T &&a) { | |||
| ASSERT_EQ(static_cast<T>(Complex<float>(a)), a); | |||
| } | |||
| TEST_F(TestComplex, test_convert_operator) { | |||
| test_convert_operator<bool>(true); | |||
| test_convert_operator<signed char>(1); | |||
| test_convert_operator<unsigned char>(1); | |||
| ASSERT_NEAR(static_cast<double>(Complex<float>(1.11)), 1.11, 0.001); | |||
| test_convert_operator<float>(1.11f); | |||
| test_convert_operator<int16_t>(1); | |||
| test_convert_operator<uint16_t>(1); | |||
| test_convert_operator<int32_t>(1); | |||
| test_convert_operator<uint32_t>(1); | |||
| test_convert_operator<int64_t>(1); | |||
| test_convert_operator<uint64_t>(1); | |||
| float16 a(1.11f); | |||
| ASSERT_EQ(static_cast<float16>(Complex<float>(a)), a); | |||
| } | |||
| TEST_F(TestComplex, test_assign_operator) { | |||
| Complex<float> a = 1.11f; | |||
| std::cout << a << std::endl; | |||
| ASSERT_EQ(a.real(), 1.11f); | |||
| ASSERT_EQ(a.imag(), float()); | |||
| a = Complex<double>(2.22f, 1.11f); | |||
| ASSERT_EQ(a.real(), 2.22f); | |||
| ASSERT_EQ(a.imag(), 1.11f); | |||
| } | |||
| template <typename T1, typename T2, typename T3> | |||
| void test_arithmetic_add(T1 lhs, T2 rhs, T3 r) { | |||
| ASSERT_EQ(lhs + rhs, r); | |||
| if constexpr (!(std::is_same<T1, float>::value || std::is_same<T1, double>::value)) { | |||
| ASSERT_EQ(lhs += rhs, r); | |||
| } | |||
| } | |||
| template <typename T1, typename T2, typename T3> | |||
| void test_arithmetic_sub(T1 lhs, T2 rhs, T3 r) { | |||
| ASSERT_EQ(lhs - rhs, r); | |||
| if constexpr (!(std::is_same<T1, float>::value || std::is_same<T1, double>::value)) { | |||
| ASSERT_EQ(lhs -= rhs, r); | |||
| } | |||
| } | |||
| template <typename T1, typename T2, typename T3> | |||
| void test_arithmetic_mul(T1 lhs, T2 rhs, T3 r) { | |||
| ASSERT_EQ(lhs * rhs, r); | |||
| if constexpr (!(std::is_same<T1, float>::value || std::is_same<T1, double>::value)) { | |||
| ASSERT_EQ(lhs *= rhs, r); | |||
| } | |||
| } | |||
| template <typename T1, typename T2, typename T3> | |||
| void test_arithmetic_div(T1 lhs, T2 rhs, T3 r) { | |||
| ASSERT_EQ(lhs / rhs, r); | |||
| if constexpr (!(std::is_same<T1, float>::value || std::is_same<T1, double>::value)) { | |||
| ASSERT_EQ(lhs /= rhs, r); | |||
| } | |||
| } | |||
| TEST_F(TestComplex, test_arithmetic) { | |||
| test_arithmetic_add<Complex<float>, Complex<float>, Complex<float>>( | |||
| Complex<float>(1.11, 2.22), Complex<float>(1.11, 2.22), Complex<float>(2.22, 4.44)); | |||
| test_arithmetic_add<Complex<float>, float, Complex<float>>(Complex<float>(1.11, 2.22), 1.11, | |||
| Complex<float>(2.22, 2.22)); | |||
| test_arithmetic_add<float, Complex<float>, Complex<float>>(1.11, Complex<float>(1.11, 2.22), | |||
| Complex<float>(2.22, 2.22)); | |||
| test_arithmetic_sub<Complex<float>, Complex<float>, Complex<float>>(Complex<float>(1.11, 2.22), | |||
| Complex<float>(1.11, 2.22), Complex<float>(0, 0)); | |||
| test_arithmetic_sub<Complex<float>, float, Complex<float>>(Complex<float>(1.11, 2.22), 1.11, Complex<float>(0, 2.22)); | |||
| test_arithmetic_sub<float, Complex<float>, Complex<float>>(1.11, Complex<float>(1.11, 2.22), | |||
| Complex<float>(0, -2.22)); | |||
| test_arithmetic_mul<Complex<float>, Complex<float>, Complex<float>>( | |||
| Complex<float>(1.11, 2.22), Complex<float>(1.11, 2.22), Complex<float>(-3.6963, 4.9284)); | |||
| test_arithmetic_mul<Complex<float>, float, Complex<float>>(Complex<float>(1.11, 2.22), 1.11, | |||
| Complex<float>(1.2321, 2.22)); | |||
| test_arithmetic_mul<float, Complex<float>, Complex<float>>(1.11, Complex<float>(1.11, 2.22), | |||
| Complex<float>(1.2321, 2.22)); | |||
| test_arithmetic_div<Complex<float>, Complex<float>, Complex<float>>(Complex<float>(1.11, 2.22), | |||
| Complex<float>(1.11, 2.22), Complex<float>(1, 0)); | |||
| test_arithmetic_div<Complex<float>, float, Complex<float>>(Complex<float>(1.11, 2.22), 1.11, Complex<float>(1, 2.22)); | |||
| test_arithmetic_div<float, Complex<float>, Complex<float>>(1.11, Complex<float>(1.11, 2.22), | |||
| Complex<float>(0.2, -0.4)); | |||
| } | |||
| } // namespace mindspore | |||
| @@ -35,6 +35,8 @@ def test_dtype_to_nptype(): | |||
| assert ms.dtype_to_nptype(ms.float16) == np.float16 | |||
| assert ms.dtype_to_nptype(ms.float32) == np.float32 | |||
| assert ms.dtype_to_nptype(ms.float64) == np.float64 | |||
| assert ms.dtype_to_nptype(ms.complex64) == np.complex64 | |||
| assert ms.dtype_to_nptype(ms.complex128) == np.complex128 | |||
| def test_dtype_to_pytype(): | |||
| @@ -51,6 +53,8 @@ def test_dtype_to_pytype(): | |||
| assert ms.dtype_to_pytype(ms.float16) == float | |||
| assert ms.dtype_to_pytype(ms.float32) == float | |||
| assert ms.dtype_to_pytype(ms.float64) == float | |||
| assert ms.dtype_to_pytype(ms.complex64) == complex | |||
| assert ms.dtype_to_pytype(ms.complex128) == complex | |||
| assert ms.dtype_to_pytype(ms.list_) == list | |||
| assert ms.dtype_to_pytype(ms.tuple_) == tuple | |||
| assert ms.dtype_to_pytype(ms.string) == str | |||
| @@ -94,6 +98,12 @@ def test_dtype(): | |||
| me_type = dtype.get_py_obj_dtype(x) | |||
| assert me_type == ms.bool_ | |||
| x = 0.1+3j | |||
| me_type = dtype.get_py_obj_dtype(type(x)) | |||
| assert me_type == ms.complex128 | |||
| me_type = dtype.get_py_obj_dtype(x) | |||
| assert me_type == ms.complex128 | |||
| # support str | |||
| # x = "string type" | |||
| @@ -74,6 +74,45 @@ def test_tensor_type_float16(): | |||
| assert t_float16.shape == (2, 3) | |||
| assert t_float16.dtype == ms.float16 | |||
| def test_tensor_type_complex64(): | |||
| np_input = np.array( | |||
| [[1+0.1j, 2j, 3+0.3j], [4-0.4j, 5, 6]], dtype=np.complex64) | |||
| t_complex64 = ms.Tensor(np_input) | |||
| assert isinstance(t_complex64, ms.Tensor) | |||
| assert t_complex64.shape == (2, 3) | |||
| assert t_complex64.dtype == ms.complex64 | |||
| assert np.all(t_complex64.asnumpy() == np_input) | |||
| def test_tensor_type_complex64_user_define(): | |||
| np_input = np.zeros([1, 2, 3]) | |||
| t_complex64 = ms.Tensor(np_input, ms.complex64) | |||
| assert isinstance(t_complex64, ms.Tensor) | |||
| assert t_complex64.shape == (1, 2, 3) | |||
| assert t_complex64.dtype == ms.complex64 | |||
| assert np.all(t_complex64.asnumpy() == np_input) | |||
| def test_tensor_type_complex128(): | |||
| np_input = np.array( | |||
| [[1+0.1j, 2j, 3+0.3j], [4-0.4j, 5, 6]], dtype=np.complex128) | |||
| t_complex128 = ms.Tensor(np_input) | |||
| assert isinstance(t_complex128, ms.Tensor) | |||
| assert t_complex128.shape == (2, 3) | |||
| assert t_complex128.dtype == ms.complex128 | |||
| assert np.all(t_complex128.asnumpy() == np_input) | |||
| np_input = (1, 2.22222222j, 3) | |||
| t_complex128 = ms.Tensor(np_input) | |||
| assert np.all(t_complex128.asnumpy() == np_input) | |||
| def test_tensor_type_complex128_user_define(): | |||
| np_input = np.zeros([1, 2, 3]) | |||
| t_complex128 = ms.Tensor(np_input, ms.complex128) | |||
| assert isinstance(t_complex128, ms.Tensor) | |||
| assert t_complex128.shape == (1, 2, 3) | |||
| assert t_complex128.dtype == ms.complex128 | |||
| assert np.all(t_complex128.asnumpy() == np_input) | |||
| def test_tensor_type_float32(): | |||
| t_float32 = ms.Tensor(np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)) | |||
| @@ -332,13 +371,6 @@ def test_tensor_input_ndarray_bool(): | |||
| inp = np.array([False, 2, 4]) | |||
| ms.Tensor(inp) | |||
| def test_tensor_input_ndarray_complex(): | |||
| with pytest.raises(TypeError): | |||
| inp = np.array([20j, 2, 4]) | |||
| ms.Tensor(inp) | |||
| def test_tensor_input_ndarray_none(): | |||
| with pytest.raises(TypeError): | |||
| inp = np.array([None, 2, 4]) | |||
| @@ -445,6 +477,19 @@ def test_tensor_dtype_fp64_to_uint8(): | |||
| assert t.shape == (2, 3) | |||
| assert t.dtype == ms.uint8 | |||
| def test_tensor_dtype_complex64_to_float32(): | |||
| array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.complex64) | |||
| t = ms.Tensor(array, ms.float32) | |||
| assert isinstance(t, ms.Tensor) | |||
| assert t.shape == (2, 3) | |||
| assert t.dtype == ms.float32 | |||
| def test_tensor_dtype_float32_to_complex64(): | |||
| array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32) | |||
| t = ms.Tensor(array, ms.complex64) | |||
| assert isinstance(t, ms.Tensor) | |||
| assert t.shape == (2, 3) | |||
| assert t.dtype == ms.complex64 | |||
| def test_tensor_operation(): | |||
| x = Tensor(np.ones((3, 3)) * 4) | |||
| @@ -200,6 +200,12 @@ def test_parameter_lazy_init(): | |||
| assert isinstance(para.data, Tensor) | |||
| assert np.array_equal(para.data.asnumpy(), np.ones((1, 2, 3))) | |||
| para = Parameter(initializer('ones', [1, 2, 3], mstype.complex64), 'test1') | |||
| assert isinstance(para.data, Tensor) | |||
| para = para.init_data() | |||
| assert isinstance(para.data, Tensor) | |||
| assert np.array_equal(para.data.asnumpy(), np.ones((1, 2, 3))) | |||
| # Call init_data() after set_data is set. | |||
| para = Parameter(initializer('ones', [1, 2, 3], mstype.float32), 'test2') | |||
| assert isinstance(para.data, Tensor) | |||