Browse Source

!5372 fix laod input data error in pynative mode on gpu for master

Merge pull request !5372 from chujinjin/fix_load_input_data_error_in_pynative_for_master
tags/v1.0.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
4a86243c34
1 changed files with 5 additions and 1 deletions
  1. +5
    -1
      mindspore/ccsrc/backend/session/gpu_session.cc

+ 5
- 1
mindspore/ccsrc/backend/session/gpu_session.cc View File

@@ -133,7 +133,11 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
const std::vector<tensor::TensorPtr> &inputs_const) const { const std::vector<tensor::TensorPtr> &inputs_const) const {
std::vector<tensor::TensorPtr> inputs(inputs_const); std::vector<tensor::TensorPtr> inputs(inputs_const);
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
auto input_nodes = kernel_graph->inputs();
std::vector<AnfNodePtr> input_nodes;
for (const auto &input_node : kernel_graph->inputs()) {
auto params = AnfAlgo::GetAllOutput(input_node);
std::copy(params.begin(), params.end(), std::back_inserter(input_nodes));
}
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
if (inputs.size() != input_nodes.size()) { if (inputs.size() != input_nodes.size()) {


Loading…
Cancel
Save