|
|
|
@@ -79,6 +79,7 @@ |
|
|
|
#include "utils/ms_context.h" |
|
|
|
#include "utils/context/graph_kernel_flags.h" |
|
|
|
#include "utils/utils.h" |
|
|
|
#include "abstract/utils.h" |
|
|
|
#if ENABLE_CPU && ENABLE_GPU |
|
|
|
#include "ps/util.h" |
|
|
|
#include "ps/ps_cache/ps_cache_manager.h" |
|
|
|
@@ -269,6 +270,19 @@ bool UpdatedByAssign(const KernelGraphPtr &kernel_graph, const AnfNodePtr &node) |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
size_t GPUSession::UpdateGraphInputAbstract(AnfNodePtr input_node, tensor::TensorPtr tensor) { |
|
|
|
size_t size = LongToSize(tensor->data().nbytes()); |
|
|
|
if (input_node->isa<Parameter>() && input_node->cast<ParameterPtr>()->is_used_by_dynamic_kernel()) { |
|
|
|
auto tensor_shape = tensor->shape(); |
|
|
|
std::vector<size_t> shape_tmp; |
|
|
|
(void)std::transform(tensor_shape.begin(), tensor_shape.end(), std::back_inserter(shape_tmp), IntToSize); |
|
|
|
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input_node, 0)}, {shape_tmp}, |
|
|
|
input_node.get()); |
|
|
|
size = abstract::ShapeSize(shape_tmp) * abstract::TypeIdSize(tensor->data_type()); |
|
|
|
} |
|
|
|
return size; |
|
|
|
} |
|
|
|
|
|
|
|
void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, |
|
|
|
const std::vector<tensor::TensorPtr> &inputs_const) const { |
|
|
|
std::vector<tensor::TensorPtr> inputs(inputs_const); |
|
|
|
@@ -314,8 +328,8 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, |
|
|
|
tensor->set_device_address(device_address); |
|
|
|
} |
|
|
|
MS_EXCEPTION_IF_NULL(device_address); |
|
|
|
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), |
|
|
|
LongToSize(tensor->data().nbytes()), tensor->data_type(), |
|
|
|
auto size = UpdateGraphInputAbstract(input_node, tensor); |
|
|
|
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), size, tensor->data_type(), |
|
|
|
tensor->data_c())) { |
|
|
|
MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; |
|
|
|
} |
|
|
|
|