/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "include/api/context.h" #include #include #include #include "utils/log_adapter.h" constexpr auto kGlobalContextDeviceTarget = "mindspore.ascend.globalcontext.device_target"; constexpr auto kGlobalContextDeviceID = "mindspore.ascend.globalcontext.device_id"; constexpr auto kModelOptionInsertOpCfgPath = "mindspore.option.insert_op_config_file_path"; // aipp config file constexpr auto kModelOptionInputFormat = "mindspore.option.input_format"; // nchw or nhwc constexpr auto kModelOptionInputShape = "mindspore.option.input_shape"; // Mandatory while dynamic batch: e.g. "input_op_name1: n1,c2,h3,w4;input_op_name2: n4,c3,h2,w1" constexpr auto kModelOptionOutputType = "mindspore.option.output_type"; // "FP32", "UINT8" or "FP16", default as "FP32" constexpr auto kModelOptionPrecisionMode = "mindspore.option.precision_mode"; // "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype" or "allow_mix_precision", default as "force_fp16" constexpr auto kModelOptionOpSelectImplMode = "mindspore.option.op_select_impl_mode"; namespace mindspore { struct Context::Data { std::map params; }; Context::Context() : data(std::make_shared()) {} template >> static const U &GetValue(const std::shared_ptr &context, const std::string &key) { static U empty_result; if (context == nullptr || context->data == nullptr) { return empty_result; } auto iter = context->data->params.find(key); if (iter == context->data->params.end()) { return empty_result; } const std::any &value = iter->second; if (value.type() != typeid(U)) { return empty_result; } return std::any_cast(value); } std::shared_ptr GlobalContext::GetGlobalContext() { static std::shared_ptr g_context = std::make_shared(); return g_context; } void GlobalContext::SetGlobalDeviceTarget(const std::vector &device_target) { auto global_context = GetGlobalContext(); MS_EXCEPTION_IF_NULL(global_context); if (global_context->data == nullptr) { global_context->data = std::make_shared(); MS_EXCEPTION_IF_NULL(global_context->data); } global_context->data->params[kGlobalContextDeviceTarget] = CharToString(device_target); } std::vector GlobalContext::GetGlobalDeviceTargetChar() { auto global_context = GetGlobalContext(); MS_EXCEPTION_IF_NULL(global_context); const std::string &ref = GetValue(global_context, kGlobalContextDeviceTarget); return StringToChar(ref); } void GlobalContext::SetGlobalDeviceID(const uint32_t &device_id) { auto global_context = GetGlobalContext(); MS_EXCEPTION_IF_NULL(global_context); if (global_context->data == nullptr) { global_context->data = std::make_shared(); MS_EXCEPTION_IF_NULL(global_context->data); } global_context->data->params[kGlobalContextDeviceID] = device_id; } uint32_t GlobalContext::GetGlobalDeviceID() { auto global_context = GetGlobalContext(); MS_EXCEPTION_IF_NULL(global_context); return GetValue(global_context, kGlobalContextDeviceID); } void ModelContext::SetInsertOpConfigPath(const std::shared_ptr &context, const std::vector &cfg_path) { MS_EXCEPTION_IF_NULL(context); if (context->data == nullptr) { context->data = std::make_shared(); MS_EXCEPTION_IF_NULL(context->data); } context->data->params[kModelOptionInsertOpCfgPath] = CharToString(cfg_path); } std::vector ModelContext::GetInsertOpConfigPathChar(const std::shared_ptr &context) { MS_EXCEPTION_IF_NULL(context); const std::string &ref = GetValue(context, kModelOptionInsertOpCfgPath); return StringToChar(ref); } void ModelContext::SetInputFormat(const std::shared_ptr &context, const std::vector &format) { MS_EXCEPTION_IF_NULL(context); if (context->data == nullptr) { context->data = std::make_shared(); MS_EXCEPTION_IF_NULL(context->data); } context->data->params[kModelOptionInputFormat] = CharToString(format); } std::vector ModelContext::GetInputFormatChar(const std::shared_ptr &context) { MS_EXCEPTION_IF_NULL(context); const std::string &ref = GetValue(context, kModelOptionInputFormat); return StringToChar(ref); } void ModelContext::SetInputShape(const std::shared_ptr &context, const std::vector &shape) { MS_EXCEPTION_IF_NULL(context); if (context->data == nullptr) { context->data = std::make_shared(); MS_EXCEPTION_IF_NULL(context->data); } context->data->params[kModelOptionInputShape] = CharToString(shape); } std::vector ModelContext::GetInputShapeChar(const std::shared_ptr &context) { MS_EXCEPTION_IF_NULL(context); const std::string &ref = GetValue(context, kModelOptionInputShape); return StringToChar(ref); } void ModelContext::SetOutputType(const std::shared_ptr &context, enum DataType output_type) { MS_EXCEPTION_IF_NULL(context); if (context->data == nullptr) { context->data = std::make_shared(); MS_EXCEPTION_IF_NULL(context->data); } context->data->params[kModelOptionOutputType] = output_type; } enum DataType ModelContext::GetOutputType(const std::shared_ptr &context) { MS_EXCEPTION_IF_NULL(context); return GetValue(context, kModelOptionOutputType); } void ModelContext::SetPrecisionMode(const std::shared_ptr &context, const std::vector &precision_mode) { MS_EXCEPTION_IF_NULL(context); if (context->data == nullptr) { context->data = std::make_shared(); MS_EXCEPTION_IF_NULL(context->data); } context->data->params[kModelOptionPrecisionMode] = CharToString(precision_mode); } std::vector ModelContext::GetPrecisionModeChar(const std::shared_ptr &context) { MS_EXCEPTION_IF_NULL(context); const std::string &ref = GetValue(context, kModelOptionPrecisionMode); return StringToChar(ref); } void ModelContext::SetOpSelectImplMode(const std::shared_ptr &context, const std::vector &op_select_impl_mode) { MS_EXCEPTION_IF_NULL(context); if (context->data == nullptr) { context->data = std::make_shared(); MS_EXCEPTION_IF_NULL(context->data); } context->data->params[kModelOptionOpSelectImplMode] = CharToString(op_select_impl_mode); } std::vector ModelContext::GetOpSelectImplModeChar(const std::shared_ptr &context) { MS_EXCEPTION_IF_NULL(context); const std::string &ref = GetValue(context, kModelOptionOpSelectImplMode); return StringToChar(ref); } } // namespace mindspore