| @@ -25,6 +25,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| class InputAndOutput; | class InputAndOutput; | ||||
| class Context; | |||||
| using Input = InputAndOutput; | using Input = InputAndOutput; | ||||
| using Output = InputAndOutput; | using Output = InputAndOutput; | ||||
| @@ -97,6 +98,7 @@ class MS_API GraphCell final : public Cell<GraphCell> { | |||||
| explicit GraphCell(Graph &&); | explicit GraphCell(Graph &&); | ||||
| explicit GraphCell(const std::shared_ptr<Graph> &); | explicit GraphCell(const std::shared_ptr<Graph> &); | ||||
| void SetContext(const std::shared_ptr<Context> &context); | |||||
| const std::shared_ptr<Graph> &GetGraph() const { return graph_; } | const std::shared_ptr<Graph> &GetGraph() const { return graph_; } | ||||
| Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override; | Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override; | ||||
| std::vector<MSTensor> GetInputs(); | std::vector<MSTensor> GetInputs(); | ||||
| @@ -212,6 +212,5 @@ std::string GpuInferenceSession::InputsInfo(const std::vector<ParameterPtr> &par | |||||
| } | } | ||||
| return graph + " " + actual; | return graph + " " + actual; | ||||
| } | } | ||||
| } // namespace session | } // namespace session | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -73,6 +73,18 @@ GraphCell::GraphCell(const std::shared_ptr<Graph> &graph) : graph_(graph) { MS_E | |||||
| GraphCell::GraphCell(Graph &&graph) : graph_(std::make_shared<Graph>(graph)) { MS_EXCEPTION_IF_NULL(graph_); } | GraphCell::GraphCell(Graph &&graph) : graph_(std::make_shared<Graph>(graph)) { MS_EXCEPTION_IF_NULL(graph_); } | ||||
| void GraphCell::SetContext(const std::shared_ptr<Context> &context) { | |||||
| if (executor_ == nullptr) { | |||||
| executor_ = Factory<GraphCell::GraphImpl>::Instance().Create(g_device_target); | |||||
| if (executor_ == nullptr) { | |||||
| MS_LOG(ERROR) << "Create graph impl for device target " << g_device_target << " failed."; | |||||
| return; | |||||
| } | |||||
| executor_->SetGraph(graph_); | |||||
| } | |||||
| executor_->SetContext(context); | |||||
| } | |||||
| Status GraphCell::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) { | Status GraphCell::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) { | ||||
| if (executor_ == nullptr) { | if (executor_ == nullptr) { | ||||
| executor_ = Factory<GraphCell::GraphImpl>::Instance().Create(g_device_target); | executor_ = Factory<GraphCell::GraphImpl>::Instance().Create(g_device_target); | ||||
| @@ -54,7 +54,17 @@ Status GPUGraphImpl::InitEnv() { | |||||
| ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode); | ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode); | ||||
| ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id_); | ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id_); | ||||
| ms_context->set_param<std::string>(MS_CTX_DEVICE_TARGET, kGPUDevice); | ms_context->set_param<std::string>(MS_CTX_DEVICE_TARGET, kGPUDevice); | ||||
| ms_context->set_param<bool>(MS_CTX_ENABLE_INFER_OPT, false); | |||||
| auto &device_infos = graph_context_->MutableDeviceInfo(); | |||||
| if (device_infos.size() != 1) { | |||||
| return kMCDeviceError; | |||||
| } | |||||
| auto gpu_info = device_infos[0]->Cast<NvidiaGPUDeviceInfo>(); | |||||
| if (gpu_info == nullptr) { | |||||
| return kMCDeviceError; | |||||
| } | |||||
| auto enable_trt = gpu_info->GetGpuTrtInferMode(); | |||||
| ms_context->set_param<bool>(MS_CTX_ENABLE_INFER_OPT, enable_trt); | |||||
| session_impl_ = session::SessionFactory::Get().Create(kGpuInferenceDevice); | session_impl_ = session::SessionFactory::Get().Create(kGpuInferenceDevice); | ||||
| if (session_impl_ == nullptr) { | if (session_impl_ == nullptr) { | ||||
| @@ -29,11 +29,12 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| class GraphCell::GraphImpl { | class GraphCell::GraphImpl { | ||||
| public: | public: | ||||
| GraphImpl() : graph_(nullptr) {} | |||||
| GraphImpl() : graph_(nullptr), graph_context_(nullptr) {} | |||||
| virtual ~GraphImpl() = default; | virtual ~GraphImpl() = default; | ||||
| std::shared_ptr<Graph::GraphData> &MutableGraphData() const { return graph_->graph_data_; } | std::shared_ptr<Graph::GraphData> &MutableGraphData() const { return graph_->graph_data_; } | ||||
| void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; } | void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; } | ||||
| void SetContext(const std::shared_ptr<Context> &context) { graph_context_ = context; } | |||||
| virtual Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) = 0; | virtual Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) = 0; | ||||
| virtual Status Load(uint32_t device_id) = 0; | virtual Status Load(uint32_t device_id) = 0; | ||||
| @@ -43,6 +44,7 @@ class GraphCell::GraphImpl { | |||||
| protected: | protected: | ||||
| std::shared_ptr<Graph> graph_; | std::shared_ptr<Graph> graph_; | ||||
| std::shared_ptr<Context> graph_context_; | |||||
| }; | }; | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_IMPL_H | #endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_IMPL_H | ||||
| @@ -74,6 +74,7 @@ std::shared_ptr<GraphCell> MsModel::GenerateGraphCell(const std::vector<std::vec | |||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| auto graph_cell = std::make_shared<GraphCell>(graph); | auto graph_cell = std::make_shared<GraphCell>(graph); | ||||
| MS_EXCEPTION_IF_NULL(graph_cell); | MS_EXCEPTION_IF_NULL(graph_cell); | ||||
| graph_cell->SetContext(model_context_); | |||||
| auto ret = ModelImpl::Load(graph_cell, GetDeviceID()); | auto ret = ModelImpl::Load(graph_cell, GetDeviceID()); | ||||
| if (ret != kSuccess) { | if (ret != kSuccess) { | ||||
| MS_LOG(ERROR) << "Load failed."; | MS_LOG(ERROR) << "Load failed."; | ||||
| @@ -99,6 +100,7 @@ Status MsModel::Build() { | |||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| auto graph_cell = std::make_shared<GraphCell>(graph); | auto graph_cell = std::make_shared<GraphCell>(graph); | ||||
| MS_EXCEPTION_IF_NULL(graph_cell); | MS_EXCEPTION_IF_NULL(graph_cell); | ||||
| graph_cell->SetContext(model_context_); | |||||
| auto ret = ModelImpl::Load(graph_cell, GetDeviceID()); | auto ret = ModelImpl::Load(graph_cell, GetDeviceID()); | ||||
| if (ret != kSuccess) { | if (ret != kSuccess) { | ||||
| MS_LOG(ERROR) << "Load failed."; | MS_LOG(ERROR) << "Load failed."; | ||||
| @@ -83,6 +83,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) { | |||||
| set_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL, false); | set_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL, false); | ||||
| set_param<bool>(MS_CTX_ENABLE_SPARSE, false); | set_param<bool>(MS_CTX_ENABLE_SPARSE, false); | ||||
| set_param<bool>(MS_CTX_ENABLE_PARALLEL_SPLIT, false); | set_param<bool>(MS_CTX_ENABLE_PARALLEL_SPLIT, false); | ||||
| set_param<bool>(MS_CTX_ENABLE_INFER_OPT, false); | |||||
| set_param<bool>(MS_CTX_GRAD_FOR_SCALAR, false); | set_param<bool>(MS_CTX_GRAD_FOR_SCALAR, false); | ||||
| backend_policy_ = policy_map_[policy]; | backend_policy_ = policy_map_[policy]; | ||||