You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cc_implementations.cc 23 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "frontend/operator/cc_implementations.h"
  17. #include <limits>
  18. #include <algorithm>
  19. #include <cmath>
  20. #include <cfloat>
  21. #include "utils/log_adapter.h"
  22. #include "utils/convert_utils.h"
  23. #include "utils/ms_utils.h"
  24. namespace mindspore {
  25. // namespace to support primitive operators definition
  26. namespace prim {
  27. enum class DataType { kInt, kInt64, kFloat, kDouble, kUnknown };
  28. // Whether has a T type data in AnyPtrList.
  29. template <class T>
  30. bool HasType(const AnyPtrList &list) {
  31. bool ret = std::any_of(list.begin(), list.end(), [](const AnyPtr &ptr) { return ptr->is<T>(); });
  32. return ret;
  33. }
  34. DataType InferType(const AnyPtrList &list) {
  35. if (HasType<double>(list)) {
  36. return DataType::kDouble;
  37. } else if (HasType<float>(list)) {
  38. return DataType::kFloat;
  39. } else if (HasType<int64_t>(list)) {
  40. return DataType::kInt64;
  41. } else if (HasType<int>(list)) {
  42. return DataType::kInt;
  43. }
  44. return DataType::kUnknown;
  45. }
  46. template <typename T>
  47. bool IsAddOverflow(const T &x, const T &y, const T &max, const T &min) {
  48. return (y > 0 && (max - y) < x) || (y < 0 && (min - y) > x);
  49. }
  50. template <typename T>
  51. bool IsSubOverflow(const T &x, const T &y, const T &max, const T &min) {
  52. return (y < 0 && (max + y) < x) || (y > 0 && (min + y) > x);
  53. }
  54. template <typename T>
  55. bool IsMulOverflow(const T &x, const T &y, const T &max, const T &min) {
  56. return (x > 0 && y > 0 && (max / y) < x) || (x < 0 && y < 0 && (max / y) > x) || (x > 0 && y < 0 && (min / y) < x) ||
  57. (x < 0 && y > 0 && (min / y) > x);
  58. }
  59. template <typename T>
  60. bool IsDivOverflow(const T &x, const T &y, const T &min) {
  61. return (x == min && static_cast<int64_t>(y) == -1);
  62. }
  63. enum class OpType { ADD, SUB, MUL, DIV, MOD };
  64. template <typename T>
  65. bool IsSignedIntOverflow(T x, T y, OpType opType) {
  66. auto max = std::numeric_limits<T>::max();
  67. auto min = std::numeric_limits<T>::min();
  68. if (opType == OpType::ADD) {
  69. return IsAddOverflow<T>(x, y, max, min);
  70. }
  71. if (opType == OpType::SUB) {
  72. return IsSubOverflow<T>(x, y, max, min);
  73. }
  74. if (opType == OpType::MUL) {
  75. return IsMulOverflow<T>(x, y, max, min);
  76. }
  77. if (opType == OpType::DIV || opType == OpType::MOD) {
  78. return IsDivOverflow<T>(x, y, min);
  79. }
  80. MS_LOG(EXCEPTION) << "Unsupported operation type.";
  81. }
  82. template <typename T>
  83. T InnerScalarAdd(T x, T y) {
  84. if (std::is_integral<T>::value && std::is_signed<T>::value && IsSignedIntOverflow(x, y, OpType::ADD)) {
  85. MS_LOG(EXCEPTION) << "Overflow of the sum of two signed number x: " << std::to_string(x)
  86. << ", y: " << std::to_string(y) << ".";
  87. }
  88. return x + y;
  89. }
  90. template <typename T>
  91. T InnerScalarSub(T x, T y) {
  92. if (std::is_integral<T>::value && std::is_signed<T>::value && IsSignedIntOverflow(x, y, OpType::SUB)) {
  93. MS_LOG(EXCEPTION) << "Overflow of the sub of two signed number x: " << std::to_string(x)
  94. << ", y: " << std::to_string(y) << ".";
  95. }
  96. return x - y;
  97. }
  98. template <typename T>
  99. T InnerScalarMul(T x, T y) {
  100. if (std::is_integral<T>::value && std::is_signed<T>::value && IsSignedIntOverflow(x, y, OpType::MUL)) {
  101. MS_LOG(EXCEPTION) << "Overflow of the mul of two signed number x: " << std::to_string(x)
  102. << ", y: " << std::to_string(y) << ".";
  103. }
  104. return x * y;
  105. }
  106. template <typename T>
  107. float InnerScalarDiv(T x, T y) {
  108. if (y == 0) {
  109. MS_LOG(EXCEPTION) << "Divisor could not be zero";
  110. }
  111. if (std::is_integral<T>::value && std::is_signed<T>::value && IsSignedIntOverflow(x, y, OpType::DIV)) {
  112. MS_LOG(EXCEPTION) << "Overflow of the div of two signed number x: " << std::to_string(x)
  113. << ", y: " << std::to_string(y) << ".";
  114. }
  115. return static_cast<float>(x) / static_cast<float>(y);
  116. }
  117. template <typename T>
  118. T InnerScalarFloordiv(T x, T y) {
  119. auto ret = std::floor(InnerScalarDiv(x, y));
  120. if (std::is_integral<T>::value) {
  121. return static_cast<int64_t>(ret);
  122. }
  123. return ret;
  124. }
  125. template <typename T>
  126. T InnerScalarMod(T x, T y) {
  127. if (y == 0) {
  128. MS_LOG(EXCEPTION) << "Could not mod to zero.";
  129. }
  130. if (std::is_integral<T>::value && std::is_signed<T>::value && IsSignedIntOverflow(x, y, OpType::MOD)) {
  131. MS_LOG(EXCEPTION) << "Overflow of the mod of two signed number x: " << std::to_string(x)
  132. << ", y: " << std::to_string(y) << ".";
  133. }
  134. if (std::is_integral<T>::value) {
  135. return static_cast<int64_t>(x) % static_cast<int64_t>(y);
  136. }
  137. return x - y * std::floor(x / y);
  138. }
  139. template <typename T, typename U>
  140. T InnerScalarPow(T x, U y) {
  141. return std::pow(x, y);
  142. }
  143. template <typename T, typename U>
  144. bool InnerScalarEq(T x, U y) {
  145. double error = static_cast<double>(x) - static_cast<double>(y);
  146. error = fabs(error);
  147. return error < DBL_EPSILON;
  148. }
  149. template <typename T, typename U>
  150. bool InnerScalarLt(T x, U y) {
  151. return x < y;
  152. }
  153. template <typename T, typename U>
  154. bool InnerScalarGt(T x, U y) {
  155. return x > y;
  156. }
  157. template <typename T, typename U>
  158. bool InnerScalarNe(T x, U y) {
  159. return !InnerScalarEq(x, y);
  160. }
  161. template <typename T, typename U>
  162. bool InnerScalarLe(T x, U y) {
  163. return x <= y;
  164. }
  165. template <typename T, typename U>
  166. bool InnerScalarGe(T x, U y) {
  167. return x >= y;
  168. }
  169. #define SCALAR_OP(op_t) \
  170. ValuePtr Scalar##op_t(const ValuePtrList &list) { \
  171. do { \
  172. if (list.size() < 2) { \
  173. MS_LOG(EXCEPTION) << "length of input list for Scalar" << #op_t << " is less than 2."; \
  174. } \
  175. ValuePtr x = list[0]; \
  176. ValuePtr y = list[1]; \
  177. MS_EXCEPTION_IF_NULL(x); \
  178. MS_EXCEPTION_IF_NULL(y); \
  179. if (x->isa<FP64Imm>() && y->isa<FP64Imm>()) { \
  180. double sum = InnerScalar##op_t(GetValue<double>(x), GetValue<double>(y)); \
  181. return MakeValue(sum); \
  182. } \
  183. if (x->isa<FP32Imm>() && y->isa<FP32Imm>()) { \
  184. float sum = InnerScalar##op_t(GetValue<float>(x), GetValue<float>(y)); \
  185. return MakeValue(sum); \
  186. } \
  187. if (x->isa<Int32Imm>() && y->isa<Int32Imm>()) { \
  188. int sum = InnerScalar##op_t(GetValue<int>(x), GetValue<int>(y)); \
  189. return MakeValue(sum); \
  190. } \
  191. if (x->isa<Int32Imm>() && y->isa<FP32Imm>()) { \
  192. float sum = InnerScalar##op_t(IntToFloat(GetValue<int>(x)), GetValue<float>(y)); \
  193. return MakeValue(sum); \
  194. } \
  195. if (x->isa<FP32Imm>() && y->isa<Int32Imm>()) { \
  196. float sum = InnerScalar##op_t(GetValue<float>(x), IntToFloat(GetValue<int>(y))); \
  197. return MakeValue(sum); \
  198. } \
  199. if (x->isa<Int64Imm>() && y->isa<Int64Imm>()) { \
  200. int64_t sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<int64_t>(y)); \
  201. return MakeValue(sum); \
  202. } \
  203. if (x->isa<Int64Imm>() && y->isa<FP64Imm>()) { \
  204. double sum = InnerScalar##op_t(LongToDouble(GetValue<int64_t>(x)), GetValue<double>(y)); \
  205. return MakeValue(sum); \
  206. } \
  207. if (x->isa<Int64Imm>() && y->isa<FP32Imm>()) { \
  208. double sum = InnerScalar##op_t(LongToDouble(GetValue<int64_t>(x)), FloatToDouble(GetValue<float>(y))); \
  209. return MakeValue(sum); \
  210. } \
  211. if (x->isa<Int64Imm>() && y->isa<Int32Imm>()) { \
  212. int64_t sum = InnerScalar##op_t(GetValue<int64_t>(x), IntToLong(GetValue<int>(y))); \
  213. return MakeValue(sum); \
  214. } \
  215. if (x->isa<FP32Imm>() && y->isa<Int64Imm>()) { \
  216. double sum = InnerScalar##op_t(FloatToDouble(GetValue<float>(x)), LongToDouble(GetValue<int64_t>(y))); \
  217. return MakeValue(sum); \
  218. } \
  219. if (x->isa<FP64Imm>() && y->isa<Int64Imm>()) { \
  220. double sum = InnerScalar##op_t(GetValue<double>(x), LongToDouble(GetValue<int64_t>(y))); \
  221. return MakeValue(sum); \
  222. } \
  223. if (x->isa<Int32Imm>() && y->isa<Int64Imm>()) { \
  224. int64_t sum = InnerScalar##op_t(IntToLong(GetValue<int>(x)), GetValue<int64_t>(y)); \
  225. return MakeValue(sum); \
  226. } \
  227. MS_LOG(EXCEPTION) << "Unsupported input type for Scalar" << #op_t << ", type of x:" << x->type_name() \
  228. << ", value of x:" << x->ToString() << ", type of y:" << y->type_name() \
  229. << ", value of y:" << y->ToString(); \
  230. } while (0); \
  231. }
  232. SCALAR_OP(Add)
  233. SCALAR_OP(Sub)
  234. SCALAR_OP(Mul)
  235. SCALAR_OP(Div)
  236. SCALAR_OP(Mod)
  237. SCALAR_OP(Pow)
  238. SCALAR_OP(Floordiv)
  239. #define LOGIC_OP(op_t) \
  240. ValuePtr Scalar##op_t(const ValuePtrList &list) { \
  241. if (list.size() < 2) { \
  242. MS_LOG(EXCEPTION) << "length of input list for Scalar" << #op_t << " is less than 2."; \
  243. } \
  244. ValuePtr x = list[0]; \
  245. ValuePtr y = list[1]; \
  246. MS_EXCEPTION_IF_NULL(x); \
  247. MS_EXCEPTION_IF_NULL(y); \
  248. if (x->isa<FP64Imm>() && y->isa<FP64Imm>()) { \
  249. bool sum = InnerScalar##op_t(GetValue<double>(x), GetValue<double>(y)); \
  250. return MakeValue(sum); \
  251. } \
  252. if (x->isa<FP32Imm>() && y->isa<FP32Imm>()) { \
  253. bool sum = InnerScalar##op_t(GetValue<float>(x), GetValue<float>(y)); \
  254. return MakeValue(sum); \
  255. } \
  256. if (x->isa<FP64Imm>() && y->isa<FP32Imm>()) { \
  257. bool sum = InnerScalar##op_t(GetValue<double>(x), GetValue<float>(y)); \
  258. return MakeValue(sum); \
  259. } \
  260. if (x->isa<FP32Imm>() && y->isa<FP64Imm>()) { \
  261. bool sum = InnerScalar##op_t(GetValue<float>(x), GetValue<double>(y)); \
  262. return MakeValue(sum); \
  263. } \
  264. if (x->isa<Int32Imm>() && y->isa<Int32Imm>()) { \
  265. bool sum = InnerScalar##op_t(GetValue<int>(x), GetValue<int>(y)); \
  266. return MakeValue(sum); \
  267. } \
  268. if (x->isa<FP32Imm>() && y->isa<Int32Imm>()) { \
  269. bool sum = InnerScalar##op_t(GetValue<float>(x), GetValue<int>(y)); \
  270. return MakeValue(sum); \
  271. } \
  272. if (x->isa<FP32Imm>() && y->isa<Int64Imm>()) { \
  273. bool sum = InnerScalar##op_t(GetValue<float>(x), GetValue<int64_t>(y)); \
  274. return MakeValue(sum); \
  275. } \
  276. if (x->isa<Int32Imm>() && y->isa<FP32Imm>()) { \
  277. bool sum = InnerScalar##op_t(GetValue<int>(x), GetValue<float>(y)); \
  278. return MakeValue(sum); \
  279. } \
  280. if (x->isa<Int64Imm>() && y->isa<FP32Imm>()) { \
  281. bool sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<float>(y)); \
  282. return MakeValue(sum); \
  283. } \
  284. if (x->isa<Int64Imm>() && y->isa<Int64Imm>()) { \
  285. bool sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<int64_t>(y)); \
  286. return MakeValue(sum); \
  287. } \
  288. if (x->isa<FP64Imm>() && y->isa<Int64Imm>()) { \
  289. bool sum = InnerScalar##op_t(GetValue<double>(x), GetValue<int64_t>(y)); \
  290. return MakeValue(sum); \
  291. } \
  292. if (x->isa<Int64Imm>() && y->isa<FP64Imm>()) { \
  293. bool sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<double>(y)); \
  294. return MakeValue(sum); \
  295. } \
  296. if (x->isa<Int64Imm>() && y->isa<Int32Imm>()) { \
  297. bool sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<int>(y)); \
  298. return MakeValue(sum); \
  299. } \
  300. if (x->isa<Int32Imm>() && y->isa<Int64Imm>()) { \
  301. bool sum = InnerScalar##op_t(GetValue<int>(x), GetValue<int64_t>(y)); \
  302. return MakeValue(sum); \
  303. } \
  304. MS_LOG(EXCEPTION) << "Unsupported Value for Scalar" << #op_t << ", x: " << x->ToString() \
  305. << ", y: " << y->ToString() << "."; \
  306. }
  307. LOGIC_OP(Eq)
  308. LOGIC_OP(Lt)
  309. LOGIC_OP(Gt)
  310. LOGIC_OP(Ne)
  311. LOGIC_OP(Le)
  312. LOGIC_OP(Ge)
  313. ValuePtr ScalarUAdd(const ValuePtrList &list) {
  314. if (list.size() != 1) {
  315. MS_LOG(EXCEPTION) << "Input number of ScalarUAdd should be 1, but got " << list.size();
  316. }
  317. ValuePtr x = list[0];
  318. MS_EXCEPTION_IF_NULL(x);
  319. return x;
  320. }
  321. ValuePtr ScalarUSub(const ValuePtrList &list) {
  322. if (list.size() != 1) {
  323. MS_LOG(EXCEPTION) << "Input number of ScalarUSub should be 1, but got " << list.size();
  324. }
  325. ValuePtr x = list[0];
  326. MS_EXCEPTION_IF_NULL(x);
  327. if (x->isa<Int32Imm>()) {
  328. int32_t sum = -1 * GetValue<int32_t>(x);
  329. return MakeValue(sum);
  330. }
  331. if (x->isa<Int64Imm>()) {
  332. int64_t sum = -1 * GetValue<int64_t>(x);
  333. return MakeValue(sum);
  334. }
  335. if (x->isa<FP32Imm>()) {
  336. float sum = -1.0f * GetValue<float>(x);
  337. return MakeValue(sum);
  338. }
  339. MS_LOG(EXCEPTION) << "Unsported Value for ScalarUSub, x: " << x->ToString() << ".";
  340. }
  341. ValuePtr ScalarLog(const ValuePtrList &list) {
  342. if (list.empty()) {
  343. MS_LOG(EXCEPTION) << "Input list of ScalarLog is empty.";
  344. }
  345. ValuePtr x = list[0];
  346. MS_EXCEPTION_IF_NULL(x);
  347. if (x->isa<FP64Imm>()) {
  348. double v = log(GetValue<double>(x));
  349. return MakeValue(v);
  350. }
  351. if (x->isa<FP32Imm>()) {
  352. auto v = static_cast<float>(log(GetValue<float>(x)));
  353. return MakeValue(v);
  354. }
  355. MS_LOG(EXCEPTION) << "Unsported Value for ScalarLog, x: " << x->ToString();
  356. }
  357. ValuePtr BoolNot(const ValuePtrList &list) {
  358. if (list.empty()) {
  359. MS_LOG(EXCEPTION) << "value list of BoolNot is empty";
  360. }
  361. ValuePtr x = list[0];
  362. MS_EXCEPTION_IF_NULL(x);
  363. bool convert = false;
  364. if (ValueToBool(x, &convert)) {
  365. auto res = !convert;
  366. return MakeValue(res);
  367. }
  368. MS_LOG(EXCEPTION) << "Unsported Value for BoolNot, x: " << x->ToString();
  369. }
  370. ValuePtr BoolAnd(const ValuePtrList &list) {
  371. if (list.size() < 2) {
  372. MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolAnd is less then 2.";
  373. }
  374. ValuePtr x = list[0];
  375. ValuePtr y = list[1];
  376. MS_EXCEPTION_IF_NULL(x);
  377. MS_EXCEPTION_IF_NULL(y);
  378. bool x_b = false;
  379. bool y_b = false;
  380. if (ValueToBool(x, &x_b) && ValueToBool(y, &y_b)) {
  381. auto res = x_b && y_b;
  382. return MakeValue(res);
  383. }
  384. MS_LOG(EXCEPTION) << "Unsported Value for BoolAnd, x: " << x->ToString() << ".";
  385. }
  386. ValuePtr BoolOr(const ValuePtrList &list) {
  387. if (list.size() < 2) {
  388. MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolOr is less then 2.";
  389. }
  390. ValuePtr x = list[0];
  391. ValuePtr y = list[1];
  392. MS_EXCEPTION_IF_NULL(x);
  393. MS_EXCEPTION_IF_NULL(y);
  394. bool x_b = false;
  395. bool y_b = false;
  396. if (ValueToBool(x, &x_b) && ValueToBool(y, &y_b)) {
  397. auto res = x_b || y_b;
  398. return MakeValue(res);
  399. }
  400. MS_LOG(EXCEPTION) << "Unsported Value for BoolOr, x: " << x->ToString() << ".";
  401. }
  402. ValuePtr BoolEq(const ValuePtrList &list) {
  403. if (list.size() < 2) {
  404. MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolEq is less than 2.";
  405. }
  406. ValuePtr x = list[0];
  407. ValuePtr y = list[1];
  408. MS_EXCEPTION_IF_NULL(x);
  409. MS_EXCEPTION_IF_NULL(y);
  410. bool x_b = false;
  411. bool y_b = false;
  412. if (ValueToBool(x, &x_b) && ValueToBool(y, &y_b)) {
  413. auto res = x_b == y_b;
  414. return MakeValue(res);
  415. }
  416. MS_LOG(EXCEPTION) << "Unsported Value for BoolEq, x: " << x->ToString() << ".";
  417. }
  418. } // namespace prim
  419. } // namespace mindspore