You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cc_implementations.cc 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "operator/cc_implementations.h"
  17. #include <cassert>
  18. #include <limits>
  19. #include <algorithm>
  20. #include <cmath>
  21. #include <cfloat>
  22. #include "utils/misc.h"
  23. #include "utils/log_adapter.h"
  24. #include "utils/convert_utils.h"
  25. #include "common/utils.h"
  26. namespace mindspore {
  27. // namespace to support primitive operators definition
  28. namespace prim {
  29. enum class DataType { kInt, kFloat, kDouble, kUnknown };
  30. // Whether has a T type data in AnyPtrList.
  31. template <class T>
  32. bool HasType(const AnyPtrList &list) {
  33. bool ret = std::any_of(list.begin(), list.end(), [](const AnyPtr &ptr) { return ptr->is<T>(); });
  34. return ret;
  35. }
  36. DataType InferType(const AnyPtrList &list) {
  37. if (HasType<double>(list)) {
  38. return DataType::kDouble;
  39. } else if (HasType<float>(list)) {
  40. return DataType::kFloat;
  41. } else if (HasType<int>(list)) {
  42. return DataType::kInt;
  43. }
  44. return DataType::kUnknown;
  45. }
  46. enum OpType { ADD, SUB, MUL, DIV, MOD };
  47. template <typename T>
  48. bool IsSignedIntOverflow(T x, T y, OpType opType) {
  49. auto max = std::numeric_limits<T>::max();
  50. auto min = std::numeric_limits<T>::min();
  51. if (opType == OpType::ADD) {
  52. return (y > 0 && (max - y) < x) || (y < 0 && (min - y) > x);
  53. }
  54. if (opType == OpType::SUB) {
  55. return (y < 0 && (max + y) < x) || (y > 0 && (min + y) > x);
  56. }
  57. if (opType == OpType::MUL) {
  58. return (x > 0 && y > 0 && (max / y) < x) || (x < 0 && y < 0 && (max / y) > x) ||
  59. (x > 0 && y < 0 && (min / y) < x) || (x < 0 && y > 0 && (min / y) > x);
  60. }
  61. if (opType == OpType::DIV || opType == OpType::MOD) {
  62. return x == min && static_cast<int64_t>(y) == -1;
  63. }
  64. MS_LOG(EXCEPTION) << "Unsupported operation type.";
  65. }
  66. template <typename T>
  67. T InnerScalarAdd(T x, T y) {
  68. if (std::is_integral<T>::value && std::is_signed<T>::value && IsSignedIntOverflow(x, y, OpType::ADD)) {
  69. MS_LOG(EXCEPTION) << "Overflow of the sum of two signed number x: " << std::to_string(x)
  70. << ", y: " << std::to_string(y) << ".";
  71. }
  72. return x + y;
  73. }
  74. template <typename T>
  75. T InnerScalarSub(T x, T y) {
  76. if (std::is_integral<T>::value && std::is_signed<T>::value && IsSignedIntOverflow(x, y, OpType::SUB)) {
  77. MS_LOG(EXCEPTION) << "Overflow of the sub of two signed number x: " << std::to_string(x)
  78. << ", y: " << std::to_string(y) << ".";
  79. }
  80. return x - y;
  81. }
  82. template <typename T>
  83. T InnerScalarMul(T x, T y) {
  84. if (std::is_integral<T>::value && std::is_signed<T>::value && IsSignedIntOverflow(x, y, OpType::MUL)) {
  85. MS_LOG(EXCEPTION) << "Overflow of the mul of two signed number x: " << std::to_string(x)
  86. << ", y: " << std::to_string(y) << ".";
  87. }
  88. return x * y;
  89. }
  90. template <typename T>
  91. float InnerScalarDiv(T x, T y) {
  92. if (y == 0) {
  93. MS_LOG(EXCEPTION) << "Divisor could not be zero";
  94. }
  95. if (std::is_integral<T>::value && std::is_signed<T>::value && IsSignedIntOverflow(x, y, OpType::DIV)) {
  96. MS_LOG(EXCEPTION) << "Overflow of the div of two signed number x: " << std::to_string(x)
  97. << ", y: " << std::to_string(y) << ".";
  98. }
  99. return static_cast<float>(x) / static_cast<float>(y);
  100. }
  101. template <typename T>
  102. T InnerScalarFloordiv(T x, T y) {
  103. auto ret = std::floor(InnerScalarDiv(x, y));
  104. if (std::is_integral<T>::value) {
  105. return static_cast<int>(ret);
  106. }
  107. return ret;
  108. }
  109. template <typename T>
  110. T InnerScalarMod(T x, T y) {
  111. if (y == 0) {
  112. MS_LOG(EXCEPTION) << "Could not mod to zero.";
  113. }
  114. if (std::is_integral<T>::value && std::is_signed<T>::value && IsSignedIntOverflow(x, y, OpType::MOD)) {
  115. MS_LOG(EXCEPTION) << "Overflow of the mod of two signed number x: " << std::to_string(x)
  116. << ", y: " << std::to_string(y) << ".";
  117. }
  118. if (std::is_integral<T>::value) {
  119. return static_cast<int>(x) % static_cast<int>(y);
  120. }
  121. int x_int = std::floor(x);
  122. int y_int = std::ceil(y);
  123. int max = x_int / y_int;
  124. float ret = x - y * max;
  125. return ret;
  126. }
  127. template <typename T, typename U>
  128. T InnerScalarPow(T x, U y) {
  129. return std::pow(x, y);
  130. }
  131. template <typename T, typename U>
  132. bool InnerScalarEq(T x, U y) {
  133. double error = static_cast<double>(x) - static_cast<double>(y);
  134. error = fabs(error);
  135. return error < DBL_EPSILON;
  136. }
  137. template <typename T, typename U>
  138. bool InnerScalarLt(T x, U y) {
  139. return x < y;
  140. }
  141. template <typename T, typename U>
  142. bool InnerScalarGt(T x, U y) {
  143. return x > y;
  144. }
  145. template <typename T, typename U>
  146. bool InnerScalarNe(T x, U y) {
  147. return !InnerScalarEq(x, y);
  148. }
  149. template <typename T, typename U>
  150. bool InnerScalarLe(T x, U y) {
  151. return x <= y;
  152. }
  153. template <typename T, typename U>
  154. bool InnerScalarGe(T x, U y) {
  155. return x >= y;
  156. }
  157. #define SCALAR_OP(op_t) \
  158. ValuePtr Scalar##op_t(const ValuePtrList &list) { \
  159. do { \
  160. if (list.size() < 2) { \
  161. MS_LOG(EXCEPTION) << "length of input list for Scalar" << #op_t << " is less than 2."; \
  162. } \
  163. ValuePtr x = list[0]; \
  164. ValuePtr y = list[1]; \
  165. MS_EXCEPTION_IF_NULL(x); \
  166. MS_EXCEPTION_IF_NULL(y); \
  167. if (x->isa<FP64Imm>() && y->isa<FP64Imm>()) { \
  168. double sum = InnerScalar##op_t(GetValue<double>(x), GetValue<double>(y)); \
  169. return MakeValue(sum); \
  170. } \
  171. if (x->isa<FP32Imm>() && y->isa<FP32Imm>()) { \
  172. float sum = InnerScalar##op_t(GetValue<float>(x), GetValue<float>(y)); \
  173. return MakeValue(sum); \
  174. } \
  175. if (x->isa<Int32Imm>() && y->isa<Int32Imm>()) { \
  176. int sum = InnerScalar##op_t(GetValue<int>(x), GetValue<int>(y)); \
  177. return MakeValue(sum); \
  178. } \
  179. if (x->isa<Int32Imm>() && y->isa<FP32Imm>()) { \
  180. float sum = InnerScalar##op_t(IntToFloat(GetValue<int>(x)), GetValue<float>(y)); \
  181. return MakeValue(sum); \
  182. } \
  183. if (x->isa<FP32Imm>() && y->isa<Int32Imm>()) { \
  184. float sum = InnerScalar##op_t(GetValue<float>(x), IntToFloat(GetValue<int>(y))); \
  185. return MakeValue(sum); \
  186. } \
  187. MS_LOG(EXCEPTION) << "Unsupported Value for Scalar" << #op_t << ", x: " << x->ToString() \
  188. << ", y: " << y->ToString(); \
  189. } while (0); \
  190. }
  191. SCALAR_OP(Add)
  192. SCALAR_OP(Sub)
  193. SCALAR_OP(Mul)
  194. SCALAR_OP(Div)
  195. SCALAR_OP(Mod)
  196. SCALAR_OP(Pow)
  197. SCALAR_OP(Floordiv)
  198. #define LOGIC_OP(op_t) \
  199. ValuePtr Scalar##op_t(const ValuePtrList &list) { \
  200. if (list.size() < 2) { \
  201. MS_LOG(EXCEPTION) << "length of input list for Scalar" << #op_t << " is less than 2."; \
  202. } \
  203. ValuePtr x = list[0]; \
  204. ValuePtr y = list[1]; \
  205. MS_EXCEPTION_IF_NULL(x); \
  206. MS_EXCEPTION_IF_NULL(y); \
  207. if (x->isa<FP64Imm>() && y->isa<FP64Imm>()) { \
  208. bool sum = InnerScalar##op_t(GetValue<double>(x), GetValue<double>(y)); \
  209. return MakeValue(sum); \
  210. } \
  211. if (x->isa<FP32Imm>() && y->isa<FP32Imm>()) { \
  212. bool sum = InnerScalar##op_t(GetValue<float>(x), GetValue<float>(y)); \
  213. return MakeValue(sum); \
  214. } \
  215. if (x->isa<FP64Imm>() && y->isa<FP32Imm>()) { \
  216. bool sum = InnerScalar##op_t(GetValue<double>(x), GetValue<float>(y)); \
  217. return MakeValue(sum); \
  218. } \
  219. if (x->isa<FP32Imm>() && y->isa<FP64Imm>()) { \
  220. bool sum = InnerScalar##op_t(GetValue<float>(x), GetValue<double>(y)); \
  221. return MakeValue(sum); \
  222. } \
  223. if (x->isa<Int32Imm>() && y->isa<Int32Imm>()) { \
  224. bool sum = InnerScalar##op_t(GetValue<int>(x), GetValue<int>(y)); \
  225. return MakeValue(sum); \
  226. } \
  227. if (x->isa<FP32Imm>() && y->isa<Int32Imm>()) { \
  228. bool sum = InnerScalar##op_t(GetValue<float>(x), GetValue<int>(y)); \
  229. return MakeValue(sum); \
  230. } \
  231. if (x->isa<Int32Imm>() && y->isa<FP32Imm>()) { \
  232. bool sum = InnerScalar##op_t(GetValue<int>(x), GetValue<float>(y)); \
  233. return MakeValue(sum); \
  234. } \
  235. if (x->isa<Int64Imm>() && y->isa<Int32Imm>()) { \
  236. bool sum = InnerScalar##op_t(GetValue<int64_t>(x), GetValue<int>(y)); \
  237. return MakeValue(sum); \
  238. } \
  239. MS_LOG(EXCEPTION) << "Unsupported Value for Scalar" << #op_t << ", x: " << x->ToString() \
  240. << ", y: " << y->ToString() << "."; \
  241. }
  242. LOGIC_OP(Eq)
  243. LOGIC_OP(Lt)
  244. LOGIC_OP(Gt)
  245. LOGIC_OP(Ne)
  246. LOGIC_OP(Le)
  247. LOGIC_OP(Ge)
  248. ValuePtr ScalarUAdd(const ValuePtrList &list) {
  249. if (list.size() != 1) {
  250. MS_LOG(EXCEPTION) << "Input number of ScalarUAdd should be 1, but got " << list.size();
  251. }
  252. ValuePtr x = list[0];
  253. MS_EXCEPTION_IF_NULL(x);
  254. return x;
  255. }
  256. ValuePtr ScalarUSub(const ValuePtrList &list) {
  257. if (list.size() != 1) {
  258. MS_LOG(EXCEPTION) << "Input number of ScalarUSub should be 1, but got " << list.size();
  259. }
  260. ValuePtr x = list[0];
  261. MS_EXCEPTION_IF_NULL(x);
  262. if (x->isa<Int32Imm>()) {
  263. int32_t sum = -1 * GetValue<int>(x);
  264. return MakeValue(sum);
  265. }
  266. if (x->isa<FP32Imm>()) {
  267. float sum = -1.0f * GetValue<float>(x);
  268. return MakeValue(sum);
  269. }
  270. MS_LOG(EXCEPTION) << "Unsported Value for ScalarUSub, x: " << x->ToString() << ".";
  271. }
  272. ValuePtr ScalarLog(const ValuePtrList &list) {
  273. if (list.empty()) {
  274. MS_LOG(EXCEPTION) << "Input list of ScalarLog is empty.";
  275. }
  276. ValuePtr x = list[0];
  277. MS_EXCEPTION_IF_NULL(x);
  278. if (x->isa<FP64Imm>()) {
  279. double v = log(GetValue<double>(x));
  280. return MakeValue(v);
  281. }
  282. if (x->isa<FP32Imm>()) {
  283. auto v = static_cast<float>(log(GetValue<float>(x)));
  284. return MakeValue(v);
  285. }
  286. MS_LOG(EXCEPTION) << "Unsported Value for ScalarLog, x: " << x->ToString();
  287. }
  288. ValuePtr BoolNot(const ValuePtrList &list) {
  289. if (list.empty()) {
  290. MS_LOG(EXCEPTION) << "value list of BoolNot is empty";
  291. }
  292. ValuePtr x = list[0];
  293. MS_EXCEPTION_IF_NULL(x);
  294. bool convert = false;
  295. if (ValueToBool(x, &convert)) {
  296. auto res = !convert;
  297. return MakeValue(res);
  298. }
  299. MS_LOG(EXCEPTION) << "Unsported Value for BoolNot, x: " << x->ToString();
  300. }
  301. ValuePtr BoolAnd(const ValuePtrList &list) {
  302. if (list.size() < 2) {
  303. MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolAnd is less then 2.";
  304. }
  305. ValuePtr x = list[0];
  306. ValuePtr y = list[1];
  307. MS_EXCEPTION_IF_NULL(x);
  308. MS_EXCEPTION_IF_NULL(y);
  309. bool x_b = false;
  310. bool y_b = false;
  311. if (ValueToBool(x, &x_b) && ValueToBool(y, &y_b)) {
  312. auto res = x_b && y_b;
  313. return MakeValue(res);
  314. }
  315. MS_LOG(EXCEPTION) << "Unsported Value for BoolAnd, x: " << x->ToString() << ".";
  316. }
  317. ValuePtr BoolOr(const ValuePtrList &list) {
  318. if (list.size() < 2) {
  319. MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolOr is less then 2.";
  320. }
  321. ValuePtr x = list[0];
  322. ValuePtr y = list[1];
  323. MS_EXCEPTION_IF_NULL(x);
  324. MS_EXCEPTION_IF_NULL(y);
  325. bool x_b = false;
  326. bool y_b = false;
  327. if (ValueToBool(x, &x_b) && ValueToBool(y, &y_b)) {
  328. auto res = x_b || y_b;
  329. return MakeValue(res);
  330. }
  331. MS_LOG(EXCEPTION) << "Unsported Value for BoolOr, x: " << x->ToString() << ".";
  332. }
  333. ValuePtr BoolEq(const ValuePtrList &list) {
  334. if (list.size() < 2) {
  335. MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolEq is less than 2.";
  336. }
  337. ValuePtr x = list[0];
  338. ValuePtr y = list[1];
  339. MS_EXCEPTION_IF_NULL(x);
  340. MS_EXCEPTION_IF_NULL(y);
  341. bool x_b = false;
  342. bool y_b = false;
  343. if (ValueToBool(x, &x_b) && ValueToBool(y, &y_b)) {
  344. auto res = x_b == y_b;
  345. return MakeValue(res);
  346. }
  347. MS_LOG(EXCEPTION) << "Unsported Value for BoolEq, x: " << x->ToString() << ".";
  348. }
  349. std::vector<int> BroadcastShape_(std::vector<int> shpx, std::vector<int> shpy) {
  350. int dlen = SizeToInt(shpx.size()) - SizeToInt(shpy.size());
  351. if (dlen < 0) {
  352. for (int i = 0; i < -dlen; ++i) {
  353. (void)shpx.insert(shpx.begin(), 1);
  354. }
  355. } else if (dlen > 0) {
  356. for (int i = 0; i < dlen; i++) {
  357. (void)shpy.insert(shpy.begin(), 1);
  358. }
  359. }
  360. if (shpx.size() != shpy.size()) {
  361. MS_LOG(EXCEPTION) << "Failure: shpx.size() != shpy.size().";
  362. }
  363. std::vector<int> shp;
  364. for (size_t i = 0; i < shpx.size(); i++) {
  365. auto a = shpx[i];
  366. auto b = shpy[i];
  367. if (a == 1) {
  368. shp.push_back(b);
  369. } else if (b == 1) {
  370. shp.push_back(a);
  371. } else if (a == -1) {
  372. shp.push_back(b);
  373. } else if (b == -1) {
  374. shp.push_back(a);
  375. } else if (a == b) {
  376. shp.push_back(a);
  377. } else {
  378. return std::vector<int>();
  379. }
  380. }
  381. return shp;
  382. }
  383. } // namespace prim
  384. } // namespace mindspore