You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cc_implementations.cc 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440
  1. /**
  2. * Copyright 2019-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "frontend/operator/cc_implementations.h"
  17. #include <limits>
  18. #include <algorithm>
  19. #include <cmath>
  20. #include <cfloat>
  21. #include "utils/log_adapter.h"
  22. #include "utils/convert_utils.h"
  23. #include "utils/ms_utils.h"
  24. namespace mindspore {
  25. // namespace to support primitive operators definition
  26. namespace prim {
  27. enum class DataType { kInt, kInt64, kFloat, kDouble, kUnknown };
  28. // Whether has a T type data in AnyPtrList.
  29. template <class T>
  30. bool HasType(const AnyPtrList &list) {
  31. bool ret = std::any_of(list.begin(), list.end(), [](const AnyPtr &ptr) { return ptr->is<T>(); });
  32. return ret;
  33. }
  34. DataType InferType(const AnyPtrList &list) {
  35. if (HasType<double>(list)) {
  36. return DataType::kDouble;
  37. } else if (HasType<float>(list)) {
  38. return DataType::kFloat;
  39. } else if (HasType<int64_t>(list)) {
  40. return DataType::kInt64;
  41. } else if (HasType<int>(list)) {
  42. return DataType::kInt;
  43. }
  44. return DataType::kUnknown;
  45. }
  46. template <typename T>
  47. T InnerScalarAdd(T x, T y) {
  48. if constexpr (std::is_integral<T>::value && std::is_signed<T>::value) {
  49. T res;
  50. if (__builtin_add_overflow(x, y, &res)) {
  51. MS_EXCEPTION(ValueError) << "Overflow of the sum of two signed number x: " << std::to_string(x)
  52. << ", y: " << std::to_string(y) << ".";
  53. }
  54. return res;
  55. }
  56. return x + y;
  57. }
  58. template <typename T>
  59. T InnerScalarSub(T x, T y) {
  60. if constexpr (std::is_integral<T>::value && std::is_signed<T>::value) {
  61. T res;
  62. if (__builtin_sub_overflow(x, y, &res)) {
  63. MS_EXCEPTION(ValueError) << "Overflow of the sub of two signed number x: " << std::to_string(x)
  64. << ", y: " << std::to_string(y) << ".";
  65. }
  66. return res;
  67. }
  68. return x - y;
  69. }
  70. template <typename T>
  71. T InnerScalarMul(T x, T y) {
  72. if constexpr (std::is_integral<T>::value && std::is_signed<T>::value) {
  73. T res;
  74. if (__builtin_mul_overflow(x, y, &res)) {
  75. MS_EXCEPTION(ValueError) << "Overflow of the mul of two signed number x: " << std::to_string(x)
  76. << ", y: " << std::to_string(y) << ".";
  77. }
  78. return res;
  79. }
  80. return x * y;
  81. }
  82. template <typename T>
  83. float InnerScalarDiv(T x, T y) {
  84. if (y == 0) {
  85. MS_EXCEPTION(ValueError) << "The divisor could not be zero. But the divisor is zero now.";
  86. }
  87. if constexpr (std::is_integral<T>::value && std::is_signed<T>::value) {
  88. if (x == std::numeric_limits<T>::min() && static_cast<int64_t>(y) == -1) {
  89. MS_EXCEPTION(ValueError) << "Overflow of the div of two signed number x: " << std::to_string(x)
  90. << ", y: " << std::to_string(y) << ".";
  91. }
  92. }
  93. return static_cast<float>(x) / static_cast<float>(y);
  94. }
  95. template <typename T>
  96. T InnerScalarFloordiv(T x, T y) {
  97. auto ret = std::floor(InnerScalarDiv(x, y));
  98. return static_cast<T>(ret);
  99. }
  100. template <typename T>
  101. T InnerScalarMod(T x, T y) {
  102. if (y == 0) {
  103. MS_EXCEPTION(ValueError) << "Cannot perform modulo operation on zero.";
  104. }
  105. if constexpr (!std::is_integral<T>::value) {
  106. return x - y * std::floor(x / y);
  107. }
  108. if constexpr (std::is_signed<T>::value) {
  109. if (x == std::numeric_limits<T>::min() && static_cast<int64_t>(y) == -1) {
  110. MS_EXCEPTION(ValueError) << "Overflow of the mod of two signed number x: " << std::to_string(x)
  111. << ", y: " << std::to_string(y) << ".";
  112. }
  113. }
  114. return static_cast<int64_t>(x) % static_cast<int64_t>(y);
  115. }
  116. template <typename T, typename U>
  117. T InnerScalarPow(T x, U y) {
  118. return std::pow(x, y);
  119. }
  120. template <typename T, typename U>
  121. bool InnerScalarEq(T x, U y) {
  122. double error = static_cast<double>(x) - static_cast<double>(y);
  123. error = fabs(error);
  124. return error < DBL_EPSILON;
  125. }
  126. template <typename T, typename U>
  127. bool InnerScalarLt(T x, U y) {
  128. return x < y;
  129. }
  130. template <typename T, typename U>
  131. bool InnerScalarGt(T x, U y) {
  132. return x > y;
  133. }
  134. template <typename T, typename U>
  135. bool InnerScalarNe(T x, U y) {
  136. return !InnerScalarEq(x, y);
  137. }
  138. template <typename T, typename U>
  139. bool InnerScalarLe(T x, U y) {
  140. return x <= y;
  141. }
  142. template <typename T, typename U>
  143. bool InnerScalarGe(T x, U y) {
  144. return x >= y;
  145. }
  146. #define SCALAR_OP(op_t) \
  147. ValuePtr Scalar##op_t(const ValuePtrList &list) { \
  148. constexpr size_t kListInputSize = 2; \
  149. if (list.size() != kListInputSize) { \
  150. MS_EXCEPTION(NotSupportError) << "Input number of Scalar" << #op_t << " should be 2, but got " << list.size(); \
  151. } \
  152. const ValuePtr &x = list[0]; \
  153. const ValuePtr &y = list[1]; \
  154. MS_EXCEPTION_IF_NULL(x); \
  155. MS_EXCEPTION_IF_NULL(y); \
  156. if (x->isa<FP64Imm>() && y->isa<FP64Imm>()) { \
  157. double sum = InnerScalar##op_t(GetValue<double>(x), GetValue<double>(y)); \
  158. return MakeValue(sum); \
  159. } \
  160. if (x->isa<FP32Imm>() && y->isa<FP32Imm>()) { \
  161. float sum = InnerScalar##op_t(GetValue<float>(x), GetValue<float>(y)); \
  162. return MakeValue(sum); \
  163. } \
  164. if (x->isa<Int32Imm>() && y->isa<Int32Imm>()) { \
  165. int sum = InnerScalar##op_t(GetValue<int>(x), GetValue<int>(y)); \
  166. return MakeValue(sum); \
  167. } \
  168. if (x->isa<Int32Imm>() && y->isa<FP32Imm>()) { \
  169. float sum = InnerScalar##op_t(IntToFloat(GetValue<int>(x)), GetValue<float>(y)); \
  170. return MakeValue(sum); \
  171. } \
  172. if (x->isa<FP32Imm>() && y->isa<Int32Imm>()) { \
  173. float sum = InnerScalar##op_t(GetValue<float>(x), IntToFloat(GetValue<int>(y))); \
  174. return MakeValue(sum); \
  175. } \
  176. if (x->isa<Int64Imm>() && y->isa<Int64Imm>()) { \
  177. int64_t sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<int64_t>(y)); \
  178. return MakeValue(sum); \
  179. } \
  180. if (x->isa<Int64Imm>() && y->isa<FP64Imm>()) { \
  181. double sum = InnerScalar##op_t(LongToDouble(GetValue<int64_t>(x)), GetValue<double>(y)); \
  182. return MakeValue(sum); \
  183. } \
  184. if (x->isa<Int64Imm>() && y->isa<FP32Imm>()) { \
  185. double sum = InnerScalar##op_t(LongToDouble(GetValue<int64_t>(x)), FloatToDouble(GetValue<float>(y))); \
  186. return MakeValue(sum); \
  187. } \
  188. if (x->isa<Int64Imm>() && y->isa<Int32Imm>()) { \
  189. int64_t sum = InnerScalar##op_t(GetValue<int64_t>(x), IntToLong(GetValue<int>(y))); \
  190. return MakeValue(sum); \
  191. } \
  192. if (x->isa<FP32Imm>() && y->isa<Int64Imm>()) { \
  193. double sum = InnerScalar##op_t(FloatToDouble(GetValue<float>(x)), LongToDouble(GetValue<int64_t>(y))); \
  194. return MakeValue(sum); \
  195. } \
  196. if (x->isa<FP64Imm>() && y->isa<Int64Imm>()) { \
  197. double sum = InnerScalar##op_t(GetValue<double>(x), LongToDouble(GetValue<int64_t>(y))); \
  198. return MakeValue(sum); \
  199. } \
  200. if (x->isa<Int32Imm>() && y->isa<Int64Imm>()) { \
  201. int64_t sum = InnerScalar##op_t(IntToLong(GetValue<int>(x)), GetValue<int64_t>(y)); \
  202. return MakeValue(sum); \
  203. } \
  204. MS_EXCEPTION(TypeError) << "Unsupported input type for Scalar" << #op_t << ", type of x:" << x->type_name() \
  205. << ", value of x:" << x->ToString() << ", type of y:" << y->type_name() \
  206. << ", value of y:" << y->ToString(); \
  207. }
  208. SCALAR_OP(Add)
  209. SCALAR_OP(Sub)
  210. SCALAR_OP(Mul)
  211. SCALAR_OP(Div)
  212. SCALAR_OP(Mod)
  213. SCALAR_OP(Pow)
  214. SCALAR_OP(Floordiv)
  215. #define LOGIC_OP(op_t) \
  216. ValuePtr Scalar##op_t(const ValuePtrList &list) { \
  217. constexpr size_t kListInputSize = 2; \
  218. if (list.size() != kListInputSize) { \
  219. MS_EXCEPTION(NotSupportError) << "Input number of Scalar" << #op_t << " should be 2, but got " << list.size(); \
  220. } \
  221. const ValuePtr &x = list[0]; \
  222. const ValuePtr &y = list[1]; \
  223. MS_EXCEPTION_IF_NULL(x); \
  224. MS_EXCEPTION_IF_NULL(y); \
  225. if (x->isa<FP64Imm>() && y->isa<FP64Imm>()) { \
  226. bool sum = InnerScalar##op_t(GetValue<double>(x), GetValue<double>(y)); \
  227. return MakeValue(sum); \
  228. } \
  229. if (x->isa<FP32Imm>() && y->isa<FP32Imm>()) { \
  230. bool sum = InnerScalar##op_t(GetValue<float>(x), GetValue<float>(y)); \
  231. return MakeValue(sum); \
  232. } \
  233. if (x->isa<FP64Imm>() && y->isa<FP32Imm>()) { \
  234. bool sum = InnerScalar##op_t(GetValue<double>(x), GetValue<float>(y)); \
  235. return MakeValue(sum); \
  236. } \
  237. if (x->isa<FP32Imm>() && y->isa<FP64Imm>()) { \
  238. bool sum = InnerScalar##op_t(GetValue<float>(x), GetValue<double>(y)); \
  239. return MakeValue(sum); \
  240. } \
  241. if (x->isa<Int32Imm>() && y->isa<Int32Imm>()) { \
  242. bool sum = InnerScalar##op_t(GetValue<int>(x), GetValue<int>(y)); \
  243. return MakeValue(sum); \
  244. } \
  245. if (x->isa<FP32Imm>() && y->isa<Int32Imm>()) { \
  246. bool sum = InnerScalar##op_t(GetValue<float>(x), GetValue<int>(y)); \
  247. return MakeValue(sum); \
  248. } \
  249. if (x->isa<FP32Imm>() && y->isa<Int64Imm>()) { \
  250. bool sum = InnerScalar##op_t(GetValue<float>(x), GetValue<int64_t>(y)); \
  251. return MakeValue(sum); \
  252. } \
  253. if (x->isa<Int32Imm>() && y->isa<FP32Imm>()) { \
  254. bool sum = InnerScalar##op_t(GetValue<int>(x), GetValue<float>(y)); \
  255. return MakeValue(sum); \
  256. } \
  257. if (x->isa<Int64Imm>() && y->isa<FP32Imm>()) { \
  258. bool sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<float>(y)); \
  259. return MakeValue(sum); \
  260. } \
  261. if (x->isa<Int64Imm>() && y->isa<Int64Imm>()) { \
  262. bool sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<int64_t>(y)); \
  263. return MakeValue(sum); \
  264. } \
  265. if (x->isa<FP64Imm>() && y->isa<Int64Imm>()) { \
  266. bool sum = InnerScalar##op_t(GetValue<double>(x), GetValue<int64_t>(y)); \
  267. return MakeValue(sum); \
  268. } \
  269. if (x->isa<Int64Imm>() && y->isa<FP64Imm>()) { \
  270. bool sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<double>(y)); \
  271. return MakeValue(sum); \
  272. } \
  273. if (x->isa<Int64Imm>() && y->isa<Int32Imm>()) { \
  274. bool sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<int>(y)); \
  275. return MakeValue(sum); \
  276. } \
  277. if (x->isa<Int32Imm>() && y->isa<Int64Imm>()) { \
  278. bool sum = InnerScalar##op_t(GetValue<int>(x), GetValue<int64_t>(y)); \
  279. return MakeValue(sum); \
  280. } \
  281. MS_EXCEPTION(TypeError) << "Unsupported input type for Scalar" << #op_t << ", type of x:" << x->type_name() \
  282. << ", value of x:" << x->ToString() << ", type of y:" << y->type_name() \
  283. << ", value of y:" << y->ToString(); \
  284. }
  285. LOGIC_OP(Eq)
  286. LOGIC_OP(Lt)
  287. LOGIC_OP(Gt)
  288. LOGIC_OP(Ne)
  289. LOGIC_OP(Le)
  290. LOGIC_OP(Ge)
  291. ValuePtr ScalarUAdd(const ValuePtrList &list) {
  292. if (list.size() != 1) {
  293. MS_EXCEPTION(NotSupportError) << "Input number of ScalarUAdd should be 1, but got " << list.size() << ".";
  294. }
  295. ValuePtr x = list[0];
  296. MS_EXCEPTION_IF_NULL(x);
  297. return x;
  298. }
  299. ValuePtr ScalarUSub(const ValuePtrList &list) {
  300. if (list.size() != 1) {
  301. MS_EXCEPTION(NotSupportError) << "Input number of ScalarUSub should be 1, but got " << list.size() << ".";
  302. }
  303. ValuePtr x = list[0];
  304. MS_EXCEPTION_IF_NULL(x);
  305. if (x->isa<Int32Imm>()) {
  306. int32_t sum = -1 * GetValue<int32_t>(x);
  307. return MakeValue(sum);
  308. }
  309. if (x->isa<Int64Imm>()) {
  310. int64_t sum = -1 * GetValue<int64_t>(x);
  311. return MakeValue(sum);
  312. }
  313. if (x->isa<FP32Imm>()) {
  314. float sum = -1.0f * GetValue<float>(x);
  315. return MakeValue(sum);
  316. }
  317. MS_EXCEPTION(NotSupportError) << "Not support ScalarUSub [x:" << x->ToString() << "].";
  318. }
  319. ValuePtr ScalarLog(const ValuePtrList &list) {
  320. if (list.size() != 1) {
  321. MS_EXCEPTION(NotSupportError) << "Input number of ScalarLog must be 1, but got " << list.size() << ".";
  322. }
  323. ValuePtr x = list[0];
  324. MS_EXCEPTION_IF_NULL(x);
  325. if (x->isa<FP64Imm>()) {
  326. double v = log(GetValue<double>(x));
  327. return MakeValue(v);
  328. }
  329. if (x->isa<FP32Imm>()) {
  330. auto v = static_cast<float>(log(GetValue<float>(x)));
  331. return MakeValue(v);
  332. }
  333. MS_EXCEPTION(NotSupportError) << "Not support ScalarLog [x:" << x->ToString() << "].";
  334. }
  335. ValuePtr BoolNot(const ValuePtrList &list) {
  336. if (list.size() != 1) {
  337. MS_EXCEPTION(NotSupportError) << "Input number of BoolNot must be 1, but got " << list.size() << ".";
  338. }
  339. ValuePtr x = list[0];
  340. MS_EXCEPTION_IF_NULL(x);
  341. bool convert = false;
  342. if (ValueToBool(x, &convert)) {
  343. auto res = !convert;
  344. return MakeValue(res);
  345. }
  346. MS_EXCEPTION(NotSupportError) << "Not support BoolNot [x:" << x->ToString() << "].";
  347. }
  348. ValuePtr BoolAnd(const ValuePtrList &list) {
  349. constexpr size_t kListInputSize = 2;
  350. if (list.size() != kListInputSize) {
  351. MS_EXCEPTION(NotSupportError) << "Input number of BoolAnd must be 2, but got " << list.size() << ".";
  352. }
  353. ValuePtr x = list[0];
  354. ValuePtr y = list[1];
  355. MS_EXCEPTION_IF_NULL(x);
  356. MS_EXCEPTION_IF_NULL(y);
  357. bool x_b = false;
  358. bool y_b = false;
  359. if (ValueToBool(x, &x_b) && ValueToBool(y, &y_b)) {
  360. auto res = x_b && y_b;
  361. return MakeValue(res);
  362. }
  363. MS_EXCEPTION(NotSupportError) << "Not support [x:" << x->ToString() << "] BoolAnd [y:" << y->ToString() << "].";
  364. }
  365. ValuePtr BoolOr(const ValuePtrList &list) {
  366. constexpr size_t kListInputSize = 2;
  367. if (list.size() != kListInputSize) {
  368. MS_EXCEPTION(NotSupportError) << "Input number of BoolOr must be 2, but got " << list.size() << ".";
  369. }
  370. ValuePtr x = list[0];
  371. ValuePtr y = list[1];
  372. MS_EXCEPTION_IF_NULL(x);
  373. MS_EXCEPTION_IF_NULL(y);
  374. bool x_b = false;
  375. bool y_b = false;
  376. if (ValueToBool(x, &x_b) && ValueToBool(y, &y_b)) {
  377. auto res = x_b || y_b;
  378. return MakeValue(res);
  379. }
  380. MS_EXCEPTION(NotSupportError) << "Not support [x:" << x->ToString() << "] BoolOr [y:" << y->ToString() << "].";
  381. }
  382. ValuePtr BoolEq(const ValuePtrList &list) {
  383. constexpr size_t kListInputSize = 2;
  384. if (list.size() != kListInputSize) {
  385. MS_EXCEPTION(NotSupportError) << "Input number of BoolEq must be 2, but got " << list.size() << ".";
  386. }
  387. ValuePtr x = list[0];
  388. ValuePtr y = list[1];
  389. MS_EXCEPTION_IF_NULL(x);
  390. MS_EXCEPTION_IF_NULL(y);
  391. bool x_b = false;
  392. bool y_b = false;
  393. if (ValueToBool(x, &x_b) && ValueToBool(y, &y_b)) {
  394. auto res = x_b == y_b;
  395. return MakeValue(res);
  396. }
  397. MS_EXCEPTION(NotSupportError) << "Not support [x:" << x->ToString() << "] BoolEq [y:" << y->ToString() << "].";
  398. }
  399. } // namespace prim
  400. } // namespace mindspore