Browse Source

gpu inference context

pull/13077/head
wilfChen 4 years ago
parent
commit
db2668d72a
7 changed files with 31 additions and 3 deletions
  1. +2
    -0
      include/api/cell.h
  2. +0
    -1
      mindspore/ccsrc/backend/session/gpu_inference_session.cc
  3. +12
    -0
      mindspore/ccsrc/cxx_api/cell.cc
  4. +11
    -1
      mindspore/ccsrc/cxx_api/graph/gpu/gpu_graph_impl.cc
  5. +3
    -1
      mindspore/ccsrc/cxx_api/graph/graph_impl.h
  6. +2
    -0
      mindspore/ccsrc/cxx_api/model/ms/ms_model.cc
  7. +1
    -0
      mindspore/core/utils/ms_context.cc

+ 2
- 0
include/api/cell.h View File

@@ -25,6 +25,7 @@

namespace mindspore {
class InputAndOutput;
class Context;
using Input = InputAndOutput;
using Output = InputAndOutput;

@@ -97,6 +98,7 @@ class MS_API GraphCell final : public Cell<GraphCell> {
explicit GraphCell(Graph &&);
explicit GraphCell(const std::shared_ptr<Graph> &);

void SetContext(const std::shared_ptr<Context> &context);
const std::shared_ptr<Graph> &GetGraph() const { return graph_; }
Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override;
std::vector<MSTensor> GetInputs();


+ 0
- 1
mindspore/ccsrc/backend/session/gpu_inference_session.cc View File

@@ -212,6 +212,5 @@ std::string GpuInferenceSession::InputsInfo(const std::vector<ParameterPtr> &par
}
return graph + " " + actual;
}

} // namespace session
} // namespace mindspore

+ 12
- 0
mindspore/ccsrc/cxx_api/cell.cc View File

@@ -73,6 +73,18 @@ GraphCell::GraphCell(const std::shared_ptr<Graph> &graph) : graph_(graph) { MS_E

GraphCell::GraphCell(Graph &&graph) : graph_(std::make_shared<Graph>(graph)) { MS_EXCEPTION_IF_NULL(graph_); }

void GraphCell::SetContext(const std::shared_ptr<Context> &context) {
if (executor_ == nullptr) {
executor_ = Factory<GraphCell::GraphImpl>::Instance().Create(g_device_target);
if (executor_ == nullptr) {
MS_LOG(ERROR) << "Create graph impl for device target " << g_device_target << " failed.";
return;
}
executor_->SetGraph(graph_);
}
executor_->SetContext(context);
}

Status GraphCell::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
if (executor_ == nullptr) {
executor_ = Factory<GraphCell::GraphImpl>::Instance().Create(g_device_target);


+ 11
- 1
mindspore/ccsrc/cxx_api/graph/gpu/gpu_graph_impl.cc View File

@@ -54,7 +54,17 @@ Status GPUGraphImpl::InitEnv() {
ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id_);
ms_context->set_param<std::string>(MS_CTX_DEVICE_TARGET, kGPUDevice);
ms_context->set_param<bool>(MS_CTX_ENABLE_INFER_OPT, false);

auto &device_infos = graph_context_->MutableDeviceInfo();
if (device_infos.size() != 1) {
return kMCDeviceError;
}
auto gpu_info = device_infos[0]->Cast<NvidiaGPUDeviceInfo>();
if (gpu_info == nullptr) {
return kMCDeviceError;
}
auto enable_trt = gpu_info->GetGpuTrtInferMode();
ms_context->set_param<bool>(MS_CTX_ENABLE_INFER_OPT, enable_trt);

session_impl_ = session::SessionFactory::Get().Create(kGpuInferenceDevice);
if (session_impl_ == nullptr) {


+ 3
- 1
mindspore/ccsrc/cxx_api/graph/graph_impl.h View File

@@ -29,11 +29,12 @@
namespace mindspore {
class GraphCell::GraphImpl {
public:
GraphImpl() : graph_(nullptr) {}
GraphImpl() : graph_(nullptr), graph_context_(nullptr) {}
virtual ~GraphImpl() = default;

std::shared_ptr<Graph::GraphData> &MutableGraphData() const { return graph_->graph_data_; }
void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; }
void SetContext(const std::shared_ptr<Context> &context) { graph_context_ = context; }

virtual Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) = 0;
virtual Status Load(uint32_t device_id) = 0;
@@ -43,6 +44,7 @@ class GraphCell::GraphImpl {

protected:
std::shared_ptr<Graph> graph_;
std::shared_ptr<Context> graph_context_;
};
} // namespace mindspore
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_IMPL_H

+ 2
- 0
mindspore/ccsrc/cxx_api/model/ms/ms_model.cc View File

@@ -74,6 +74,7 @@ std::shared_ptr<GraphCell> MsModel::GenerateGraphCell(const std::vector<std::vec
MS_EXCEPTION_IF_NULL(graph);
auto graph_cell = std::make_shared<GraphCell>(graph);
MS_EXCEPTION_IF_NULL(graph_cell);
graph_cell->SetContext(model_context_);
auto ret = ModelImpl::Load(graph_cell, GetDeviceID());
if (ret != kSuccess) {
MS_LOG(ERROR) << "Load failed.";
@@ -99,6 +100,7 @@ Status MsModel::Build() {
MS_EXCEPTION_IF_NULL(graph);
auto graph_cell = std::make_shared<GraphCell>(graph);
MS_EXCEPTION_IF_NULL(graph_cell);
graph_cell->SetContext(model_context_);
auto ret = ModelImpl::Load(graph_cell, GetDeviceID());
if (ret != kSuccess) {
MS_LOG(ERROR) << "Load failed.";


+ 1
- 0
mindspore/core/utils/ms_context.cc View File

@@ -83,6 +83,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
set_param<bool>(MS_CTX_ENABLE_GRAPH_KERNEL, false);
set_param<bool>(MS_CTX_ENABLE_SPARSE, false);
set_param<bool>(MS_CTX_ENABLE_PARALLEL_SPLIT, false);
set_param<bool>(MS_CTX_ENABLE_INFER_OPT, false);
set_param<bool>(MS_CTX_GRAD_FOR_SCALAR, false);

backend_policy_ = policy_map_[policy];


Loading…
Cancel
Save