From: @wilfchen Reviewed-by: @cristoval,@limingqi107 Signed-off-by: @limingqi107tags/v1.2.0-rc1
| @@ -22,6 +22,7 @@ | |||||
| #include "include/api/status.h" | #include "include/api/status.h" | ||||
| #include "include/api/types.h" | #include "include/api/types.h" | ||||
| #include "include/api/graph.h" | #include "include/api/graph.h" | ||||
| #include "include/api/context.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| class InputAndOutput; | class InputAndOutput; | ||||
| @@ -97,6 +98,7 @@ class MS_API GraphCell final : public Cell<GraphCell> { | |||||
| explicit GraphCell(Graph &&); | explicit GraphCell(Graph &&); | ||||
| explicit GraphCell(const std::shared_ptr<Graph> &); | explicit GraphCell(const std::shared_ptr<Graph> &); | ||||
| void SetContext(const std::shared_ptr<Context> &context); | |||||
| const std::shared_ptr<Graph> &GetGraph() const { return graph_; } | const std::shared_ptr<Graph> &GetGraph() const { return graph_; } | ||||
| Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override; | Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override; | ||||
| std::vector<MSTensor> GetInputs(); | std::vector<MSTensor> GetInputs(); | ||||
| @@ -81,6 +81,9 @@ struct MS_API ModelContext : public Context { | |||||
| static inline void SetFusionSwitchConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path); | static inline void SetFusionSwitchConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path); | ||||
| static inline std::string GetFusionSwitchConfigPath(const std::shared_ptr<Context> &context); | static inline std::string GetFusionSwitchConfigPath(const std::shared_ptr<Context> &context); | ||||
| static inline void SetGpuTrtInferMode(const std::shared_ptr<Context> &context, const std::string &gpu_trt_infer_mode); | |||||
| static inline std::string GetGpuTrtInferMode(const std::shared_ptr<Context> &context); | |||||
| private: | private: | ||||
| // api without std::string | // api without std::string | ||||
| static void SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::vector<char> &cfg_path); | static void SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::vector<char> &cfg_path); | ||||
| @@ -101,6 +104,9 @@ struct MS_API ModelContext : public Context { | |||||
| static void SetFusionSwitchConfigPath(const std::shared_ptr<Context> &context, const std::vector<char> &cfg_path); | static void SetFusionSwitchConfigPath(const std::shared_ptr<Context> &context, const std::vector<char> &cfg_path); | ||||
| static std::vector<char> GetFusionSwitchConfigPathChar(const std::shared_ptr<Context> &context); | static std::vector<char> GetFusionSwitchConfigPathChar(const std::shared_ptr<Context> &context); | ||||
| static void SetGpuTrtInferMode(const std::shared_ptr<Context> &context, const std::vector<char> &gpu_trt_infer_mode); | |||||
| static std::vector<char> GetGpuTrtInferModeChar(const std::shared_ptr<Context> &context); | |||||
| }; | }; | ||||
| void GlobalContext::SetGlobalDeviceTarget(const std::string &device_target) { | void GlobalContext::SetGlobalDeviceTarget(const std::string &device_target) { | ||||
| @@ -155,5 +161,12 @@ void ModelContext::SetFusionSwitchConfigPath(const std::shared_ptr<Context> &con | |||||
| std::string ModelContext::GetFusionSwitchConfigPath(const std::shared_ptr<Context> &context) { | std::string ModelContext::GetFusionSwitchConfigPath(const std::shared_ptr<Context> &context) { | ||||
| return CharToString(GetFusionSwitchConfigPathChar(context)); | return CharToString(GetFusionSwitchConfigPathChar(context)); | ||||
| } | } | ||||
| void ModelContext::SetGpuTrtInferMode(const std::shared_ptr<Context> &context, const std::string &gpu_trt_infer_mode) { | |||||
| SetGpuTrtInferMode(context, StringToChar(gpu_trt_infer_mode)); | |||||
| } | |||||
| std::string ModelContext::GetGpuTrtInferMode(const std::shared_ptr<Context> &context) { | |||||
| return CharToString(GetGpuTrtInferModeChar(context)); | |||||
| } | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_INCLUDE_API_CONTEXT_H | #endif // MINDSPORE_INCLUDE_API_CONTEXT_H | ||||
| @@ -78,6 +78,8 @@ GraphCell::GraphCell(Graph &&graph) | |||||
| executor_->SetGraph(graph_); | executor_->SetGraph(graph_); | ||||
| } | } | ||||
| void GraphCell::SetContext(const std::shared_ptr<Context> &context) { return executor_->SetContext(context); } | |||||
| Status GraphCell::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) { | Status GraphCell::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) { | ||||
| MS_EXCEPTION_IF_NULL(executor_); | MS_EXCEPTION_IF_NULL(executor_); | ||||
| return executor_->Run(inputs, outputs); | return executor_->Run(inputs, outputs); | ||||
| @@ -31,6 +31,8 @@ constexpr auto kModelOptionPrecisionMode = "mindspore.option.precision_mode"; | |||||
| // "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype" or "allow_mix_precision", default as "force_fp16" | // "force_fp16", "allow_fp32_to_fp16", "must_keep_origin_dtype" or "allow_mix_precision", default as "force_fp16" | ||||
| constexpr auto kModelOptionOpSelectImplMode = "mindspore.option.op_select_impl_mode"; | constexpr auto kModelOptionOpSelectImplMode = "mindspore.option.op_select_impl_mode"; | ||||
| constexpr auto KModelOptionFusionSwitchCfgPath = "mindspore.option.fusion_switch_config_file_path"; | constexpr auto KModelOptionFusionSwitchCfgPath = "mindspore.option.fusion_switch_config_file_path"; | ||||
| // "False": Inference with native backend, "True": Inference with Tensor-RT engine, default as "False" | |||||
| constexpr auto kModelOptionGpuTrtInferMode = "mindspore.option.gpu_trt_infer_mode"; | |||||
| namespace mindspore { | namespace mindspore { | ||||
| struct Context::Data { | struct Context::Data { | ||||
| @@ -217,4 +219,20 @@ std::vector<char> ModelContext::GetFusionSwitchConfigPathChar(const std::shared_ | |||||
| const std::string &ref = GetValue<std::string>(context, KModelOptionFusionSwitchCfgPath); | const std::string &ref = GetValue<std::string>(context, KModelOptionFusionSwitchCfgPath); | ||||
| return StringToChar(ref); | return StringToChar(ref); | ||||
| } | } | ||||
| void ModelContext::SetGpuTrtInferMode(const std::shared_ptr<Context> &context, | |||||
| const std::vector<char> &gpu_trt_infer_mode) { | |||||
| MS_EXCEPTION_IF_NULL(context); | |||||
| if (context->data == nullptr) { | |||||
| context->data = std::make_shared<Data>(); | |||||
| MS_EXCEPTION_IF_NULL(context->data); | |||||
| } | |||||
| context->data->params[kModelOptionGpuTrtInferMode] = CharToString(gpu_trt_infer_mode); | |||||
| } | |||||
| std::vector<char> ModelContext::GetGpuTrtInferModeChar(const std::shared_ptr<Context> &context) { | |||||
| MS_EXCEPTION_IF_NULL(context); | |||||
| const std::string &ref = GetValue<std::string>(context, kModelOptionGpuTrtInferMode); | |||||
| return StringToChar(ref); | |||||
| } | |||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -51,7 +51,10 @@ Status GPUGraphImpl::InitEnv() { | |||||
| ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode); | ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode); | ||||
| ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id_); | ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id_); | ||||
| ms_context->set_param<std::string>(MS_CTX_DEVICE_TARGET, kGPUDevice); | ms_context->set_param<std::string>(MS_CTX_DEVICE_TARGET, kGPUDevice); | ||||
| ms_context->set_param<bool>(MS_CTX_ENABLE_INFER_OPT, true); | |||||
| auto enable_trt = ModelContext::GetGpuTrtInferMode(graph_context_); | |||||
| if (enable_trt == "True") { | |||||
| ms_context->set_param<bool>(MS_CTX_ENABLE_INFER_OPT, true); | |||||
| } | |||||
| session_impl_ = session::SessionFactory::Get().Create(kGpuInferenceDevice); | session_impl_ = session::SessionFactory::Get().Create(kGpuInferenceDevice); | ||||
| if (session_impl_ == nullptr) { | if (session_impl_ == nullptr) { | ||||
| @@ -29,11 +29,12 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| class GraphCell::GraphImpl { | class GraphCell::GraphImpl { | ||||
| public: | public: | ||||
| GraphImpl() = default; | |||||
| GraphImpl() : graph_(nullptr), graph_context_(nullptr) {} | |||||
| virtual ~GraphImpl() = default; | virtual ~GraphImpl() = default; | ||||
| std::shared_ptr<Graph::GraphData> &MutableGraphData() const { return graph_->graph_data_; } | std::shared_ptr<Graph::GraphData> &MutableGraphData() const { return graph_->graph_data_; } | ||||
| void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; } | void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; } | ||||
| void SetContext(const std::shared_ptr<Context> &context) { graph_context_ = context; } | |||||
| virtual Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) = 0; | virtual Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) = 0; | ||||
| virtual Status Load() = 0; | virtual Status Load() = 0; | ||||
| @@ -43,6 +44,7 @@ class GraphCell::GraphImpl { | |||||
| protected: | protected: | ||||
| std::shared_ptr<Graph> graph_; | std::shared_ptr<Graph> graph_; | ||||
| std::shared_ptr<Context> graph_context_; | |||||
| }; | }; | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_IMPL_H | #endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_IMPL_H | ||||
| @@ -70,6 +70,7 @@ std::shared_ptr<GraphCell> MsModel::GenerateGraphCell(const std::vector<std::vec | |||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| auto graph_cell = std::make_shared<GraphCell>(graph); | auto graph_cell = std::make_shared<GraphCell>(graph); | ||||
| MS_EXCEPTION_IF_NULL(graph_cell); | MS_EXCEPTION_IF_NULL(graph_cell); | ||||
| graph_cell->SetContext(model_context_); | |||||
| auto ret = ModelImpl::Load(graph_cell); | auto ret = ModelImpl::Load(graph_cell); | ||||
| if (ret != kSuccess) { | if (ret != kSuccess) { | ||||
| MS_LOG(ERROR) << "Load failed."; | MS_LOG(ERROR) << "Load failed."; | ||||
| @@ -95,6 +96,7 @@ Status MsModel::Build() { | |||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| auto graph_cell = std::make_shared<GraphCell>(graph); | auto graph_cell = std::make_shared<GraphCell>(graph); | ||||
| MS_EXCEPTION_IF_NULL(graph_cell); | MS_EXCEPTION_IF_NULL(graph_cell); | ||||
| graph_cell->SetContext(model_context_); | |||||
| auto ret = ModelImpl::Load(graph_cell); | auto ret = ModelImpl::Load(graph_cell); | ||||
| if (ret != kSuccess) { | if (ret != kSuccess) { | ||||
| MS_LOG(ERROR) << "Load failed."; | MS_LOG(ERROR) << "Load failed."; | ||||