|
|
|
@@ -572,7 +572,6 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con |
|
|
|
// run graph steps |
|
|
|
void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, |
|
|
|
const std::vector<tensor::TensorPtr> &inputs_const) const { |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph); |
|
|
|
std::vector<tensor::TensorPtr> inputs(inputs_const); |
|
|
|
size_t input_ctrl_size = 1; |
|
|
|
MS_EXCEPTION_IF_NULL(context_); |
|
|
|
@@ -585,6 +584,8 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap |
|
|
|
MS_LOG(EXCEPTION) << "tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size() |
|
|
|
<< ", input_ctrl_size:" << input_ctrl_size; |
|
|
|
} |
|
|
|
auto ms_context = MsContext::GetInstance(); |
|
|
|
MS_EXCEPTION_IF_NULL(ms_context); |
|
|
|
for (size_t i = 0; i < inputs.size(); ++i) { |
|
|
|
auto tensor = inputs[i]; |
|
|
|
MS_EXCEPTION_IF_NULL(tensor); |
|
|
|
@@ -594,8 +595,7 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap |
|
|
|
auto pk_node = input_node->cast<ParameterPtr>(); |
|
|
|
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); |
|
|
|
bool need_sync = false; |
|
|
|
MS_EXCEPTION_IF_NULL(MsContext::GetInstance()); |
|
|
|
if (MsContext::GetInstance()->enable_pynative_infer()) { |
|
|
|
if (ms_context->enable_pynative_infer()) { |
|
|
|
if (tensor->device_address().get() == nullptr || tensor->device_address() != device_address) { |
|
|
|
need_sync = true; |
|
|
|
} |
|
|
|
|