You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

context.h 18 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_INCLUDE_API_CONTEXT_H
  17. #define MINDSPORE_INCLUDE_API_CONTEXT_H
  18. #include <string>
  19. #include <memory>
  20. #include <vector>
  21. #include <map>
  22. #include "include/api/types.h"
  23. #include "include/api/dual_abi_helper.h"
  24. namespace mindspore {
  25. enum DeviceType {
  26. kCPU = 0,
  27. kGPU,
  28. kKirinNPU,
  29. kAscend,
  30. kAscend910,
  31. kAscend310,
  32. // add new type here
  33. kInvalidDeviceType = 100,
  34. };
  35. class Allocator;
  36. class Delegate;
  37. class DeviceInfoContext;
  38. /// \brief Context is used to store environment variables during execution.
  39. class MS_API Context {
  40. public:
  41. struct Data;
  42. Context();
  43. ~Context() = default;
  44. /// \brief Set the number of threads at runtime. Only valid for Lite.
  45. ///
  46. /// \param[in] thread_num the number of threads at runtime.
  47. void SetThreadNum(int32_t thread_num);
  48. /// \brief Get the current thread number setting. Only valid for Lite.
  49. ///
  50. /// \return The current thread number setting.
  51. int32_t GetThreadNum() const;
  52. /// \brief Set the thread affinity to CPU cores. Only valid for Lite.
  53. ///
  54. /// \param[in] mode: 0: no affinities, 1: big cores first, 2: little cores first
  55. void SetThreadAffinity(int mode);
  56. /// \brief Get the thread affinity of CPU cores. Only valid for Lite.
  57. ///
  58. /// \return Thread affinity to CPU cores. 0: no affinities, 1: big cores first, 2: little cores first
  59. int GetThreadAffinityMode() const;
  60. /// \brief Set the thread lists to CPU cores. Only valid for Lite.
  61. ///
  62. /// \note If core_list and mode are set by SetThreadAffinity at the same time, the core_list is effective, but the
  63. /// mode is not effective.
  64. ///
  65. /// \param[in] core_list: a vector of thread core lists.
  66. void SetThreadAffinity(const std::vector<int> &core_list);
  67. /// \brief Get the thread lists of CPU cores. Only valid for Lite.
  68. ///
  69. /// \return core_list: a vector of thread core lists.
  70. std::vector<int32_t> GetThreadAffinityCoreList() const;
  71. /// \brief Set the status whether to perform model inference or training in parallel. Only valid for Lite.
  72. ///
  73. /// \param[in] is_parallel: true, parallel; false, not in parallel.
  74. void SetEnableParallel(bool is_parallel);
  75. /// \brief Get the status whether to perform model inference or training in parallel. Only valid for Lite.
  76. ///
  77. /// \return Bool value that indicates whether in parallel.
  78. bool GetEnableParallel() const;
  79. /// \brief Set Delegate to access third-party AI framework. Only valid for Lite.
  80. ///
  81. /// \param[in] Pointer to the custom delegate.
  82. void SetDelegate(const std::shared_ptr<Delegate> &delegate);
  83. /// \brief Get the delegate of the third-party AI framework. Only valid for Lite.
  84. ///
  85. /// \return Pointer to the custom delegate.
  86. std::shared_ptr<Delegate> GetDelegate() const;
  87. /// \brief Set quant model to run as float model in multi device.
  88. ///
  89. /// \param[in] float_mode: true, run as float model; false, not run as float model.
  90. void SetMultiModalHW(bool float_mode);
  91. /// \brief Get the mode of the model run.
  92. ///
  93. /// \return Bool value that indicates whether run as float model
  94. bool GetMultiModalHW() const;
  95. /// \brief Get a mutable reference of DeviceInfoContext vector in this context. Only MindSpore Lite supports
  96. /// heterogeneous scenarios with multiple members in the vector.
  97. ///
  98. /// \return Mutable reference of DeviceInfoContext vector in this context.
  99. std::vector<std::shared_ptr<DeviceInfoContext>> &MutableDeviceInfo();
  100. private:
  101. std::shared_ptr<Data> data_;
  102. };
  103. /// \brief DeviceInfoContext defines different device contexts.
  104. class MS_API DeviceInfoContext : public std::enable_shared_from_this<DeviceInfoContext> {
  105. public:
  106. struct Data;
  107. DeviceInfoContext();
  108. virtual ~DeviceInfoContext() = default;
  109. /// \brief Get the type of this DeviceInfoContext.
  110. ///
  111. /// \return Type of this DeviceInfoContext.
  112. virtual enum DeviceType GetDeviceType() const = 0;
  113. /// \brief A similar function to RTTI is provided when the -fno-rtti compilation option is turned on, which converts
  114. /// DeviceInfoContext to a shared pointer of type T, and returns nullptr if the conversion fails.
  115. ///
  116. /// \param T Type
  117. /// \return A pointer of type T after conversion. If the conversion fails, it will be nullptr.
  118. template <class T>
  119. std::shared_ptr<T> Cast() {
  120. static_assert(std::is_base_of<DeviceInfoContext, T>::value, "Wrong cast type.");
  121. if (GetDeviceType() != T().GetDeviceType()) {
  122. return nullptr;
  123. }
  124. return std::static_pointer_cast<T>(shared_from_this());
  125. }
  126. /// \brief obtain provider's name
  127. ///
  128. /// \return provider's name.
  129. inline std::string GetProvider() const;
  130. /// \brief set provider's name.
  131. ///
  132. /// \param[in] provider define the provider's name.
  133. inline void SetProvider(const std::string &provider);
  134. /// \brief obtain provider's device type.
  135. ///
  136. /// \return provider's device type.
  137. inline std::string GetProviderDevice() const;
  138. /// \brief set provider's device type.
  139. ///
  140. /// \param[in] device define the provider's device type.EG: CPU.
  141. inline void SetProviderDevice(const std::string &device);
  142. /// \brief set memory allocator.
  143. ///
  144. /// \param[in] allocator define the memory allocator which can be defined by user.
  145. void SetAllocator(const std::shared_ptr<Allocator> &allocator);
  146. /// \brief obtain memory allocator.
  147. ///
  148. /// \return memory allocator.
  149. std::shared_ptr<Allocator> GetAllocator() const;
  150. protected:
  151. std::vector<char> GetProviderChar() const;
  152. void SetProvider(const std::vector<char> &provider);
  153. std::vector<char> GetProviderDeviceChar() const;
  154. void SetProviderDevice(const std::vector<char> &device);
  155. std::shared_ptr<Data> data_;
  156. };
  157. std::string DeviceInfoContext::GetProvider() const { return CharToString(GetProviderChar()); }
  158. void DeviceInfoContext::SetProvider(const std::string &provider) { SetProvider(StringToChar(provider)); }
  159. std::string DeviceInfoContext::GetProviderDevice() const { return CharToString(GetProviderDeviceChar()); }
  160. void DeviceInfoContext::SetProviderDevice(const std::string &device) { SetProviderDevice(StringToChar(device)); }
  161. /// \brief Derived from DeviceInfoContext, The configuration of the model running on the CPU. This option is only valid
  162. /// for MindSpore Lite.
  163. class MS_API CPUDeviceInfo : public DeviceInfoContext {
  164. public:
  165. /// \brief Get the type of this DeviceInfoContext.
  166. ///
  167. /// \return Type of this DeviceInfoContext.
  168. enum DeviceType GetDeviceType() const override { return DeviceType::kCPU; };
  169. /// \brief Set enables to perform the float16 inference
  170. ///
  171. /// \param[in] is_fp16 Enable float16 inference or not.
  172. void SetEnableFP16(bool is_fp16);
  173. /// \brief Get enables to perform the float16 inference
  174. ///
  175. /// \return Whether enable float16 inference.
  176. bool GetEnableFP16() const;
  177. };
  178. /// \brief Derived from DeviceInfoContext, The configuration of the model running on the NPU. This option is only valid
  179. /// for MindSpore Lite.
  180. class MS_API KirinNPUDeviceInfo : public DeviceInfoContext {
  181. public:
  182. /// \brief Get the type of this DeviceInfoContext.
  183. ///
  184. /// \return Type of this DeviceInfoContext.
  185. enum DeviceType GetDeviceType() const override { return DeviceType::kKirinNPU; };
  186. /// \brief Set the NPU frequency.
  187. ///
  188. /// \param[in] frequency Can be set to 1 (low power consumption), 2 (balanced), 3 (high performance), 4 (extreme
  189. /// performance), default as 3.
  190. void SetFrequency(int frequency);
  191. /// \brief Get the NPU frequency.
  192. ///
  193. /// \return NPU frequency
  194. int GetFrequency() const;
  195. };
  196. /// \brief Derived from DeviceInfoContext, The configuration of the model running on the GPU.
  197. class MS_API GPUDeviceInfo : public DeviceInfoContext {
  198. public:
  199. /// \brief Get the type of this DeviceInfoContext.
  200. ///
  201. /// \return Type of this DeviceInfoContext.
  202. enum DeviceType GetDeviceType() const override { return DeviceType::kGPU; };
  203. /// \brief Set device id.
  204. ///
  205. /// \param[in] device_id The device id.
  206. void SetDeviceID(uint32_t device_id);
  207. /// \brief Get the device id.
  208. ///
  209. /// \return The device id.
  210. uint32_t GetDeviceID() const;
  211. /// \brief Get the distribution rank id.
  212. ///
  213. /// \return The device id.
  214. int GetRankID() const;
  215. /// \brief Get the distribution group size.
  216. ///
  217. /// \return The device id.
  218. int GetGroupSize() const;
  219. /// \brief Set the precision mode.
  220. ///
  221. /// \param[in] precision_mode Optional "origin", "fp16". "origin" is set as default.
  222. inline void SetPrecisionMode(const std::string &precision_mode);
  223. /// \brief Get the precision mode.
  224. ///
  225. /// \return The precision mode.
  226. inline std::string GetPrecisionMode() const;
  227. /// \brief Set enables to perform the float16 inference
  228. ///
  229. /// \param[in] is_fp16 Enable float16 inference or not.
  230. void SetEnableFP16(bool is_fp16);
  231. /// \brief Get enables to perform the float16 inference
  232. ///
  233. /// \return Whether enable float16 inference.
  234. bool GetEnableFP16() const;
  235. /// \brief Set enables to sharing mem with OpenGL
  236. ///
  237. /// \param[in] is_enable_sharing_mem_with_gl Enable sharing OpenCL Memory with OpenGL or not.
  238. void SetEnableGLTexture(bool is_enable_gl_texture);
  239. /// \brief Get enables to sharing mem with OpenGL
  240. ///
  241. /// \return Whether enable sharing mem with OpenGL.
  242. bool GetEnableGLTexture() const;
  243. /// \brief Set current OpenGL context
  244. ///
  245. /// \param[in] gl_context Current OpenGL context.
  246. void SetGLContext(void *gl_context);
  247. /// \brief Get current OpenGL context
  248. ///
  249. /// \return the OpenCL context by OpenGL used.
  250. void *GetGLContext() const;
  251. /// \brief Set current OpenGL display
  252. ///
  253. /// \param[in] gl_display Current OpenGL display.
  254. void SetGLDisplay(void *gl_display);
  255. /// \brief Get current OpenGL display
  256. ///
  257. /// \return the OpenCL display by OpenGL used.
  258. void *GetGLDisplay() const;
  259. private:
  260. void SetPrecisionMode(const std::vector<char> &precision_mode);
  261. std::vector<char> GetPrecisionModeChar() const;
  262. };
  263. void GPUDeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
  264. SetPrecisionMode(StringToChar(precision_mode));
  265. }
  266. std::string GPUDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
  267. /// \brief Derived from DeviceInfoContext, The configuration of the model running on the Ascend. This option is
  268. /// invalid for MindSpore Lite.
  269. class MS_API AscendDeviceInfo : public DeviceInfoContext {
  270. public:
  271. /// \brief Get the type of this DeviceInfoContext.
  272. ///
  273. /// \return Type of this DeviceInfoContext.
  274. enum DeviceType GetDeviceType() const override { return DeviceType::kAscend; };
  275. /// \brief Set device id.
  276. ///
  277. /// \param[in] device_id The device id.
  278. void SetDeviceID(uint32_t device_id);
  279. /// \brief Get the device id.
  280. ///
  281. /// \return The device id.
  282. uint32_t GetDeviceID() const;
  283. /// \brief Set AIPP configuration file path.
  284. ///
  285. /// \param[in] cfg_path AIPP configuration file path.
  286. inline void SetInsertOpConfigPath(const std::string &cfg_path);
  287. /// \brief Get AIPP configuration file path.
  288. ///
  289. /// \return AIPP configuration file path.
  290. inline std::string GetInsertOpConfigPath() const;
  291. /// \brief Set format of model inputs.
  292. ///
  293. /// \param[in] format Optional "NCHW", "NHWC", etc.
  294. inline void SetInputFormat(const std::string &format);
  295. /// \brief Get format of model inputs.
  296. ///
  297. /// \return The format of model inputs.
  298. inline std::string GetInputFormat() const;
  299. /// \brief Set shape of model inputs.
  300. ///
  301. /// \param[in] shape e.g. "input_op_name1: 1,2,3,4;input_op_name2: 4,3,2,1".
  302. inline void SetInputShape(const std::string &shape);
  303. /// \brief Get shape of model inputs.
  304. ///
  305. /// \return The shape of model inputs.
  306. inline std::string GetInputShape() const;
  307. /// \brief Set shape of model inputs.
  308. ///
  309. /// \param[in] shape e.g. {{1, {1,2,3,4}}, {2, {4,3,2,1}}} means the first input shape 1,2,3,4 and the second input
  310. /// shape 4,3,2,1.
  311. void SetInputShapeMap(const std::map<int, std::vector<int>> &shape);
  312. /// \brief Get shape of model inputs.
  313. ///
  314. /// \return The shape of model inputs.
  315. std::map<int, std::vector<int>> GetInputShapeMap() const;
  316. void SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size);
  317. inline std::string GetDynamicBatchSize() const;
  318. /// \brief Set the dynamic image size of model inputs.
  319. ///
  320. /// \param[in] image size hw e.g. "66,88;32,64" means h1:66,w1:88; h2:32,w2:64.
  321. inline void SetDynamicImageSize(const std::string &dynamic_image_size);
  322. /// \brief Get dynamic image size of model inputs.
  323. ///
  324. /// \return The image size of model inputs.
  325. inline std::string GetDynamicImageSize() const;
  326. /// \brief Set type of model outputs.
  327. ///
  328. /// \param[in] output_type FP32, UINT8 or FP16, default as FP32.
  329. void SetOutputType(enum DataType output_type);
  330. /// \brief Get type of model outputs.
  331. ///
  332. /// \return The set type of model outputs.
  333. enum DataType GetOutputType() const;
  334. /// \brief Set precision mode of model.
  335. ///
  336. /// \param[in] precision_mode Optional "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype" and
  337. /// "allow_mix_precision", "force_fp16" is set as default
  338. inline void SetPrecisionMode(const std::string &precision_mode);
  339. /// \brief Get precision mode of model.
  340. ///
  341. /// \return The set type of model outputs
  342. inline std::string GetPrecisionMode() const;
  343. /// \brief Set op select implementation mode.
  344. ///
  345. /// \param[in] op_select_impl_mode Optional "high_performance" and "high_precision", "high_performance" is set as
  346. /// default.
  347. inline void SetOpSelectImplMode(const std::string &op_select_impl_mode);
  348. /// \brief Get op select implementation mode.
  349. ///
  350. /// \return The set op select implementation mode.
  351. inline std::string GetOpSelectImplMode() const;
  352. inline void SetFusionSwitchConfigPath(const std::string &cfg_path);
  353. inline std::string GetFusionSwitchConfigPath() const;
  354. // Optional "l1_optimize", "l2_optimize", "off_optimize" or "l1_and_l2_optimize", default as "l2_optimize"
  355. inline void SetBufferOptimizeMode(const std::string &buffer_optimize_mode);
  356. inline std::string GetBufferOptimizeMode() const;
  357. private:
  358. void SetInsertOpConfigPath(const std::vector<char> &cfg_path);
  359. std::vector<char> GetInsertOpConfigPathChar() const;
  360. void SetInputFormat(const std::vector<char> &format);
  361. std::vector<char> GetInputFormatChar() const;
  362. void SetInputShape(const std::vector<char> &shape);
  363. std::vector<char> GetInputShapeChar() const;
  364. std::vector<char> GetDynamicBatchSizeChar() const;
  365. void SetDynamicImageSize(const std::vector<char> &dynamic_image_size);
  366. std::vector<char> GetDynamicImageSizeChar() const;
  367. void SetPrecisionMode(const std::vector<char> &precision_mode);
  368. std::vector<char> GetPrecisionModeChar() const;
  369. void SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode);
  370. std::vector<char> GetOpSelectImplModeChar() const;
  371. void SetFusionSwitchConfigPath(const std::vector<char> &cfg_path);
  372. std::vector<char> GetFusionSwitchConfigPathChar() const;
  373. void SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode);
  374. std::vector<char> GetBufferOptimizeModeChar() const;
  375. };
  376. using Ascend310DeviceInfo = AscendDeviceInfo;
  377. using Ascend910DeviceInfo = AscendDeviceInfo;
  378. using Ascend710DeviceInfo = AscendDeviceInfo;
  379. void AscendDeviceInfo::SetInsertOpConfigPath(const std::string &cfg_path) {
  380. SetInsertOpConfigPath(StringToChar(cfg_path));
  381. }
  382. std::string AscendDeviceInfo::GetInsertOpConfigPath() const { return CharToString(GetInsertOpConfigPathChar()); }
  383. void AscendDeviceInfo::SetInputFormat(const std::string &format) { SetInputFormat(StringToChar(format)); }
  384. std::string AscendDeviceInfo::GetInputFormat() const { return CharToString(GetInputFormatChar()); }
  385. void AscendDeviceInfo::SetInputShape(const std::string &shape) { SetInputShape(StringToChar(shape)); }
  386. std::string AscendDeviceInfo::GetInputShape() const { return CharToString(GetInputShapeChar()); }
  387. std::string AscendDeviceInfo::GetDynamicBatchSize() const { return CharToString(GetDynamicBatchSizeChar()); }
  388. void AscendDeviceInfo::SetDynamicImageSize(const std::string &dynamic_image_size) {
  389. SetDynamicImageSize(StringToChar(dynamic_image_size));
  390. }
  391. std::string AscendDeviceInfo::GetDynamicImageSize() const { return CharToString(GetDynamicImageSizeChar()); }
  392. void AscendDeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
  393. SetPrecisionMode(StringToChar(precision_mode));
  394. }
  395. std::string AscendDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
  396. void AscendDeviceInfo::SetOpSelectImplMode(const std::string &op_select_impl_mode) {
  397. SetOpSelectImplMode(StringToChar(op_select_impl_mode));
  398. }
  399. std::string AscendDeviceInfo::GetOpSelectImplMode() const { return CharToString(GetOpSelectImplModeChar()); }
  400. void AscendDeviceInfo::SetFusionSwitchConfigPath(const std::string &cfg_path) {
  401. SetFusionSwitchConfigPath(StringToChar(cfg_path));
  402. }
  403. std::string AscendDeviceInfo::GetFusionSwitchConfigPath() const {
  404. return CharToString(GetFusionSwitchConfigPathChar());
  405. }
  406. void AscendDeviceInfo::SetBufferOptimizeMode(const std::string &buffer_optimize_mode) {
  407. SetBufferOptimizeMode(StringToChar(buffer_optimize_mode));
  408. }
  409. std::string AscendDeviceInfo::GetBufferOptimizeMode() const { return CharToString(GetBufferOptimizeModeChar()); }
  410. } // namespace mindspore
  411. #endif // MINDSPORE_INCLUDE_API_CONTEXT_H