Browse Source

fix bcewithlogitsloss op error in pynative

tags/v1.3.0
chujinjin 4 years ago
parent
commit
90feb6a6d2
2 changed files with 13 additions and 0 deletions
  1. +11
    -0
      mindspore/ccsrc/backend/session/gpu_session.cc
  2. +2
    -0
      mindspore/ccsrc/backend/session/gpu_session.h

+ 11
- 0
mindspore/ccsrc/backend/session/gpu_session.cc View File

@@ -189,6 +189,16 @@ void GPUSession::HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_gra
kernel_graph->SetExecOrderByDefault();
}

void GPUSession::RunOpOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
pm->AddPass(std::make_shared<opt::BCEWithLogitsLossFusion>());
optimizer->AddPassManager(pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
}

void GPUSession::RunOpHardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>();
@@ -558,6 +568,7 @@ void GPUSession::BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &grap
// Prepare the graph
auto kernel_graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask);
MS_EXCEPTION_IF_NULL(kernel_graph);
RunOpOptimize(kernel_graph);
SelectKernel(kernel_graph);
RunOpHardwareOptimize(kernel_graph);
StartKernelRT();


+ 2
- 0
mindspore/ccsrc/backend/session/gpu_session.h View File

@@ -66,6 +66,8 @@ class GPUSession : public SessionBasic {

void HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph);

void RunOpOptimize(const std::shared_ptr<KernelGraph> &kernel_graph);

void RunOpHardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph);

void GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_graph);


Loading…
Cancel
Save