You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cc_implementations.cc 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "frontend/operator/cc_implementations.h"
  17. #include <limits>
  18. #include <algorithm>
  19. #include <cmath>
  20. #include <cfloat>
  21. #include "utils/log_adapter.h"
  22. #include "utils/convert_utils.h"
  23. #include "utils/ms_utils.h"
  24. namespace mindspore {
  25. // namespace to support primitive operators definition
  26. namespace prim {
  27. enum class DataType { kInt, kFloat, kDouble, kUnknown };
  28. // Whether has a T type data in AnyPtrList.
  29. template <class T>
  30. bool HasType(const AnyPtrList &list) {
  31. bool ret = std::any_of(list.begin(), list.end(), [](const AnyPtr &ptr) { return ptr->is<T>(); });
  32. return ret;
  33. }
  34. DataType InferType(const AnyPtrList &list) {
  35. if (HasType<double>(list)) {
  36. return DataType::kDouble;
  37. } else if (HasType<float>(list)) {
  38. return DataType::kFloat;
  39. } else if (HasType<int>(list)) {
  40. return DataType::kInt;
  41. }
  42. return DataType::kUnknown;
  43. }
  44. enum OpType { ADD, SUB, MUL, DIV, MOD };
  45. template <typename T>
  46. bool IsSignedIntOverflow(T x, T y, OpType opType) {
  47. auto max = std::numeric_limits<T>::max();
  48. auto min = std::numeric_limits<T>::min();
  49. if (opType == OpType::ADD) {
  50. return (y > 0 && (max - y) < x) || (y < 0 && (min - y) > x);
  51. }
  52. if (opType == OpType::SUB) {
  53. return (y < 0 && (max + y) < x) || (y > 0 && (min + y) > x);
  54. }
  55. if (opType == OpType::MUL) {
  56. return (x > 0 && y > 0 && (max / y) < x) || (x < 0 && y < 0 && (max / y) > x) ||
  57. (x > 0 && y < 0 && (min / y) < x) || (x < 0 && y > 0 && (min / y) > x);
  58. }
  59. if (opType == OpType::DIV || opType == OpType::MOD) {
  60. return x == min && static_cast<int64_t>(y) == -1;
  61. }
  62. MS_LOG(EXCEPTION) << "Unsupported operation type.";
  63. }
  64. template <typename T>
  65. T InnerScalarAdd(T x, T y) {
  66. if (std::is_integral<T>::value && std::is_signed<T>::value && IsSignedIntOverflow(x, y, OpType::ADD)) {
  67. MS_LOG(EXCEPTION) << "Overflow of the sum of two signed number x: " << std::to_string(x)
  68. << ", y: " << std::to_string(y) << ".";
  69. }
  70. return x + y;
  71. }
  72. template <typename T>
  73. T InnerScalarSub(T x, T y) {
  74. if (std::is_integral<T>::value && std::is_signed<T>::value && IsSignedIntOverflow(x, y, OpType::SUB)) {
  75. MS_LOG(EXCEPTION) << "Overflow of the sub of two signed number x: " << std::to_string(x)
  76. << ", y: " << std::to_string(y) << ".";
  77. }
  78. return x - y;
  79. }
  80. template <typename T>
  81. T InnerScalarMul(T x, T y) {
  82. if (std::is_integral<T>::value && std::is_signed<T>::value && IsSignedIntOverflow(x, y, OpType::MUL)) {
  83. MS_LOG(EXCEPTION) << "Overflow of the mul of two signed number x: " << std::to_string(x)
  84. << ", y: " << std::to_string(y) << ".";
  85. }
  86. return x * y;
  87. }
  88. template <typename T>
  89. float InnerScalarDiv(T x, T y) {
  90. if (y == 0) {
  91. MS_LOG(EXCEPTION) << "Divisor could not be zero";
  92. }
  93. if (std::is_integral<T>::value && std::is_signed<T>::value && IsSignedIntOverflow(x, y, OpType::DIV)) {
  94. MS_LOG(EXCEPTION) << "Overflow of the div of two signed number x: " << std::to_string(x)
  95. << ", y: " << std::to_string(y) << ".";
  96. }
  97. return static_cast<float>(x) / static_cast<float>(y);
  98. }
  99. template <typename T>
  100. T InnerScalarFloordiv(T x, T y) {
  101. auto ret = std::floor(InnerScalarDiv(x, y));
  102. if (std::is_integral<T>::value) {
  103. return static_cast<int>(ret);
  104. }
  105. return ret;
  106. }
  107. template <typename T>
  108. T InnerScalarMod(T x, T y) {
  109. if (y == 0) {
  110. MS_LOG(EXCEPTION) << "Could not mod to zero.";
  111. }
  112. if (std::is_integral<T>::value && std::is_signed<T>::value && IsSignedIntOverflow(x, y, OpType::MOD)) {
  113. MS_LOG(EXCEPTION) << "Overflow of the mod of two signed number x: " << std::to_string(x)
  114. << ", y: " << std::to_string(y) << ".";
  115. }
  116. if (std::is_integral<T>::value) {
  117. return static_cast<int>(x) % static_cast<int>(y);
  118. }
  119. return x - y * std::floor(x / y);
  120. }
  121. template <typename T, typename U>
  122. T InnerScalarPow(T x, U y) {
  123. return std::pow(x, y);
  124. }
  125. template <typename T, typename U>
  126. bool InnerScalarEq(T x, U y) {
  127. double error = static_cast<double>(x) - static_cast<double>(y);
  128. error = fabs(error);
  129. return error < DBL_EPSILON;
  130. }
  131. template <typename T, typename U>
  132. bool InnerScalarLt(T x, U y) {
  133. return x < y;
  134. }
  135. template <typename T, typename U>
  136. bool InnerScalarGt(T x, U y) {
  137. return x > y;
  138. }
  139. template <typename T, typename U>
  140. bool InnerScalarNe(T x, U y) {
  141. return !InnerScalarEq(x, y);
  142. }
  143. template <typename T, typename U>
  144. bool InnerScalarLe(T x, U y) {
  145. return x <= y;
  146. }
  147. template <typename T, typename U>
  148. bool InnerScalarGe(T x, U y) {
  149. return x >= y;
  150. }
  151. #define SCALAR_OP(op_t) \
  152. ValuePtr Scalar##op_t(const ValuePtrList &list) { \
  153. do { \
  154. if (list.size() < 2) { \
  155. MS_LOG(EXCEPTION) << "length of input list for Scalar" << #op_t << " is less than 2."; \
  156. } \
  157. ValuePtr x = list[0]; \
  158. ValuePtr y = list[1]; \
  159. MS_EXCEPTION_IF_NULL(x); \
  160. MS_EXCEPTION_IF_NULL(y); \
  161. if (x->isa<FP64Imm>() && y->isa<FP64Imm>()) { \
  162. double sum = InnerScalar##op_t(GetValue<double>(x), GetValue<double>(y)); \
  163. return MakeValue(sum); \
  164. } \
  165. if (x->isa<FP32Imm>() && y->isa<FP32Imm>()) { \
  166. float sum = InnerScalar##op_t(GetValue<float>(x), GetValue<float>(y)); \
  167. return MakeValue(sum); \
  168. } \
  169. if (x->isa<Int32Imm>() && y->isa<Int32Imm>()) { \
  170. int sum = InnerScalar##op_t(GetValue<int>(x), GetValue<int>(y)); \
  171. return MakeValue(sum); \
  172. } \
  173. if (x->isa<Int32Imm>() && y->isa<FP32Imm>()) { \
  174. float sum = InnerScalar##op_t(IntToFloat(GetValue<int>(x)), GetValue<float>(y)); \
  175. return MakeValue(sum); \
  176. } \
  177. if (x->isa<FP32Imm>() && y->isa<Int32Imm>()) { \
  178. float sum = InnerScalar##op_t(GetValue<float>(x), IntToFloat(GetValue<int>(y))); \
  179. return MakeValue(sum); \
  180. } \
  181. MS_LOG(EXCEPTION) << "Unsupported Value for Scalar" << #op_t << ", x: " << x->ToString() \
  182. << ", y: " << y->ToString(); \
  183. } while (0); \
  184. }
  185. SCALAR_OP(Add)
  186. SCALAR_OP(Sub)
  187. SCALAR_OP(Mul)
  188. SCALAR_OP(Div)
  189. SCALAR_OP(Mod)
  190. SCALAR_OP(Pow)
  191. SCALAR_OP(Floordiv)
  192. #define LOGIC_OP(op_t) \
  193. ValuePtr Scalar##op_t(const ValuePtrList &list) { \
  194. if (list.size() < 2) { \
  195. MS_LOG(EXCEPTION) << "length of input list for Scalar" << #op_t << " is less than 2."; \
  196. } \
  197. ValuePtr x = list[0]; \
  198. ValuePtr y = list[1]; \
  199. MS_EXCEPTION_IF_NULL(x); \
  200. MS_EXCEPTION_IF_NULL(y); \
  201. if (x->isa<FP64Imm>() && y->isa<FP64Imm>()) { \
  202. bool sum = InnerScalar##op_t(GetValue<double>(x), GetValue<double>(y)); \
  203. return MakeValue(sum); \
  204. } \
  205. if (x->isa<FP32Imm>() && y->isa<FP32Imm>()) { \
  206. bool sum = InnerScalar##op_t(GetValue<float>(x), GetValue<float>(y)); \
  207. return MakeValue(sum); \
  208. } \
  209. if (x->isa<FP64Imm>() && y->isa<FP32Imm>()) { \
  210. bool sum = InnerScalar##op_t(GetValue<double>(x), GetValue<float>(y)); \
  211. return MakeValue(sum); \
  212. } \
  213. if (x->isa<FP32Imm>() && y->isa<FP64Imm>()) { \
  214. bool sum = InnerScalar##op_t(GetValue<float>(x), GetValue<double>(y)); \
  215. return MakeValue(sum); \
  216. } \
  217. if (x->isa<Int32Imm>() && y->isa<Int32Imm>()) { \
  218. bool sum = InnerScalar##op_t(GetValue<int>(x), GetValue<int>(y)); \
  219. return MakeValue(sum); \
  220. } \
  221. if (x->isa<FP32Imm>() && y->isa<Int32Imm>()) { \
  222. bool sum = InnerScalar##op_t(GetValue<float>(x), GetValue<int>(y)); \
  223. return MakeValue(sum); \
  224. } \
  225. if (x->isa<Int32Imm>() && y->isa<FP32Imm>()) { \
  226. bool sum = InnerScalar##op_t(GetValue<int>(x), GetValue<float>(y)); \
  227. return MakeValue(sum); \
  228. } \
  229. if (x->isa<Int64Imm>() && y->isa<Int32Imm>()) { \
  230. bool sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<int>(y)); \
  231. return MakeValue(sum); \
  232. } \
  233. MS_LOG(EXCEPTION) << "Unsupported Value for Scalar" << #op_t << ", x: " << x->ToString() \
  234. << ", y: " << y->ToString() << "."; \
  235. }
  236. LOGIC_OP(Eq)
  237. LOGIC_OP(Lt)
  238. LOGIC_OP(Gt)
  239. LOGIC_OP(Ne)
  240. LOGIC_OP(Le)
  241. LOGIC_OP(Ge)
  242. ValuePtr ScalarUAdd(const ValuePtrList &list) {
  243. if (list.size() != 1) {
  244. MS_LOG(EXCEPTION) << "Input number of ScalarUAdd should be 1, but got " << list.size();
  245. }
  246. ValuePtr x = list[0];
  247. MS_EXCEPTION_IF_NULL(x);
  248. return x;
  249. }
  250. ValuePtr ScalarUSub(const ValuePtrList &list) {
  251. if (list.size() != 1) {
  252. MS_LOG(EXCEPTION) << "Input number of ScalarUSub should be 1, but got " << list.size();
  253. }
  254. ValuePtr x = list[0];
  255. MS_EXCEPTION_IF_NULL(x);
  256. if (x->isa<Int32Imm>()) {
  257. int32_t sum = -1 * GetValue<int>(x);
  258. return MakeValue(sum);
  259. }
  260. if (x->isa<FP32Imm>()) {
  261. float sum = -1.0f * GetValue<float>(x);
  262. return MakeValue(sum);
  263. }
  264. MS_LOG(EXCEPTION) << "Unsported Value for ScalarUSub, x: " << x->ToString() << ".";
  265. }
  266. ValuePtr ScalarLog(const ValuePtrList &list) {
  267. if (list.empty()) {
  268. MS_LOG(EXCEPTION) << "Input list of ScalarLog is empty.";
  269. }
  270. ValuePtr x = list[0];
  271. MS_EXCEPTION_IF_NULL(x);
  272. if (x->isa<FP64Imm>()) {
  273. double v = log(GetValue<double>(x));
  274. return MakeValue(v);
  275. }
  276. if (x->isa<FP32Imm>()) {
  277. auto v = static_cast<float>(log(GetValue<float>(x)));
  278. return MakeValue(v);
  279. }
  280. MS_LOG(EXCEPTION) << "Unsported Value for ScalarLog, x: " << x->ToString();
  281. }
  282. ValuePtr BoolNot(const ValuePtrList &list) {
  283. if (list.empty()) {
  284. MS_LOG(EXCEPTION) << "value list of BoolNot is empty";
  285. }
  286. ValuePtr x = list[0];
  287. MS_EXCEPTION_IF_NULL(x);
  288. bool convert = false;
  289. if (ValueToBool(x, &convert)) {
  290. auto res = !convert;
  291. return MakeValue(res);
  292. }
  293. MS_LOG(EXCEPTION) << "Unsported Value for BoolNot, x: " << x->ToString();
  294. }
  295. ValuePtr BoolAnd(const ValuePtrList &list) {
  296. if (list.size() < 2) {
  297. MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolAnd is less then 2.";
  298. }
  299. ValuePtr x = list[0];
  300. ValuePtr y = list[1];
  301. MS_EXCEPTION_IF_NULL(x);
  302. MS_EXCEPTION_IF_NULL(y);
  303. bool x_b = false;
  304. bool y_b = false;
  305. if (ValueToBool(x, &x_b) && ValueToBool(y, &y_b)) {
  306. auto res = x_b && y_b;
  307. return MakeValue(res);
  308. }
  309. MS_LOG(EXCEPTION) << "Unsported Value for BoolAnd, x: " << x->ToString() << ".";
  310. }
  311. ValuePtr BoolOr(const ValuePtrList &list) {
  312. if (list.size() < 2) {
  313. MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolOr is less then 2.";
  314. }
  315. ValuePtr x = list[0];
  316. ValuePtr y = list[1];
  317. MS_EXCEPTION_IF_NULL(x);
  318. MS_EXCEPTION_IF_NULL(y);
  319. bool x_b = false;
  320. bool y_b = false;
  321. if (ValueToBool(x, &x_b) && ValueToBool(y, &y_b)) {
  322. auto res = x_b || y_b;
  323. return MakeValue(res);
  324. }
  325. MS_LOG(EXCEPTION) << "Unsported Value for BoolOr, x: " << x->ToString() << ".";
  326. }
  327. ValuePtr BoolEq(const ValuePtrList &list) {
  328. if (list.size() < 2) {
  329. MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolEq is less than 2.";
  330. }
  331. ValuePtr x = list[0];
  332. ValuePtr y = list[1];
  333. MS_EXCEPTION_IF_NULL(x);
  334. MS_EXCEPTION_IF_NULL(y);
  335. bool x_b = false;
  336. bool y_b = false;
  337. if (ValueToBool(x, &x_b) && ValueToBool(y, &y_b)) {
  338. auto res = x_b == y_b;
  339. return MakeValue(res);
  340. }
  341. MS_LOG(EXCEPTION) << "Unsported Value for BoolEq, x: " << x->ToString() << ".";
  342. }
  343. } // namespace prim
  344. } // namespace mindspore