You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

context.h 14 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_INCLUDE_API_CONTEXT_H
  17. #define MINDSPORE_INCLUDE_API_CONTEXT_H
  18. #include <string>
  19. #include <memory>
  20. #include <vector>
  21. #include <map>
  22. #include "include/api/types.h"
  23. #include "include/api/dual_abi_helper.h"
  24. namespace mindspore {
  25. enum DeviceType {
  26. kCPU = 0,
  27. kGPU,
  28. kKirinNPU,
  29. kAscend910,
  30. kAscend310,
  31. // add new type here
  32. kInvalidDeviceType = 100,
  33. };
  34. class Allocator;
  35. class Delegate;
  36. class DeviceInfoContext;
  37. /// \brief Context is used to store environment variables during execution.
  38. class MS_API Context {
  39. public:
  40. Context();
  41. ~Context() = default;
  42. /// \brief Set the number of threads at runtime. This option is only valid for MindSpore Lite.
  43. ///
  44. /// \param[in] thread_num the number of threads at runtime.
  45. void SetThreadNum(int32_t thread_num);
  46. /// \brief Get the current thread number setting.
  47. ///
  48. /// \return The current thread number setting.
  49. int32_t GetThreadNum() const;
  50. /// \brief Set the thread affinity to CPU cores.
  51. ///
  52. /// \param mode: 0: no affinities, 1: big cores first, 2: little cores first
  53. void SetThreadAffinity(int mode);
  54. int GetThreadAffinityMode() const;
  55. void SetThreadAffinity(const std::vector<int> &core_list);
  56. std::vector<int32_t> GetThreadAffinityCoreList() const;
  57. void SetEnableParallel(bool is_parallel);
  58. bool GetEnableParallel() const;
  59. void SetDelegate(const std::shared_ptr<Delegate> &delegate);
  60. std::shared_ptr<Delegate> GetDelegate() const;
  61. /// \brief Get a mutable reference of DeviceInfoContext vector in this context. Only MindSpore Lite supports
  62. /// heterogeneous scenarios with multiple members in the vector.
  63. ///
  64. /// \return Mutable reference of DeviceInfoContext vector in this context.
  65. std::vector<std::shared_ptr<DeviceInfoContext>> &MutableDeviceInfo();
  66. private:
  67. struct Data;
  68. std::shared_ptr<Data> data_;
  69. };
  70. /// \brief DeviceInfoContext defines different device contexts.
  71. class MS_API DeviceInfoContext : public std::enable_shared_from_this<DeviceInfoContext> {
  72. public:
  73. struct Data;
  74. DeviceInfoContext();
  75. virtual ~DeviceInfoContext() = default;
  76. /// \brief Get the type of this DeviceInfoContext.
  77. ///
  78. /// \return Type of this DeviceInfoContext.
  79. virtual enum DeviceType GetDeviceType() const = 0;
  80. /// \brief A similar function to RTTI is provided when the -fno-rtti compilation option is turned on, which converts
  81. /// DeviceInfoContext to a shared pointer of type T, and returns nullptr if the conversion fails.
  82. ///
  83. /// \param T Type
  84. /// \return A pointer of type T after conversion. If the conversion fails, it will be nullptr.
  85. template <class T>
  86. std::shared_ptr<T> Cast() {
  87. static_assert(std::is_base_of<DeviceInfoContext, T>::value, "Wrong cast type.");
  88. if (GetDeviceType() != T().GetDeviceType()) {
  89. return nullptr;
  90. }
  91. return std::static_pointer_cast<T>(shared_from_this());
  92. }
  93. std::string GetProvider() const;
  94. void SetProvider(const std::string &provider);
  95. std::string GetProviderDevice() const;
  96. void SetProviderDevice(const std::string &device);
  97. void SetAllocator(const std::shared_ptr<Allocator> &allocator);
  98. std::shared_ptr<Allocator> GetAllocator() const;
  99. protected:
  100. std::shared_ptr<Data> data_;
  101. };
  102. /// \brief Derived from DeviceInfoContext, The configuration of the model running on the CPU. This option is only valid
  103. /// for MindSpore Lite.
  104. class MS_API CPUDeviceInfo : public DeviceInfoContext {
  105. public:
  106. /// \brief Get the type of this DeviceInfoContext.
  107. ///
  108. /// \return Type of this DeviceInfoContext.
  109. enum DeviceType GetDeviceType() const override { return DeviceType::kCPU; };
  110. /// \brief Set enables to perform the float16 inference
  111. ///
  112. /// \param[in] is_fp16 Enable float16 inference or not.
  113. void SetEnableFP16(bool is_fp16);
  114. /// \brief Get enables to perform the float16 inference
  115. ///
  116. /// \return Whether enable float16 inference.
  117. bool GetEnableFP16() const;
  118. };
  119. /// \brief Derived from DeviceInfoContext, The configuration of the model running on the NPU. This option is only valid
  120. /// for MindSpore Lite.
  121. class MS_API KirinNPUDeviceInfo : public DeviceInfoContext {
  122. public:
  123. /// \brief Get the type of this DeviceInfoContext.
  124. ///
  125. /// \return Type of this DeviceInfoContext.
  126. enum DeviceType GetDeviceType() const override { return DeviceType::kKirinNPU; };
  127. /// \brief Set the NPU frequency.
  128. ///
  129. /// \param[in] frequency Can be set to 1 (low power consumption), 2 (balanced), 3 (high performance), 4 (extreme
  130. /// performance), default as 3.
  131. void SetFrequency(int frequency);
  132. /// \brief Get the NPU frequency.
  133. ///
  134. /// \return NPU frequency
  135. int GetFrequency() const;
  136. };
  137. /// \brief Derived from DeviceInfoContext, The configuration of the model running on the GPU.
  138. class MS_API GPUDeviceInfo : public DeviceInfoContext {
  139. public:
  140. /// \brief Get the type of this DeviceInfoContext.
  141. ///
  142. /// \return Type of this DeviceInfoContext.
  143. enum DeviceType GetDeviceType() const override { return DeviceType::kGPU; };
  144. /// \brief Set device id.
  145. ///
  146. /// \param[in] device_id The device id.
  147. void SetDeviceID(uint32_t device_id);
  148. /// \brief Get the device id.
  149. ///
  150. /// \return The device id.
  151. uint32_t GetDeviceID() const;
  152. void SetGpuTrtInferMode(bool gpu_trt_infer_mode);
  153. bool GetGpuTrtInferMode() const;
  154. inline void SetPrecisionMode(const std::string &precison_mode);
  155. inline std::string GetPrecisionMode() const;
  156. /// \brief Set enables to perform the float16 inference
  157. ///
  158. /// \param[in] is_fp16 Enable float16 inference or not.
  159. void SetEnableFP16(bool is_fp16);
  160. /// \brief Get enables to perform the float16 inference
  161. ///
  162. /// \return Whether enable float16 inference.
  163. bool GetEnableFP16() const;
  164. private:
  165. void SetPrecisionMode(const std::vector<char> &precision_mode);
  166. std::vector<char> GetPrecisionModeChar() const;
  167. };
  168. void GPUDeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
  169. SetPrecisionMode(StringToChar(precision_mode));
  170. }
  171. std::string GPUDeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
  172. /// \brief Derived from DeviceInfoContext, The configuration of the model running on the Ascend910. This option is
  173. /// invalid for MindSpore Lite.
  174. class MS_API Ascend910DeviceInfo : public DeviceInfoContext {
  175. public:
  176. /// \brief Get the type of this DeviceInfoContext.
  177. ///
  178. /// \return Type of this DeviceInfoContext.
  179. enum DeviceType GetDeviceType() const override { return DeviceType::kAscend910; };
  180. /// \brief Set device id.
  181. ///
  182. /// \param[in] device_id The device id.
  183. void SetDeviceID(uint32_t device_id);
  184. /// \brief Get the device id.
  185. ///
  186. /// \return The device id.
  187. uint32_t GetDeviceID() const;
  188. };
  189. /// \brief Derived from DeviceInfoContext, The configuration of the model running on the Ascend310. This option is
  190. /// invalid for MindSpore Lite.
  191. class MS_API Ascend310DeviceInfo : public DeviceInfoContext {
  192. public:
  193. /// \brief Get the type of this DeviceInfoContext.
  194. ///
  195. /// \return Type of this DeviceInfoContext.
  196. enum DeviceType GetDeviceType() const override { return DeviceType::kAscend310; };
  197. /// \brief Set device id.
  198. ///
  199. /// \param[in] device_id The device id.
  200. void SetDeviceID(uint32_t device_id);
  201. /// \brief Get the device id.
  202. ///
  203. /// \return The device id.
  204. uint32_t GetDeviceID() const;
  205. inline void SetDumpConfigPath(const std::string &cfg_path);
  206. inline std::string GetDumpConfigPath() const;
  207. /// \brief Set AIPP configuration file path.
  208. ///
  209. /// \param[in] cfg_path AIPP configuration file path.
  210. inline void SetInsertOpConfigPath(const std::string &cfg_path);
  211. /// \brief Get AIPP configuration file path.
  212. ///
  213. /// \return AIPP configuration file path.
  214. inline std::string GetInsertOpConfigPath() const;
  215. /// \brief Set format of model inputs.
  216. ///
  217. /// \param[in] format Optional "NCHW", "NHWC", etc.
  218. inline void SetInputFormat(const std::string &format);
  219. /// \brief Get format of model inputs.
  220. ///
  221. /// \return The format of model inputs.
  222. inline std::string GetInputFormat() const;
  223. /// \brief Set shape of model inputs.
  224. ///
  225. /// \param[in] shape e.g. "input_op_name1: 1,2,3,4;input_op_name2: 4,3,2,1".
  226. inline void SetInputShape(const std::string &shape);
  227. /// \brief Get shape of model inputs.
  228. ///
  229. /// \return The shape of model inputs.
  230. inline std::string GetInputShape() const;
  231. /// \brief Set shape of model inputs.
  232. ///
  233. /// \param[in] shape e.g. {{1, {1,2,3,4}}, {2, {4,3,2,1}}} means the first input shape 1,2,3,4 and the second input
  234. /// shape 4,3,2,1.
  235. void SetInputShapeMap(const std::map<int, std::vector<int>> &shape);
  236. /// \brief Get shape of model inputs.
  237. ///
  238. /// \return The shape of model inputs.
  239. std::map<int, std::vector<int>> GetInputShapeMap() const;
  240. void SetDynamicBatchSize(const std::vector<size_t> &dynamic_batch_size);
  241. inline std::string GetDynamicBatchSize() const;
  242. /// \brief Set type of model outputs.
  243. ///
  244. /// \param[in] output_type FP32, UINT8 or FP16, default as FP32.
  245. void SetOutputType(enum DataType output_type);
  246. /// \brief Get type of model outputs.
  247. ///
  248. /// \return The set type of model outputs.
  249. enum DataType GetOutputType() const;
  250. /// \brief Set precision mode of model.
  251. ///
  252. /// \param[in] precision_mode Optional "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype" and
  253. /// "allow_mix_precision", "force_fp16" is set as default
  254. inline void SetPrecisionMode(const std::string &precision_mode);
  255. /// \brief Get precision mode of model.
  256. ///
  257. /// \return The set type of model outputs
  258. inline std::string GetPrecisionMode() const;
  259. /// \brief Set op select implementation mode.
  260. ///
  261. /// \param[in] op_select_impl_mode Optional "high_performance" and "high_precision", "high_performance" is set as
  262. /// default.
  263. inline void SetOpSelectImplMode(const std::string &op_select_impl_mode);
  264. /// \brief Get op select implementation mode.
  265. ///
  266. /// \return The set op select implementation mode.
  267. inline std::string GetOpSelectImplMode() const;
  268. inline void SetFusionSwitchConfigPath(const std::string &cfg_path);
  269. inline std::string GetFusionSwitchConfigPath() const;
  270. // Optional "l1_optimize", "l2_optimize", "off_optimize" or "l1_and_l2_optimize", default as "l2_optimize"
  271. inline void SetBufferOptimizeMode(const std::string &buffer_optimize_mode);
  272. inline std::string GetBufferOptimizeMode() const;
  273. private:
  274. void SetDumpConfigPath(const std::vector<char> &cfg_path);
  275. std::vector<char> GetDumpConfigPathChar() const;
  276. void SetInsertOpConfigPath(const std::vector<char> &cfg_path);
  277. std::vector<char> GetInsertOpConfigPathChar() const;
  278. void SetInputFormat(const std::vector<char> &format);
  279. std::vector<char> GetInputFormatChar() const;
  280. void SetInputShape(const std::vector<char> &shape);
  281. std::vector<char> GetInputShapeChar() const;
  282. std::vector<char> GetDynamicBatchSizeChar() const;
  283. void SetPrecisionMode(const std::vector<char> &precision_mode);
  284. std::vector<char> GetPrecisionModeChar() const;
  285. void SetOpSelectImplMode(const std::vector<char> &op_select_impl_mode);
  286. std::vector<char> GetOpSelectImplModeChar() const;
  287. void SetFusionSwitchConfigPath(const std::vector<char> &cfg_path);
  288. std::vector<char> GetFusionSwitchConfigPathChar() const;
  289. void SetBufferOptimizeMode(const std::vector<char> &buffer_optimize_mode);
  290. std::vector<char> GetBufferOptimizeModeChar() const;
  291. };
  292. void Ascend310DeviceInfo::SetDumpConfigPath(const std::string &cfg_path) { SetDumpConfigPath(StringToChar(cfg_path)); }
  293. std::string Ascend310DeviceInfo::GetDumpConfigPath() const { return CharToString(GetDumpConfigPathChar()); }
  294. void Ascend310DeviceInfo::SetInsertOpConfigPath(const std::string &cfg_path) {
  295. SetInsertOpConfigPath(StringToChar(cfg_path));
  296. }
  297. std::string Ascend310DeviceInfo::GetInsertOpConfigPath() const { return CharToString(GetInsertOpConfigPathChar()); }
  298. void Ascend310DeviceInfo::SetInputFormat(const std::string &format) { SetInputFormat(StringToChar(format)); }
  299. std::string Ascend310DeviceInfo::GetInputFormat() const { return CharToString(GetInputFormatChar()); }
  300. void Ascend310DeviceInfo::SetInputShape(const std::string &shape) { SetInputShape(StringToChar(shape)); }
  301. std::string Ascend310DeviceInfo::GetInputShape() const { return CharToString(GetInputShapeChar()); }
  302. std::string Ascend310DeviceInfo::GetDynamicBatchSize() const { return CharToString(GetDynamicBatchSizeChar()); }
  303. void Ascend310DeviceInfo::SetPrecisionMode(const std::string &precision_mode) {
  304. SetPrecisionMode(StringToChar(precision_mode));
  305. }
  306. std::string Ascend310DeviceInfo::GetPrecisionMode() const { return CharToString(GetPrecisionModeChar()); }
  307. void Ascend310DeviceInfo::SetOpSelectImplMode(const std::string &op_select_impl_mode) {
  308. SetOpSelectImplMode(StringToChar(op_select_impl_mode));
  309. }
  310. std::string Ascend310DeviceInfo::GetOpSelectImplMode() const { return CharToString(GetOpSelectImplModeChar()); }
  311. void Ascend310DeviceInfo::SetFusionSwitchConfigPath(const std::string &cfg_path) {
  312. SetFusionSwitchConfigPath(StringToChar(cfg_path));
  313. }
  314. std::string Ascend310DeviceInfo::GetFusionSwitchConfigPath() const {
  315. return CharToString(GetFusionSwitchConfigPathChar());
  316. }
  317. void Ascend310DeviceInfo::SetBufferOptimizeMode(const std::string &buffer_optimize_mode) {
  318. SetBufferOptimizeMode(StringToChar(buffer_optimize_mode));
  319. }
  320. std::string Ascend310DeviceInfo::GetBufferOptimizeMode() const { return CharToString(GetBufferOptimizeModeChar()); }
  321. } // namespace mindspore
  322. #endif // MINDSPORE_INCLUDE_API_CONTEXT_H