You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

context.cc 7.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "include/api/context.h"
  17. #include <any>
  18. #include <map>
  19. #include <type_traits>
  20. #include "utils/log_adapter.h"
  21. constexpr auto kGlobalContextDeviceTarget = "mindspore.ascend.globalcontext.device_target";
  22. constexpr auto kGlobalContextDeviceID = "mindspore.ascend.globalcontext.device_id";
  23. constexpr auto kModelOptionInsertOpCfgPath = "mindspore.option.insert_op_config_file_path"; // aipp config file
  24. constexpr auto kModelOptionInputFormat = "mindspore.option.input_format"; // nchw or nhwc
  25. constexpr auto kModelOptionInputShape = "mindspore.option.input_shape";
  26. // Mandatory while dynamic batch: e.g. "input_op_name1: n1,c2,h3,w4;input_op_name2: n4,c3,h2,w1"
  27. constexpr auto kModelOptionOutputType = "mindspore.option.output_type"; // "FP32", "UINT8" or "FP16", default as "FP32"
  28. constexpr auto kModelOptionPrecisionMode = "mindspore.option.precision_mode";
  29. // "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype" or "allow_mix_precision", default as "force_fp16"
  30. constexpr auto kModelOptionOpSelectImplMode = "mindspore.option.op_select_impl_mode";
  31. namespace mindspore {
  32. struct Context::Data {
  33. std::map<std::string, std::any> params;
  34. };
  35. Context::Context() : data(std::make_shared<Data>()) {}
  36. template <class T, typename U = std::remove_cv_t<std::remove_reference_t<T>>>
  37. static const U &GetValue(const std::shared_ptr<Context> &context, const std::string &key) {
  38. static U empty_result;
  39. if (context == nullptr || context->data == nullptr) {
  40. return empty_result;
  41. }
  42. auto iter = context->data->params.find(key);
  43. if (iter == context->data->params.end()) {
  44. return empty_result;
  45. }
  46. const std::any &value = iter->second;
  47. if (value.type() != typeid(U)) {
  48. return empty_result;
  49. }
  50. return std::any_cast<const U &>(value);
  51. }
  52. std::shared_ptr<Context> GlobalContext::GetGlobalContext() {
  53. static std::shared_ptr<Context> g_context = std::make_shared<Context>();
  54. return g_context;
  55. }
  56. void GlobalContext::SetGlobalDeviceTarget(const std::vector<char> &device_target) {
  57. auto global_context = GetGlobalContext();
  58. MS_EXCEPTION_IF_NULL(global_context);
  59. if (global_context->data == nullptr) {
  60. global_context->data = std::make_shared<Data>();
  61. MS_EXCEPTION_IF_NULL(global_context->data);
  62. }
  63. global_context->data->params[kGlobalContextDeviceTarget] = CharToString(device_target);
  64. }
  65. std::vector<char> GlobalContext::GetGlobalDeviceTargetChar() {
  66. auto global_context = GetGlobalContext();
  67. MS_EXCEPTION_IF_NULL(global_context);
  68. const std::string &ref = GetValue<std::string>(global_context, kGlobalContextDeviceTarget);
  69. return StringToChar(ref);
  70. }
  71. void GlobalContext::SetGlobalDeviceID(const uint32_t &device_id) {
  72. auto global_context = GetGlobalContext();
  73. MS_EXCEPTION_IF_NULL(global_context);
  74. if (global_context->data == nullptr) {
  75. global_context->data = std::make_shared<Data>();
  76. MS_EXCEPTION_IF_NULL(global_context->data);
  77. }
  78. global_context->data->params[kGlobalContextDeviceID] = device_id;
  79. }
  80. uint32_t GlobalContext::GetGlobalDeviceID() {
  81. auto global_context = GetGlobalContext();
  82. MS_EXCEPTION_IF_NULL(global_context);
  83. return GetValue<uint32_t>(global_context, kGlobalContextDeviceID);
  84. }
  85. void ModelContext::SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::vector<char> &cfg_path) {
  86. MS_EXCEPTION_IF_NULL(context);
  87. if (context->data == nullptr) {
  88. context->data = std::make_shared<Data>();
  89. MS_EXCEPTION_IF_NULL(context->data);
  90. }
  91. context->data->params[kModelOptionInsertOpCfgPath] = CharToString(cfg_path);
  92. }
  93. std::vector<char> ModelContext::GetInsertOpConfigPathChar(const std::shared_ptr<Context> &context) {
  94. MS_EXCEPTION_IF_NULL(context);
  95. const std::string &ref = GetValue<std::string>(context, kModelOptionInsertOpCfgPath);
  96. return StringToChar(ref);
  97. }
  98. void ModelContext::SetInputFormat(const std::shared_ptr<Context> &context, const std::vector<char> &format) {
  99. MS_EXCEPTION_IF_NULL(context);
  100. if (context->data == nullptr) {
  101. context->data = std::make_shared<Data>();
  102. MS_EXCEPTION_IF_NULL(context->data);
  103. }
  104. context->data->params[kModelOptionInputFormat] = CharToString(format);
  105. }
  106. std::vector<char> ModelContext::GetInputFormatChar(const std::shared_ptr<Context> &context) {
  107. MS_EXCEPTION_IF_NULL(context);
  108. const std::string &ref = GetValue<std::string>(context, kModelOptionInputFormat);
  109. return StringToChar(ref);
  110. }
  111. void ModelContext::SetInputShape(const std::shared_ptr<Context> &context, const std::vector<char> &shape) {
  112. MS_EXCEPTION_IF_NULL(context);
  113. if (context->data == nullptr) {
  114. context->data = std::make_shared<Data>();
  115. MS_EXCEPTION_IF_NULL(context->data);
  116. }
  117. context->data->params[kModelOptionInputShape] = CharToString(shape);
  118. }
  119. std::vector<char> ModelContext::GetInputShapeChar(const std::shared_ptr<Context> &context) {
  120. MS_EXCEPTION_IF_NULL(context);
  121. const std::string &ref = GetValue<std::string>(context, kModelOptionInputShape);
  122. return StringToChar(ref);
  123. }
  124. void ModelContext::SetOutputType(const std::shared_ptr<Context> &context, enum DataType output_type) {
  125. MS_EXCEPTION_IF_NULL(context);
  126. if (context->data == nullptr) {
  127. context->data = std::make_shared<Data>();
  128. MS_EXCEPTION_IF_NULL(context->data);
  129. }
  130. context->data->params[kModelOptionOutputType] = output_type;
  131. }
  132. enum DataType ModelContext::GetOutputType(const std::shared_ptr<Context> &context) {
  133. MS_EXCEPTION_IF_NULL(context);
  134. return GetValue<enum DataType>(context, kModelOptionOutputType);
  135. }
  136. void ModelContext::SetPrecisionMode(const std::shared_ptr<Context> &context, const std::vector<char> &precision_mode) {
  137. MS_EXCEPTION_IF_NULL(context);
  138. if (context->data == nullptr) {
  139. context->data = std::make_shared<Data>();
  140. MS_EXCEPTION_IF_NULL(context->data);
  141. }
  142. context->data->params[kModelOptionPrecisionMode] = CharToString(precision_mode);
  143. }
  144. std::vector<char> ModelContext::GetPrecisionModeChar(const std::shared_ptr<Context> &context) {
  145. MS_EXCEPTION_IF_NULL(context);
  146. const std::string &ref = GetValue<std::string>(context, kModelOptionPrecisionMode);
  147. return StringToChar(ref);
  148. }
  149. void ModelContext::SetOpSelectImplMode(const std::shared_ptr<Context> &context,
  150. const std::vector<char> &op_select_impl_mode) {
  151. MS_EXCEPTION_IF_NULL(context);
  152. if (context->data == nullptr) {
  153. context->data = std::make_shared<Data>();
  154. MS_EXCEPTION_IF_NULL(context->data);
  155. }
  156. context->data->params[kModelOptionOpSelectImplMode] = CharToString(op_select_impl_mode);
  157. }
  158. std::vector<char> ModelContext::GetOpSelectImplModeChar(const std::shared_ptr<Context> &context) {
  159. MS_EXCEPTION_IF_NULL(context);
  160. const std::string &ref = GetValue<std::string>(context, kModelOptionOpSelectImplMode);
  161. return StringToChar(ref);
  162. }
  163. } // namespace mindspore

A lightweight and high-performance service module that helps MindSpore developers efficiently deploy online inference services in the production environment.