/** * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "frontend/operator/cc_implementations.h" #include #include #include #include #include "utils/log_adapter.h" #include "utils/convert_utils.h" #include "utils/ms_utils.h" namespace mindspore { // namespace to support primitive operators definition namespace prim { enum class DataType { kInt, kInt64, kFloat, kDouble, kUnknown }; // Whether has a T type data in AnyPtrList. template bool HasType(const AnyPtrList &list) { bool ret = std::any_of(list.begin(), list.end(), [](const AnyPtr &ptr) { return ptr->is(); }); return ret; } DataType InferType(const AnyPtrList &list) { if (HasType(list)) { return DataType::kDouble; } else if (HasType(list)) { return DataType::kFloat; } else if (HasType(list)) { return DataType::kInt64; } else if (HasType(list)) { return DataType::kInt; } return DataType::kUnknown; } template T InnerScalarAdd(T x, T y) { if constexpr (std::is_integral::value && std::is_signed::value) { T res; if (__builtin_add_overflow(x, y, &res)) { MS_EXCEPTION(ValueError) << "Overflow of the sum of two signed number x: " << std::to_string(x) << ", y: " << std::to_string(y) << "."; } return res; } return x + y; } template T InnerScalarSub(T x, T y) { if constexpr (std::is_integral::value && std::is_signed::value) { T res; if (__builtin_sub_overflow(x, y, &res)) { MS_EXCEPTION(ValueError) << "Overflow of the sub of two signed number x: " << std::to_string(x) << ", y: " << std::to_string(y) << "."; } return res; } return x - y; } template T InnerScalarMul(T x, T y) { if constexpr (std::is_integral::value && std::is_signed::value) { T res; if (__builtin_mul_overflow(x, y, &res)) { MS_EXCEPTION(ValueError) << "Overflow of the mul of two signed number x: " << std::to_string(x) << ", y: " << std::to_string(y) << "."; } return res; } return x * y; } template float InnerScalarDiv(T x, T y) { if (y == 0) { MS_EXCEPTION(ValueError) << "The divisor could not be zero."; } if constexpr (std::is_integral::value && std::is_signed::value) { if (x == std::numeric_limits::min() && static_cast(y) == -1) { MS_EXCEPTION(ValueError) << "Overflow of the div of two signed number x: " << std::to_string(x) << ", y: " << std::to_string(y) << "."; } } return static_cast(x) / static_cast(y); } template T InnerScalarFloordiv(T x, T y) { auto ret = std::floor(InnerScalarDiv(x, y)); return static_cast(ret); } template T InnerScalarMod(T x, T y) { if (y == 0) { MS_EXCEPTION(ValueError) << "Could not mod to zero."; } if constexpr (!std::is_integral::value) { return x - y * std::floor(x / y); } if constexpr (std::is_signed::value) { if (x == std::numeric_limits::min() && static_cast(y) == -1) { MS_EXCEPTION(ValueError) << "Overflow of the mod of two signed number x: " << std::to_string(x) << ", y: " << std::to_string(y) << "."; } } return static_cast(x) % static_cast(y); } template T InnerScalarPow(T x, U y) { return std::pow(x, y); } template bool InnerScalarEq(T x, U y) { double error = static_cast(x) - static_cast(y); error = fabs(error); return error < DBL_EPSILON; } template bool InnerScalarLt(T x, U y) { return x < y; } template bool InnerScalarGt(T x, U y) { return x > y; } template bool InnerScalarNe(T x, U y) { return !InnerScalarEq(x, y); } template bool InnerScalarLe(T x, U y) { return x <= y; } template bool InnerScalarGe(T x, U y) { return x >= y; } #define SCALAR_OP(op_t) \ ValuePtr Scalar##op_t(const ValuePtrList &list) { \ constexpr size_t kListInputSize = 2; \ if (list.size() != kListInputSize) { \ MS_EXCEPTION(NotSupportError) << "Input number of Scalar" << #op_t << " should be 2, but got " << list.size(); \ } \ const ValuePtr &x = list[0]; \ const ValuePtr &y = list[1]; \ MS_EXCEPTION_IF_NULL(x); \ MS_EXCEPTION_IF_NULL(y); \ if (x->isa() && y->isa()) { \ double sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ return MakeValue(sum); \ } \ if (x->isa() && y->isa()) { \ float sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ return MakeValue(sum); \ } \ if (x->isa() && y->isa()) { \ int sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ return MakeValue(sum); \ } \ if (x->isa() && y->isa()) { \ float sum = InnerScalar##op_t(IntToFloat(GetValue(x)), GetValue(y)); \ return MakeValue(sum); \ } \ if (x->isa() && y->isa()) { \ float sum = InnerScalar##op_t(GetValue(x), IntToFloat(GetValue(y))); \ return MakeValue(sum); \ } \ if (x->isa() && y->isa()) { \ int64_t sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ return MakeValue(sum); \ } \ if (x->isa() && y->isa()) { \ double sum = InnerScalar##op_t(LongToDouble(GetValue(x)), GetValue(y)); \ return MakeValue(sum); \ } \ if (x->isa() && y->isa()) { \ double sum = InnerScalar##op_t(LongToDouble(GetValue(x)), FloatToDouble(GetValue(y))); \ return MakeValue(sum); \ } \ if (x->isa() && y->isa()) { \ int64_t sum = InnerScalar##op_t(GetValue(x), IntToLong(GetValue(y))); \ return MakeValue(sum); \ } \ if (x->isa() && y->isa()) { \ double sum = InnerScalar##op_t(FloatToDouble(GetValue(x)), LongToDouble(GetValue(y))); \ return MakeValue(sum); \ } \ if (x->isa() && y->isa()) { \ double sum = InnerScalar##op_t(GetValue(x), LongToDouble(GetValue(y))); \ return MakeValue(sum); \ } \ if (x->isa() && y->isa()) { \ int64_t sum = InnerScalar##op_t(IntToLong(GetValue(x)), GetValue(y)); \ return MakeValue(sum); \ } \ MS_EXCEPTION(TypeError) << "Unsupported input type for Scalar" << #op_t << ", type of x:" << x->type_name() \ << ", value of x:" << x->ToString() << ", type of y:" << y->type_name() \ << ", value of y:" << y->ToString(); \ } SCALAR_OP(Add) SCALAR_OP(Sub) SCALAR_OP(Mul) SCALAR_OP(Div) SCALAR_OP(Mod) SCALAR_OP(Pow) SCALAR_OP(Floordiv) #define LOGIC_OP(op_t) \ ValuePtr Scalar##op_t(const ValuePtrList &list) { \ constexpr size_t kListInputSize = 2; \ if (list.size() != kListInputSize) { \ MS_EXCEPTION(NotSupportError) << "Input number of Scalar" << #op_t << " should be 2, but got " << list.size(); \ } \ const ValuePtr &x = list[0]; \ const ValuePtr &y = list[1]; \ MS_EXCEPTION_IF_NULL(x); \ MS_EXCEPTION_IF_NULL(y); \ if (x->isa() && y->isa()) { \ bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ return MakeValue(sum); \ } \ if (x->isa() && y->isa()) { \ bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ return MakeValue(sum); \ } \ if (x->isa() && y->isa()) { \ bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ return MakeValue(sum); \ } \ if (x->isa() && y->isa()) { \ bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ return MakeValue(sum); \ } \ if (x->isa() && y->isa()) { \ bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ return MakeValue(sum); \ } \ if (x->isa() && y->isa()) { \ bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ return MakeValue(sum); \ } \ if (x->isa() && y->isa()) { \ bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ return MakeValue(sum); \ } \ if (x->isa() && y->isa()) { \ bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ return MakeValue(sum); \ } \ if (x->isa() && y->isa()) { \ bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ return MakeValue(sum); \ } \ if (x->isa() && y->isa()) { \ bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ return MakeValue(sum); \ } \ if (x->isa() && y->isa()) { \ bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ return MakeValue(sum); \ } \ if (x->isa() && y->isa()) { \ bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ return MakeValue(sum); \ } \ if (x->isa() && y->isa()) { \ bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ return MakeValue(sum); \ } \ if (x->isa() && y->isa()) { \ bool sum = InnerScalar##op_t(GetValue(x), GetValue(y)); \ return MakeValue(sum); \ } \ MS_EXCEPTION(TypeError) << "Unsupported input type for Scalar" << #op_t << ", type of x:" << x->type_name() \ << ", value of x:" << x->ToString() << ", type of y:" << y->type_name() \ << ", value of y:" << y->ToString(); \ } LOGIC_OP(Eq) LOGIC_OP(Lt) LOGIC_OP(Gt) LOGIC_OP(Ne) LOGIC_OP(Le) LOGIC_OP(Ge) ValuePtr ScalarUAdd(const ValuePtrList &list) { if (list.size() != 1) { MS_EXCEPTION(NotSupportError) << "Input number of ScalarUAdd should be 1, but got " << list.size(); } ValuePtr x = list[0]; MS_EXCEPTION_IF_NULL(x); return x; } ValuePtr ScalarUSub(const ValuePtrList &list) { if (list.size() != 1) { MS_EXCEPTION(NotSupportError) << "Input number of ScalarUSub should be 1, but got " << list.size(); } ValuePtr x = list[0]; MS_EXCEPTION_IF_NULL(x); if (x->isa()) { int32_t sum = -1 * GetValue(x); return MakeValue(sum); } if (x->isa()) { int64_t sum = -1 * GetValue(x); return MakeValue(sum); } if (x->isa()) { float sum = -1.0f * GetValue(x); return MakeValue(sum); } MS_EXCEPTION(NotSupportError) << "Not support ScalarUSub [x:" << x->ToString() << "]."; } ValuePtr ScalarLog(const ValuePtrList &list) { if (list.size() != 1) { MS_EXCEPTION(NotSupportError) << "Input number of ScalarLog must be 1, but got " << list.size(); } ValuePtr x = list[0]; MS_EXCEPTION_IF_NULL(x); if (x->isa()) { double v = log(GetValue(x)); return MakeValue(v); } if (x->isa()) { auto v = static_cast(log(GetValue(x))); return MakeValue(v); } MS_EXCEPTION(NotSupportError) << "Not support ScalarLog [x:" << x->ToString() << "]."; } ValuePtr BoolNot(const ValuePtrList &list) { if (list.size() != 1) { MS_EXCEPTION(NotSupportError) << "Input number of BoolNot must be 1, but got " << list.size(); } ValuePtr x = list[0]; MS_EXCEPTION_IF_NULL(x); bool convert = false; if (ValueToBool(x, &convert)) { auto res = !convert; return MakeValue(res); } MS_EXCEPTION(NotSupportError) << "Not support BoolNot [x:" << x->ToString() << "]."; } ValuePtr BoolAnd(const ValuePtrList &list) { constexpr size_t kListInputSize = 2; if (list.size() != kListInputSize) { MS_EXCEPTION(NotSupportError) << "Input number of BoolAnd must be 2, but got " << list.size(); } ValuePtr x = list[0]; ValuePtr y = list[1]; MS_EXCEPTION_IF_NULL(x); MS_EXCEPTION_IF_NULL(y); bool x_b = false; bool y_b = false; if (ValueToBool(x, &x_b) && ValueToBool(y, &y_b)) { auto res = x_b && y_b; return MakeValue(res); } MS_EXCEPTION(NotSupportError) << "Not support [x:" << x->ToString() << "] BoolAnd [y:" << y->ToString(); } ValuePtr BoolOr(const ValuePtrList &list) { constexpr size_t kListInputSize = 2; if (list.size() != kListInputSize) { MS_EXCEPTION(NotSupportError) << "Input number of BoolOr must be 2, but got " << list.size(); } ValuePtr x = list[0]; ValuePtr y = list[1]; MS_EXCEPTION_IF_NULL(x); MS_EXCEPTION_IF_NULL(y); bool x_b = false; bool y_b = false; if (ValueToBool(x, &x_b) && ValueToBool(y, &y_b)) { auto res = x_b || y_b; return MakeValue(res); } MS_EXCEPTION(NotSupportError) << "Not support [x:" << x->ToString() << "] BoolOr [y:" << y->ToString() << "]."; } ValuePtr BoolEq(const ValuePtrList &list) { constexpr size_t kListInputSize = 2; if (list.size() != kListInputSize) { MS_EXCEPTION(NotSupportError) << "Input number of BoolEq must be 2, but got " << list.size(); } ValuePtr x = list[0]; ValuePtr y = list[1]; MS_EXCEPTION_IF_NULL(x); MS_EXCEPTION_IF_NULL(y); bool x_b = false; bool y_b = false; if (ValueToBool(x, &x_b) && ValueToBool(y, &y_b)) { auto res = x_b == y_b; return MakeValue(res); } MS_EXCEPTION(NotSupportError) << "Not support [x:" << x->ToString() << "] BoolEq [y:" << y->ToString() << "]."; } } // namespace prim } // namespace mindspore