Browse Source

!12974 gpu inference config

From: @wilfchen
Reviewed-by: @limingqi107,@zhanghaibo5
Signed-off-by: @zhanghaibo5
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
71cecd939c
5 changed files with 2 additions and 13 deletions
  1. +0
    -2
      include/api/cell.h
  2. +0
    -2
      mindspore/ccsrc/cxx_api/cell.cc
  3. +1
    -4
      mindspore/ccsrc/cxx_api/graph/gpu/gpu_graph_impl.cc
  4. +1
    -3
      mindspore/ccsrc/cxx_api/graph/graph_impl.h
  5. +0
    -2
      mindspore/ccsrc/cxx_api/model/ms/ms_model.cc

+ 0
- 2
include/api/cell.h View File

@@ -22,7 +22,6 @@
#include "include/api/status.h"
#include "include/api/types.h"
#include "include/api/graph.h"
#include "include/api/context.h"

namespace mindspore {
class InputAndOutput;
@@ -98,7 +97,6 @@ class MS_API GraphCell final : public Cell<GraphCell> {
explicit GraphCell(Graph &&);
explicit GraphCell(const std::shared_ptr<Graph> &);

void SetContext(const std::shared_ptr<Context> &context);
const std::shared_ptr<Graph> &GetGraph() const { return graph_; }
Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override;
std::vector<MSTensor> GetInputs();


+ 0
- 2
mindspore/ccsrc/cxx_api/cell.cc View File

@@ -78,8 +78,6 @@ GraphCell::GraphCell(Graph &&graph)
executor_->SetGraph(graph_);
}

void GraphCell::SetContext(const std::shared_ptr<Context> &context) { return executor_->SetContext(context); }

Status GraphCell::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
MS_EXCEPTION_IF_NULL(executor_);
return executor_->Run(inputs, outputs);


+ 1
- 4
mindspore/ccsrc/cxx_api/graph/gpu/gpu_graph_impl.cc View File

@@ -51,10 +51,7 @@ Status GPUGraphImpl::InitEnv() {
ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id_);
ms_context->set_param<std::string>(MS_CTX_DEVICE_TARGET, kGPUDevice);
auto enable_trt = ModelContext::GetGpuTrtInferMode(graph_context_);
if (enable_trt == "True") {
ms_context->set_param<bool>(MS_CTX_ENABLE_INFER_OPT, true);
}
ms_context->set_param<bool>(MS_CTX_ENABLE_INFER_OPT, false);

session_impl_ = session::SessionFactory::Get().Create(kGpuInferenceDevice);
if (session_impl_ == nullptr) {


+ 1
- 3
mindspore/ccsrc/cxx_api/graph/graph_impl.h View File

@@ -29,12 +29,11 @@
namespace mindspore {
class GraphCell::GraphImpl {
public:
GraphImpl() : graph_(nullptr), graph_context_(nullptr) {}
GraphImpl() : graph_(nullptr) {}
virtual ~GraphImpl() = default;

std::shared_ptr<Graph::GraphData> &MutableGraphData() const { return graph_->graph_data_; }
void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; }
void SetContext(const std::shared_ptr<Context> &context) { graph_context_ = context; }

virtual Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) = 0;
virtual Status Load() = 0;
@@ -44,7 +43,6 @@ class GraphCell::GraphImpl {

protected:
std::shared_ptr<Graph> graph_;
std::shared_ptr<Context> graph_context_;
};
} // namespace mindspore
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_IMPL_H

+ 0
- 2
mindspore/ccsrc/cxx_api/model/ms/ms_model.cc View File

@@ -70,7 +70,6 @@ std::shared_ptr<GraphCell> MsModel::GenerateGraphCell(const std::vector<std::vec
MS_EXCEPTION_IF_NULL(graph);
auto graph_cell = std::make_shared<GraphCell>(graph);
MS_EXCEPTION_IF_NULL(graph_cell);
graph_cell->SetContext(model_context_);
auto ret = ModelImpl::Load(graph_cell);
if (ret != kSuccess) {
MS_LOG(ERROR) << "Load failed.";
@@ -96,7 +95,6 @@ Status MsModel::Build() {
MS_EXCEPTION_IF_NULL(graph);
auto graph_cell = std::make_shared<GraphCell>(graph);
MS_EXCEPTION_IF_NULL(graph_cell);
graph_cell->SetContext(model_context_);
auto ret = ModelImpl::Load(graph_cell);
if (ret != kSuccess) {
MS_LOG(ERROR) << "Load failed.";


Loading…
Cancel
Save