|
|
|
@@ -32,7 +32,6 @@ using mindspore::tensor::TensorPy; |
|
|
|
namespace mindspore { |
|
|
|
namespace session { |
|
|
|
namespace { |
|
|
|
std::set<AnfNodePtr> weight_infos; |
|
|
|
static TypeId GetDataType(const py::buffer_info &buf) { |
|
|
|
if (buf.format.size() == 1) { |
|
|
|
switch (buf.format.front()) { |
|
|
|
@@ -105,10 +104,33 @@ void AscendInferenceSession::LoadInputData(const std::shared_ptr<KernelGraph> &k |
|
|
|
MS_EXCEPTION_IF_NULL(pk_node); |
|
|
|
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); |
|
|
|
MS_EXCEPTION_IF_NULL(device_address); |
|
|
|
if (AnfAlgo::IsParameterWeight(pk_node)) { |
|
|
|
if (weight_infos.count(pk_node) != 0) { |
|
|
|
continue; |
|
|
|
if (!AnfAlgo::IsParameterWeight(pk_node)) { |
|
|
|
tensor = inputs[no_weight_input++]; |
|
|
|
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), |
|
|
|
LongToSize(tensor->data().nbytes()), tensor->data_type(), |
|
|
|
tensor->data_c())) { |
|
|
|
MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
GraphId AscendInferenceSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) { |
|
|
|
auto graph_id = AscendSession::CompileGraph(func_graph); |
|
|
|
auto kernel_graph = GetGraph(graph_id); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph); |
|
|
|
// load weight data to device |
|
|
|
auto input_nodes = kernel_graph->inputs(); |
|
|
|
for (size_t i = 0; i < input_nodes.size(); ++i) { |
|
|
|
if (!input_nodes[i]->isa<Parameter>()) { |
|
|
|
MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter"; |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto pk_node = input_nodes[i]->cast<ParameterPtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(pk_node); |
|
|
|
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); |
|
|
|
MS_EXCEPTION_IF_NULL(device_address); |
|
|
|
if (AnfAlgo::IsParameterWeight(pk_node)) { |
|
|
|
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(pk_node->default_param()); |
|
|
|
MS_EXCEPTION_IF_NULL(param_value); |
|
|
|
auto py_param = param_value->value(); |
|
|
|
@@ -120,16 +142,9 @@ void AscendInferenceSession::LoadInputData(const std::shared_ptr<KernelGraph> &k |
|
|
|
LongToSize(buf.size * buf.itemsize), buf_type, buf.ptr)) { |
|
|
|
MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; |
|
|
|
} |
|
|
|
weight_infos.insert(pk_node); |
|
|
|
} else { |
|
|
|
tensor = inputs[no_weight_input++]; |
|
|
|
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), |
|
|
|
LongToSize(tensor->data().nbytes()), tensor->data_type(), |
|
|
|
tensor->data_c())) { |
|
|
|
MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return graph_id; |
|
|
|
} |
|
|
|
} // namespace session |
|
|
|
} // namespace mindspore |