| @@ -572,7 +572,6 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con | |||||
| // run graph steps | // run graph steps | ||||
| void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, | void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, | ||||
| const std::vector<tensor::TensorPtr> &inputs_const) const { | const std::vector<tensor::TensorPtr> &inputs_const) const { | ||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||||
| std::vector<tensor::TensorPtr> inputs(inputs_const); | std::vector<tensor::TensorPtr> inputs(inputs_const); | ||||
| size_t input_ctrl_size = 1; | size_t input_ctrl_size = 1; | ||||
| MS_EXCEPTION_IF_NULL(context_); | MS_EXCEPTION_IF_NULL(context_); | ||||
| @@ -585,6 +584,8 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap | |||||
| MS_LOG(EXCEPTION) << "tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size() | MS_LOG(EXCEPTION) << "tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size() | ||||
| << ", input_ctrl_size:" << input_ctrl_size; | << ", input_ctrl_size:" << input_ctrl_size; | ||||
| } | } | ||||
| auto ms_context = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(ms_context); | |||||
| for (size_t i = 0; i < inputs.size(); ++i) { | for (size_t i = 0; i < inputs.size(); ++i) { | ||||
| auto tensor = inputs[i]; | auto tensor = inputs[i]; | ||||
| MS_EXCEPTION_IF_NULL(tensor); | MS_EXCEPTION_IF_NULL(tensor); | ||||
| @@ -594,8 +595,7 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap | |||||
| auto pk_node = input_node->cast<ParameterPtr>(); | auto pk_node = input_node->cast<ParameterPtr>(); | ||||
| auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); | auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); | ||||
| bool need_sync = false; | bool need_sync = false; | ||||
| MS_EXCEPTION_IF_NULL(MsContext::GetInstance()); | |||||
| if (MsContext::GetInstance()->enable_pynative_infer()) { | |||||
| if (ms_context->enable_pynative_infer()) { | |||||
| if (tensor->device_address().get() == nullptr || tensor->device_address() != device_address) { | if (tensor->device_address().get() == nullptr || tensor->device_address() != device_address) { | ||||
| need_sync = true; | need_sync = true; | ||||
| } | } | ||||