From: @wilfchen Reviewed-by: @limingqi107,@zhanghaibo5 Signed-off-by: @zhanghaibo5tags/v1.2.0-rc1
| @@ -22,7 +22,6 @@ | |||||
| #include "include/api/status.h" | #include "include/api/status.h" | ||||
| #include "include/api/types.h" | #include "include/api/types.h" | ||||
| #include "include/api/graph.h" | #include "include/api/graph.h" | ||||
| #include "include/api/context.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| class InputAndOutput; | class InputAndOutput; | ||||
| @@ -98,7 +97,6 @@ class MS_API GraphCell final : public Cell<GraphCell> { | |||||
| explicit GraphCell(Graph &&); | explicit GraphCell(Graph &&); | ||||
| explicit GraphCell(const std::shared_ptr<Graph> &); | explicit GraphCell(const std::shared_ptr<Graph> &); | ||||
| void SetContext(const std::shared_ptr<Context> &context); | |||||
| const std::shared_ptr<Graph> &GetGraph() const { return graph_; } | const std::shared_ptr<Graph> &GetGraph() const { return graph_; } | ||||
| Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override; | Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override; | ||||
| std::vector<MSTensor> GetInputs(); | std::vector<MSTensor> GetInputs(); | ||||
| @@ -78,8 +78,6 @@ GraphCell::GraphCell(Graph &&graph) | |||||
| executor_->SetGraph(graph_); | executor_->SetGraph(graph_); | ||||
| } | } | ||||
| void GraphCell::SetContext(const std::shared_ptr<Context> &context) { return executor_->SetContext(context); } | |||||
| Status GraphCell::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) { | Status GraphCell::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) { | ||||
| MS_EXCEPTION_IF_NULL(executor_); | MS_EXCEPTION_IF_NULL(executor_); | ||||
| return executor_->Run(inputs, outputs); | return executor_->Run(inputs, outputs); | ||||
| @@ -51,10 +51,7 @@ Status GPUGraphImpl::InitEnv() { | |||||
| ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode); | ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode); | ||||
| ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id_); | ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id_); | ||||
| ms_context->set_param<std::string>(MS_CTX_DEVICE_TARGET, kGPUDevice); | ms_context->set_param<std::string>(MS_CTX_DEVICE_TARGET, kGPUDevice); | ||||
| auto enable_trt = ModelContext::GetGpuTrtInferMode(graph_context_); | |||||
| if (enable_trt == "True") { | |||||
| ms_context->set_param<bool>(MS_CTX_ENABLE_INFER_OPT, true); | |||||
| } | |||||
| ms_context->set_param<bool>(MS_CTX_ENABLE_INFER_OPT, false); | |||||
| session_impl_ = session::SessionFactory::Get().Create(kGpuInferenceDevice); | session_impl_ = session::SessionFactory::Get().Create(kGpuInferenceDevice); | ||||
| if (session_impl_ == nullptr) { | if (session_impl_ == nullptr) { | ||||
| @@ -29,12 +29,11 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| class GraphCell::GraphImpl { | class GraphCell::GraphImpl { | ||||
| public: | public: | ||||
| GraphImpl() : graph_(nullptr), graph_context_(nullptr) {} | |||||
| GraphImpl() : graph_(nullptr) {} | |||||
| virtual ~GraphImpl() = default; | virtual ~GraphImpl() = default; | ||||
| std::shared_ptr<Graph::GraphData> &MutableGraphData() const { return graph_->graph_data_; } | std::shared_ptr<Graph::GraphData> &MutableGraphData() const { return graph_->graph_data_; } | ||||
| void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; } | void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; } | ||||
| void SetContext(const std::shared_ptr<Context> &context) { graph_context_ = context; } | |||||
| virtual Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) = 0; | virtual Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) = 0; | ||||
| virtual Status Load() = 0; | virtual Status Load() = 0; | ||||
| @@ -44,7 +43,6 @@ class GraphCell::GraphImpl { | |||||
| protected: | protected: | ||||
| std::shared_ptr<Graph> graph_; | std::shared_ptr<Graph> graph_; | ||||
| std::shared_ptr<Context> graph_context_; | |||||
| }; | }; | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| #endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_IMPL_H | #endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_IMPL_H | ||||
| @@ -70,7 +70,6 @@ std::shared_ptr<GraphCell> MsModel::GenerateGraphCell(const std::vector<std::vec | |||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| auto graph_cell = std::make_shared<GraphCell>(graph); | auto graph_cell = std::make_shared<GraphCell>(graph); | ||||
| MS_EXCEPTION_IF_NULL(graph_cell); | MS_EXCEPTION_IF_NULL(graph_cell); | ||||
| graph_cell->SetContext(model_context_); | |||||
| auto ret = ModelImpl::Load(graph_cell); | auto ret = ModelImpl::Load(graph_cell); | ||||
| if (ret != kSuccess) { | if (ret != kSuccess) { | ||||
| MS_LOG(ERROR) << "Load failed."; | MS_LOG(ERROR) << "Load failed."; | ||||
| @@ -96,7 +95,6 @@ Status MsModel::Build() { | |||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| auto graph_cell = std::make_shared<GraphCell>(graph); | auto graph_cell = std::make_shared<GraphCell>(graph); | ||||
| MS_EXCEPTION_IF_NULL(graph_cell); | MS_EXCEPTION_IF_NULL(graph_cell); | ||||
| graph_cell->SetContext(model_context_); | |||||
| auto ret = ModelImpl::Load(graph_cell); | auto ret = ModelImpl::Load(graph_cell); | ||||
| if (ret != kSuccess) { | if (ret != kSuccess) { | ||||
| MS_LOG(ERROR) << "Load failed."; | MS_LOG(ERROR) << "Load failed."; | ||||