You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

lite_utils.h 17 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_LITE_INCLUDE_LITE_UTILS_H_
  17. #define MINDSPORE_LITE_INCLUDE_LITE_UTILS_H_
  18. #ifndef NOT_USE_STL
  19. #include <vector>
  20. #include <string>
  21. #include <memory>
  22. #include <functional>
  23. #else
  24. #include <stdint.h>
  25. #include <stdlib.h>
  26. #include <string.h>
  27. #include <stddef.h>
  28. #include <stdio.h>
  29. #include <float.h>
  30. #include <new>
  31. #endif // NOT_USE_STL
  32. #ifndef MS_API
  33. #ifdef _WIN32
  34. #define MS_API __declspec(dllexport)
  35. #else
  36. #define MS_API __attribute__((visibility("default")))
  37. #endif
  38. #endif
  39. namespace mindspore {
  40. namespace schema {
  41. struct Tensor;
  42. } // namespace schema
  43. namespace tensor {
  44. class MSTensor;
  45. } // namespace tensor
  46. namespace lite {
  47. struct DeviceContext;
  48. struct LiteQuantParam;
  49. } // namespace lite
  50. #ifdef NOT_USE_STL
  51. #define MS_C_EXCEPTION(...) exit(1)
  52. class String {
  53. public:
  54. String() {
  55. buffer_ = reinterpret_cast<char *>(malloc(sizeof(char) * 1));
  56. if (buffer_ == nullptr) {
  57. MS_C_EXCEPTION("malloc data failed");
  58. }
  59. buffer_[0] = '\0';
  60. size_ = 0;
  61. }
  62. String(size_t count, char ch) {
  63. if (count > SIZE_MAX / sizeof(char) - 1) {
  64. MS_C_EXCEPTION("Invalid string size");
  65. }
  66. buffer_ = reinterpret_cast<char *>(malloc(sizeof(char) * (count + 1)));
  67. if (buffer_ == nullptr) {
  68. MS_C_EXCEPTION("malloc data failed");
  69. }
  70. memset(buffer_, ch, count);
  71. buffer_[count] = '\0';
  72. size_ = count;
  73. }
  74. String(const char *s, size_t count) {
  75. if (s == nullptr) {
  76. buffer_ = reinterpret_cast<char *>(malloc(sizeof(char) * 1));
  77. if (buffer_ == nullptr) {
  78. MS_C_EXCEPTION("malloc data failed");
  79. }
  80. buffer_[0] = '\0';
  81. size_ = 0;
  82. return;
  83. }
  84. size_t size_s = strlen(s);
  85. if (size_s <= count) {
  86. size_ = size_s;
  87. } else {
  88. size_ = count;
  89. }
  90. buffer_ = reinterpret_cast<char *>(malloc(sizeof(char) * (size_ + 1)));
  91. if (buffer_ == nullptr) {
  92. MS_C_EXCEPTION("malloc data failed");
  93. }
  94. memcpy(buffer_, s, size_);
  95. buffer_[size_] = '\0';
  96. }
  97. explicit String(const char *s) {
  98. if (s == nullptr) {
  99. buffer_ = reinterpret_cast<char *>(malloc(sizeof(char) * 1));
  100. if (buffer_ == nullptr) {
  101. MS_C_EXCEPTION("malloc data failed");
  102. }
  103. buffer_[0] = '\0';
  104. size_ = 0;
  105. return;
  106. }
  107. size_ = strlen(s);
  108. buffer_ = reinterpret_cast<char *>(malloc(sizeof(char) * (size_ + 1)));
  109. if (buffer_ == nullptr) {
  110. MS_C_EXCEPTION("malloc data failed");
  111. }
  112. memcpy(buffer_, s, size_ + 1);
  113. }
  114. String(const String &other) {
  115. buffer_ = reinterpret_cast<char *>(malloc(sizeof(char) * (other.size_ + 1)));
  116. if (buffer_ == nullptr) {
  117. MS_C_EXCEPTION("malloc data failed");
  118. }
  119. size_ = other.size_;
  120. memcpy(buffer_, other.buffer_, size_ + 1);
  121. }
  122. String(const String &other, size_t pos, size_t count = npos) {
  123. if (pos >= other.size_) {
  124. buffer_ = reinterpret_cast<char *>(malloc(sizeof(char) * 1));
  125. if (buffer_ == nullptr) {
  126. MS_C_EXCEPTION("malloc data failed");
  127. }
  128. buffer_[0] = '\0';
  129. size_ = 0;
  130. } else {
  131. if (count == npos) {
  132. count = other.size_ - pos;
  133. }
  134. if (pos + count > other.size_) {
  135. size_ = other.size_ - pos;
  136. } else {
  137. size_ = count;
  138. }
  139. buffer_ = reinterpret_cast<char *>(malloc(sizeof(char) * (size_ + 1)));
  140. if (buffer_ == nullptr) {
  141. MS_C_EXCEPTION("malloc data failed");
  142. }
  143. memcpy(buffer_, other.buffer_ + pos, size_);
  144. buffer_[size_] = '\0';
  145. }
  146. }
  147. ~String() {
  148. if (buffer_ != nullptr) {
  149. free(buffer_);
  150. buffer_ = nullptr;
  151. }
  152. }
  153. String &operator=(const String &str) {
  154. if (this == &str) {
  155. return *this;
  156. }
  157. free(buffer_);
  158. buffer_ = reinterpret_cast<char *>(malloc(sizeof(char) * (str.size_ + 1)));
  159. if (buffer_ == nullptr) {
  160. MS_C_EXCEPTION("malloc data failed");
  161. }
  162. size_ = str.size_;
  163. memcpy(buffer_, str.buffer_, size_ + 1);
  164. return *this;
  165. }
  166. String &operator=(const char *str) {
  167. free(buffer_);
  168. if (str == nullptr) {
  169. buffer_ = reinterpret_cast<char *>(malloc(sizeof(char) * 1));
  170. if (buffer_ == nullptr) {
  171. MS_C_EXCEPTION("malloc data failed");
  172. }
  173. buffer_[0] = '\0';
  174. size_ = 0;
  175. return *this;
  176. }
  177. size_t size_s = strlen(str);
  178. buffer_ = reinterpret_cast<char *>(malloc(sizeof(char) * (size_s + 1)));
  179. if (buffer_ == nullptr) {
  180. MS_C_EXCEPTION("malloc data failed");
  181. }
  182. size_ = size_s;
  183. memcpy(buffer_, str, size_ + 1);
  184. return *this;
  185. }
  186. char &at(size_t pos) {
  187. if (pos >= size_) {
  188. MS_C_EXCEPTION("pos out of range");
  189. }
  190. return buffer_[pos];
  191. }
  192. const char &at(size_t pos) const {
  193. if (pos >= size_) {
  194. MS_C_EXCEPTION("pos out of range");
  195. }
  196. return buffer_[pos];
  197. }
  198. inline char &operator[](size_t pos) {
  199. if (pos >= size_) {
  200. MS_C_EXCEPTION("pos out of range");
  201. }
  202. return this->at(pos);
  203. }
  204. inline const char &operator[](size_t pos) const {
  205. if (pos >= size_) {
  206. MS_C_EXCEPTION("pos out of range");
  207. }
  208. return this->at(pos);
  209. }
  210. char *data() noexcept { return buffer_; }
  211. const char *data() const noexcept { return buffer_; }
  212. const char *c_str() const noexcept { return buffer_; }
  213. // capacity
  214. bool empty() const noexcept { return size_ == 0; }
  215. size_t size() const noexcept { return size_; }
  216. size_t length() const noexcept { return size_; }
  217. // operations
  218. void clear() noexcept {
  219. free(buffer_);
  220. buffer_ = reinterpret_cast<char *>(malloc(sizeof(char) * 1));
  221. if (buffer_ == nullptr) {
  222. MS_C_EXCEPTION("malloc data failed");
  223. }
  224. buffer_[0] = '\0';
  225. size_ = 0;
  226. }
  227. String &append(size_t count, const char ch) {
  228. for (size_t i = 0; i < count; i++) {
  229. (*this) += ch;
  230. }
  231. return *this;
  232. }
  233. String &append(const String &str) {
  234. (*this) += str;
  235. return *this;
  236. }
  237. String &append(const char *str) {
  238. if (str == nullptr) {
  239. return *this;
  240. }
  241. (*this) += str;
  242. return *this;
  243. }
  244. String &operator+(const String &str) {
  245. (*this) += str;
  246. return *this;
  247. }
  248. String &operator+=(const String &str) {
  249. if (size_ > SIZE_MAX / sizeof(char) - str.size_ - 1) {
  250. MS_C_EXCEPTION("Invalid string size");
  251. }
  252. size_t new_size = size_ + str.size_;
  253. char *tmp = reinterpret_cast<char *>(malloc(sizeof(char) * (new_size + 1)));
  254. if (tmp == nullptr) {
  255. MS_C_EXCEPTION("malloc data failed");
  256. }
  257. memcpy(tmp, this->buffer_, size_ + 1);
  258. strncat(tmp, str.buffer_, str.size_);
  259. tmp[new_size] = '\0';
  260. free(buffer_);
  261. buffer_ = tmp;
  262. size_ = new_size;
  263. return *this;
  264. }
  265. String &operator+=(const char *str) {
  266. if (str == nullptr) {
  267. return *this;
  268. }
  269. size_t str_size = strlen(str);
  270. if (size_ > SIZE_MAX / sizeof(char) - str_size - 1) {
  271. MS_C_EXCEPTION("Invalid string size");
  272. }
  273. size_t new_size = size_ + str_size;
  274. char *tmp = reinterpret_cast<char *>(malloc(sizeof(char) * (new_size + 1)));
  275. if (tmp == nullptr) {
  276. MS_C_EXCEPTION("malloc data failed");
  277. }
  278. memcpy(tmp, this->buffer_, size_ + 1);
  279. strncat(tmp, str, str_size);
  280. tmp[new_size] = '\0';
  281. free(buffer_);
  282. buffer_ = tmp;
  283. size_ = new_size;
  284. return *this;
  285. }
  286. String &operator+=(const char ch) {
  287. if (size_ > SIZE_MAX / sizeof(char) - 2) {
  288. MS_C_EXCEPTION("Invalid string size");
  289. }
  290. char *tmp = reinterpret_cast<char *>(malloc(sizeof(char) * (size_ + 2)));
  291. if (tmp == nullptr) {
  292. MS_C_EXCEPTION("malloc data failed");
  293. }
  294. memcpy(tmp, this->buffer_, size_ + 1);
  295. tmp[size_] = ch;
  296. tmp[size_ + 1] = '\0';
  297. free(buffer_);
  298. buffer_ = tmp;
  299. size_ += 1;
  300. return *this;
  301. }
  302. int compare(const String &str) const { return strcmp(buffer_, str.buffer_); }
  303. int compare(const char *str) const { return strcmp(buffer_, str); }
  304. String substr(size_t pos = 0, size_t count = npos) const { return String(*this, pos, count); }
  305. static const size_t npos = -1;
  306. private:
  307. size_t size_;
  308. char *buffer_;
  309. };
  310. inline String operator+(const String &lhs, const char *rhs) {
  311. String str = lhs;
  312. str += rhs;
  313. return str;
  314. }
  315. inline String operator+(const char *lhs, const String &rhs) {
  316. String str = rhs;
  317. str += lhs;
  318. return str;
  319. }
  320. inline bool operator!=(const String &lhs, const String &rhs) { return lhs.compare(rhs) != 0; }
  321. inline bool operator==(const String &lhs, const String &rhs) { return lhs.compare(rhs) == 0; }
  322. inline bool operator==(const String &lhs, const char *rhs) { return lhs.compare(rhs) == 0; }
  323. inline bool operator==(const char *lhs, const String &rhs) { return rhs.compare(lhs) == 0; }
  324. inline String to_String(int32_t value) {
  325. char tmp[sizeof(int32_t) * 4];
  326. snprintf(tmp, sizeof(int32_t) * 4, "%d", value);
  327. return String(tmp, strlen(tmp));
  328. }
  329. inline String to_String(float value) {
  330. char tmp[FLT_MAX_10_EXP + 20];
  331. snprintf(tmp, FLT_MAX_10_EXP + 20, "%f", value);
  332. return String(tmp, strlen(tmp));
  333. }
  334. #define DEFAULT_CAPACITY 4
  335. #define MIN(x, y) ((x < y) ? (x) : (y))
  336. template <typename T>
  337. class Vector {
  338. public:
  339. Vector() {
  340. size_ = 0;
  341. capacity_ = DEFAULT_CAPACITY;
  342. elem_size_ = sizeof(T);
  343. data_ = nullptr;
  344. }
  345. explicit Vector(size_t size) {
  346. size_ = size;
  347. elem_size_ = sizeof(T);
  348. capacity_ = (size == 0 ? DEFAULT_CAPACITY : size);
  349. data_ = new (std::nothrow) T[capacity_];
  350. if (data_ == nullptr) {
  351. MS_C_EXCEPTION("malloc data failed");
  352. }
  353. }
  354. Vector(size_t size, const T &value) {
  355. size_ = size;
  356. elem_size_ = sizeof(T);
  357. capacity_ = (size == 0 ? DEFAULT_CAPACITY : size);
  358. data_ = new (std::nothrow) T[capacity_];
  359. if (data_ == nullptr) {
  360. MS_C_EXCEPTION("malloc data failed");
  361. }
  362. for (int i = 0; i < static_cast<int>(size_); ++i) {
  363. data_[i] = value;
  364. }
  365. }
  366. Vector(const Vector<T> &vec) {
  367. size_ = vec.size_;
  368. elem_size_ = sizeof(T);
  369. capacity_ = vec.capacity_;
  370. data_ = new (std::nothrow) T[capacity_];
  371. if (data_ == nullptr) {
  372. MS_C_EXCEPTION("malloc data failed");
  373. }
  374. for (int i = 0; i < static_cast<int>(size_); ++i) {
  375. data_[i] = vec.data_[i];
  376. }
  377. }
  378. ~Vector() {
  379. if (data_ != nullptr) {
  380. delete[] data_;
  381. }
  382. }
  383. void clear() {
  384. size_ = 0;
  385. if (data_ != nullptr) {
  386. delete[] data_;
  387. data_ = nullptr;
  388. }
  389. }
  390. void push_back(const T &elem) {
  391. if (data_ == nullptr) {
  392. data_ = new (std::nothrow) T[capacity_];
  393. if (data_ == nullptr) {
  394. MS_C_EXCEPTION("malloc data failed");
  395. }
  396. } else if (size_ == capacity_) {
  397. resize(size_ + 1);
  398. --size_;
  399. }
  400. data_[size_] = elem;
  401. ++size_;
  402. }
  403. void push_back(T &&elem) {
  404. if (data_ == nullptr) {
  405. data_ = new (std::nothrow) T[capacity_];
  406. if (data_ == nullptr) {
  407. MS_C_EXCEPTION("malloc data failed");
  408. }
  409. } else if (size_ == capacity_) {
  410. resize(size_ + 1);
  411. --size_;
  412. }
  413. data_[size_] = elem;
  414. ++size_;
  415. }
  416. void pop_back() {
  417. if (size_ > 0) {
  418. --size_;
  419. } else {
  420. MS_C_EXCEPTION("Index is out of range!");
  421. }
  422. }
  423. void insert(const T &elem, size_t index) {
  424. if (index <= size_) {
  425. ++size_;
  426. if (size_ > capacity_) {
  427. resize(size_);
  428. }
  429. if (index == size_ - 1) {
  430. push_back(elem);
  431. } else {
  432. for (int i = static_cast<int>(size_) - 1; i > static_cast<int>(index); --i) {
  433. data_[i + 1] = data_[i];
  434. }
  435. data_[index] = elem;
  436. }
  437. } else {
  438. MS_C_EXCEPTION("Input index is out of range!");
  439. }
  440. }
  441. T *begin() { return data_; }
  442. const T *begin() const { return data_; }
  443. T *end() { return data_ + size_; }
  444. const T *end() const { return data_ + size_; }
  445. T &front() {
  446. if (size_ > 0) {
  447. return data_[0];
  448. }
  449. MS_C_EXCEPTION("Index is out of range!");
  450. }
  451. const T &front() const {
  452. if (size_ > 0) {
  453. return data_[0];
  454. }
  455. MS_C_EXCEPTION("Index is out of range!");
  456. }
  457. T &back() {
  458. if (size_ > 0) {
  459. return data_[size_ - 1];
  460. }
  461. MS_C_EXCEPTION("Index is out of range!");
  462. }
  463. const T &back() const {
  464. if (size_ > 0) {
  465. return data_[size_ - 1];
  466. }
  467. MS_C_EXCEPTION("Index is out of range!");
  468. }
  469. T &at(size_t index) {
  470. if (index < size_) {
  471. return data_[index];
  472. }
  473. MS_C_EXCEPTION("Input index is out of range!");
  474. }
  475. const T &at(size_t index) const {
  476. if (index < size_) {
  477. return data_[index];
  478. }
  479. MS_C_EXCEPTION("Input index is out of range!");
  480. }
  481. T &operator[](size_t index) {
  482. if (index < size_) {
  483. return data_[index];
  484. }
  485. MS_C_EXCEPTION("Input index is out of range!");
  486. }
  487. const T &operator[](size_t index) const {
  488. if (index < size_) {
  489. return data_[index];
  490. }
  491. MS_C_EXCEPTION("Input index is out of range!");
  492. }
  493. T *data() { return data_; }
  494. const T *data() const { return data_; }
  495. size_t size() const { return size_; }
  496. size_t capacity() const { return capacity_; }
  497. bool empty() const { return size_ == 0; }
  498. void erase(size_t index) {
  499. if (index == size_ - 1) {
  500. --size_;
  501. } else if (index < size_) {
  502. for (int i = index; i < static_cast<int>(size_); ++i) {
  503. data_[i] = data_[i + 1];
  504. }
  505. --size_;
  506. } else {
  507. MS_C_EXCEPTION("Input index is out of range!");
  508. }
  509. }
  510. void resize(size_t size) {
  511. while (size > capacity_) {
  512. capacity_ *= 2;
  513. }
  514. T *tmp = data_;
  515. data_ = new (std::nothrow) T[capacity_];
  516. if (data_ == nullptr) {
  517. MS_C_EXCEPTION("malloc data failed");
  518. }
  519. for (int i = 0; i < MIN(static_cast<int>(size), static_cast<int>(size_)); ++i) {
  520. data_[i] = tmp[i];
  521. }
  522. size_ = size;
  523. delete[] tmp;
  524. }
  525. void reserve(size_t capacity) {
  526. if (capacity > capacity_) {
  527. capacity_ = capacity;
  528. }
  529. }
  530. Vector<T> &operator=(const Vector<T> &vec) {
  531. if (this == &vec) {
  532. return *this;
  533. }
  534. size_ = vec.size_;
  535. elem_size_ = sizeof(T);
  536. capacity_ = vec.capacity_;
  537. if (data_ != nullptr) {
  538. delete[] data_;
  539. data_ = nullptr;
  540. }
  541. data_ = new (std::nothrow) T[capacity_];
  542. if (data_ == nullptr) {
  543. MS_C_EXCEPTION("malloc data failed");
  544. }
  545. for (int i = 0; i < static_cast<int>(size_); ++i) {
  546. data_[i] = vec.data_[i];
  547. }
  548. return *this;
  549. }
  550. private:
  551. size_t size_;
  552. size_t elem_size_;
  553. size_t capacity_;
  554. T *data_;
  555. };
  556. using TensorPtrVector = Vector<mindspore::schema::Tensor *>;
  557. using Uint32Vector = Vector<uint32_t>;
  558. class Allocator;
  559. using AllocatorPtr = void *;
  560. class Delegate;
  561. using DelegatePtr = void *;
  562. using DeviceContextVector = Vector<lite::DeviceContext>;
  563. using KernelCallBack = void (*)(void *, void *);
  564. #else
  565. /// \brief Allocator defined a memory pool for malloc memory and free memory dynamically.
  566. ///
  567. /// \note List public class and interface for reference.
  568. class Allocator;
  569. using AllocatorPtr = std::shared_ptr<Allocator>;
  570. class Delegate;
  571. using DelegatePtr = std::shared_ptr<Delegate>;
  572. using TensorPtrVector = std::vector<mindspore::schema::Tensor *>;
  573. using Uint32Vector = std::vector<uint32_t>;
  574. template <typename T>
  575. using Vector = std::vector<T>;
  576. template <typename T>
  577. inline std::string to_string(T t) {
  578. return std::to_string(t);
  579. }
  580. namespace tensor {
  581. using String = std::string;
  582. } // namespace tensor
  583. namespace session {
  584. using String = std::string;
  585. } // namespace session
  586. /// \brief CallBackParam defined input arguments for callBack function.
  587. struct CallBackParam {
  588. session::String node_name; /**< node name argument */
  589. session::String node_type; /**< node type argument */
  590. };
  591. struct GPUCallBackParam : CallBackParam {
  592. double execute_time{-1.f};
  593. };
  594. /// \brief KernelCallBack defined the function pointer for callBack.
  595. using KernelCallBack = std::function<bool(Vector<tensor::MSTensor *> inputs, Vector<tensor::MSTensor *> outputs,
  596. const CallBackParam &opInfo)>;
  597. namespace lite {
  598. using String = std::string;
  599. using DeviceContextVector = std::vector<DeviceContext>;
  600. /// \brief Set data of MSTensor from string vector.
  601. ///
  602. /// \param[in] input string vector.
  603. /// \param[out] MSTensor.
  604. ///
  605. /// \return STATUS as an error code of this interface, STATUS is defined in errorcode.h.
  606. int MS_API StringsToMSTensor(const Vector<String> &inputs, tensor::MSTensor *tensor);
  607. /// \brief Get string vector from MSTensor.
  608. /// \param[in] MSTensor.
  609. /// \return string vector.
  610. Vector<String> MS_API MSTensorToStrings(const tensor::MSTensor *tensor);
  611. } // namespace lite
  612. #endif // NOT_USE_STL
  613. } // namespace mindspore
  614. #endif // MINDSPORE_LITE_INCLUDE_LITE_UTILS_H_