|
|
|
@@ -475,7 +475,7 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K |
|
|
|
cnode_inputs.emplace_back(new_value_node); |
|
|
|
} |
|
|
|
continue; |
|
|
|
} else if (anf->isa<Parameter>() && AnfAlgo::GetOutputTensorNum(anf) == 1) { |
|
|
|
} else if (anf->isa<Parameter>()) { |
|
|
|
auto new_parameter = CreateNewParameterFromParameter(anf, valid_input, graph); |
|
|
|
cnode_inputs.push_back(new_parameter); |
|
|
|
if (GetGraphIdByNode(anf) == kInvalidGraphId) { |
|
|
|
@@ -818,6 +818,25 @@ void SessionBasic::AddParameterToGraphInputs(const std::vector<AnfNodePtr> ¶ |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
namespace { |
|
|
|
bool TensorNeedSync(const AnfNodePtr ¶meter, const tensor::TensorPtr &tensor) { |
|
|
|
auto ms_context = MsContext::GetInstance(); |
|
|
|
MS_EXCEPTION_IF_NULL(ms_context); |
|
|
|
auto device_address = AnfAlgo::GetMutableOutputAddr(parameter, 0); |
|
|
|
if (ms_context->enable_pynative_infer()) { |
|
|
|
return tensor->device_address().get() == nullptr || tensor->device_address() != device_address; |
|
|
|
} |
|
|
|
if (tensor->is_dirty()) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
if (tensor->device_address() != device_address) { |
|
|
|
(void)tensor->data_sync(); |
|
|
|
return true; |
|
|
|
} |
|
|
|
return false; |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
// run graph steps |
|
|
|
void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, |
|
|
|
const std::vector<tensor::TensorPtr> &inputs_const) const { |
|
|
|
@@ -827,7 +846,11 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap |
|
|
|
if (kernel_graph->input_ctrl_tensors()) { |
|
|
|
input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs); |
|
|
|
} |
|
|
|
auto input_nodes = kernel_graph->inputs(); |
|
|
|
std::vector<AnfNodePtr> input_nodes; |
|
|
|
for (const auto &input_node : kernel_graph->inputs()) { |
|
|
|
auto params = AnfAlgo::GetAllOutput(input_node); |
|
|
|
std::copy(params.begin(), params.end(), std::back_inserter(input_nodes)); |
|
|
|
} |
|
|
|
if ((inputs.size() + input_ctrl_size) - 2 != input_nodes.size()) { |
|
|
|
MS_LOG(EXCEPTION) << "Tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size() |
|
|
|
<< ", input_ctrl_size:" << input_ctrl_size; |
|
|
|
@@ -838,33 +861,17 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap |
|
|
|
auto tensor = inputs[i]; |
|
|
|
MS_EXCEPTION_IF_NULL(tensor); |
|
|
|
auto input_node = input_nodes[i]; |
|
|
|
MS_EXCEPTION_IF_NULL(input_node); |
|
|
|
if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) { |
|
|
|
auto pk_node = input_node->cast<ParameterPtr>(); |
|
|
|
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); |
|
|
|
bool need_sync = false; |
|
|
|
if (ms_context->enable_pynative_infer()) { |
|
|
|
if (tensor->device_address().get() == nullptr || tensor->device_address() != device_address) { |
|
|
|
need_sync = true; |
|
|
|
} |
|
|
|
} else { |
|
|
|
if (tensor->is_dirty()) { |
|
|
|
need_sync = true; |
|
|
|
} else if (tensor->device_address() != device_address) { |
|
|
|
(void)tensor->data_sync(); |
|
|
|
need_sync = true; |
|
|
|
} |
|
|
|
if (TensorNeedSync(input_node, tensor) && input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) { |
|
|
|
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0); |
|
|
|
if (ms_context->execution_mode() == kPynativeMode || |
|
|
|
AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>())) { |
|
|
|
tensor->set_device_address(device_address); |
|
|
|
} |
|
|
|
if (need_sync) { |
|
|
|
if (ms_context->execution_mode() == kPynativeMode || AnfAlgo::IsParameterWeight(pk_node)) { |
|
|
|
tensor->set_device_address(device_address); |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(device_address); |
|
|
|
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), |
|
|
|
LongToSize(tensor->data().nbytes()), tensor->data_type(), |
|
|
|
tensor->data_c())) { |
|
|
|
MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(device_address); |
|
|
|
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(input_node, 0), |
|
|
|
LongToSize(tensor->data().nbytes()), tensor->data_type(), |
|
|
|
tensor->data_c())) { |
|
|
|
MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; |
|
|
|
} |
|
|
|
} |
|
|
|
tensor->set_dirty(false); |
|
|
|
|