You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

context.cc 5.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "include/api/context.h"
  17. #include "utils/log_adapter.h"
  18. constexpr auto kGlobalContextDeviceTarget = "mindspore.ascend.globalcontext.device_target";
  19. constexpr auto kGlobalContextDeviceID = "mindspore.ascend.globalcontext.device_id";
  20. constexpr auto kModelOptionInsertOpCfgPath = "mindspore.option.insert_op_config_file_path"; // aipp config file
  21. constexpr auto kModelOptionInputFormat = "mindspore.option.input_format"; // nchw or nhwc
  22. constexpr auto kModelOptionInputShape = "mindspore.option.input_shape";
  23. // Mandatory while dynamic batch: e.g. "input_op_name1: n1,c2,h3,w4;input_op_name2: n4,c3,h2,w1"
  24. constexpr auto kModelOptionOutputType = "mindspore.option.output_type"; // "FP32", "UINT8" or "FP16", default as "FP32"
  25. constexpr auto kModelOptionPrecisionMode = "mindspore.option.precision_mode";
  26. // "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype" or "allow_mix_precision", default as "force_fp16"
  27. constexpr auto kModelOptionOpSelectImplMode = "mindspore.option.op_select_impl_mode";
  28. namespace mindspore {
  29. template <class T>
  30. static T GetValue(const std::shared_ptr<Context> &context, const std::string &key) {
  31. auto iter = context->params.find(key);
  32. if (iter == context->params.end()) {
  33. return T();
  34. }
  35. const std::any &value = iter->second;
  36. if (value.type() != typeid(T)) {
  37. return T();
  38. }
  39. return std::any_cast<T>(value);
  40. }
  41. std::shared_ptr<Context> GlobalContext::GetGlobalContext() {
  42. static std::shared_ptr<Context> g_context = std::make_shared<Context>();
  43. return g_context;
  44. }
  45. void GlobalContext::SetGlobalDeviceTarget(const std::string &device_target) {
  46. auto global_context = GetGlobalContext();
  47. MS_EXCEPTION_IF_NULL(global_context);
  48. global_context->params[kGlobalContextDeviceTarget] = device_target;
  49. }
  50. std::string GlobalContext::GetGlobalDeviceTarget() {
  51. auto global_context = GetGlobalContext();
  52. MS_EXCEPTION_IF_NULL(global_context);
  53. return GetValue<std::string>(global_context, kGlobalContextDeviceTarget);
  54. }
  55. void GlobalContext::SetGlobalDeviceID(const uint32_t &device_id) {
  56. auto global_context = GetGlobalContext();
  57. MS_EXCEPTION_IF_NULL(global_context);
  58. global_context->params[kGlobalContextDeviceID] = device_id;
  59. }
  60. uint32_t GlobalContext::GetGlobalDeviceID() {
  61. auto global_context = GetGlobalContext();
  62. MS_EXCEPTION_IF_NULL(global_context);
  63. return GetValue<uint32_t>(global_context, kGlobalContextDeviceID);
  64. }
  65. void ModelContext::SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path) {
  66. MS_EXCEPTION_IF_NULL(context);
  67. context->params[kModelOptionInsertOpCfgPath] = cfg_path;
  68. }
  69. std::string ModelContext::GetInsertOpConfigPath(const std::shared_ptr<Context> &context) {
  70. MS_EXCEPTION_IF_NULL(context);
  71. return GetValue<std::string>(context, kModelOptionInsertOpCfgPath);
  72. }
  73. void ModelContext::SetInputFormat(const std::shared_ptr<Context> &context, const std::string &format) {
  74. MS_EXCEPTION_IF_NULL(context);
  75. context->params[kModelOptionInputFormat] = format;
  76. }
  77. std::string ModelContext::GetInputFormat(const std::shared_ptr<Context> &context) {
  78. MS_EXCEPTION_IF_NULL(context);
  79. return GetValue<std::string>(context, kModelOptionInputFormat);
  80. }
  81. void ModelContext::SetInputShape(const std::shared_ptr<Context> &context, const std::string &shape) {
  82. MS_EXCEPTION_IF_NULL(context);
  83. context->params[kModelOptionInputShape] = shape;
  84. }
  85. std::string ModelContext::GetInputShape(const std::shared_ptr<Context> &context) {
  86. MS_EXCEPTION_IF_NULL(context);
  87. return GetValue<std::string>(context, kModelOptionInputShape);
  88. }
  89. void ModelContext::SetOutputType(const std::shared_ptr<Context> &context, enum DataType output_type) {
  90. MS_EXCEPTION_IF_NULL(context);
  91. context->params[kModelOptionOutputType] = output_type;
  92. }
  93. enum DataType ModelContext::GetOutputType(const std::shared_ptr<Context> &context) {
  94. MS_EXCEPTION_IF_NULL(context);
  95. return GetValue<enum DataType>(context, kModelOptionOutputType);
  96. }
  97. void ModelContext::SetPrecisionMode(const std::shared_ptr<Context> &context, const std::string &precision_mode) {
  98. MS_EXCEPTION_IF_NULL(context);
  99. context->params[kModelOptionPrecisionMode] = precision_mode;
  100. }
  101. std::string ModelContext::GetPrecisionMode(const std::shared_ptr<Context> &context) {
  102. MS_EXCEPTION_IF_NULL(context);
  103. return GetValue<std::string>(context, kModelOptionPrecisionMode);
  104. }
  105. void ModelContext::SetOpSelectImplMode(const std::shared_ptr<Context> &context,
  106. const std::string &op_select_impl_mode) {
  107. MS_EXCEPTION_IF_NULL(context);
  108. context->params[kModelOptionOpSelectImplMode] = op_select_impl_mode;
  109. }
  110. std::string ModelContext::GetOpSelectImplMode(const std::shared_ptr<Context> &context) {
  111. MS_EXCEPTION_IF_NULL(context);
  112. return GetValue<std::string>(context, kModelOptionOpSelectImplMode);
  113. }
  114. } // namespace mindspore

A lightweight and high-performance service module that helps MindSpore developers efficiently deploy online inference services in the production environment.