|
|
|
@@ -95,14 +95,18 @@ uint32_t CheckpointSaveCallback(uint32_t graph_id, const std::map<std::string, g |
|
|
|
for (auto &item : params_list) { |
|
|
|
std::string name = item.first; |
|
|
|
std::shared_ptr<ge::Tensor> ge_tensor_ptr = std::make_shared<ge::Tensor>(item.second); |
|
|
|
TensorPtr tensor_ptr = GetMeTensorTransformed(graph_id, name, ge_tensor_ptr); |
|
|
|
if (tensor_ptr == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "Transform ge tensor to me tensor failed"; |
|
|
|
if (name.size() > 5 && name.compare(name.size() - 5, 5, "_temp") == 0) { |
|
|
|
continue; |
|
|
|
} else { |
|
|
|
TensorPtr tensor_ptr = GetMeTensorTransformed(graph_id, name, ge_tensor_ptr); |
|
|
|
if (tensor_ptr == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "Transform ge tensor to me tensor failed"; |
|
|
|
} |
|
|
|
py::dict param_dict; |
|
|
|
param_dict["name"] = name; |
|
|
|
param_dict["data"] = tensor_ptr; |
|
|
|
parameter_list.append(param_dict); |
|
|
|
} |
|
|
|
py::dict param_dict; |
|
|
|
param_dict["name"] = name; |
|
|
|
param_dict["data"] = tensor_ptr; |
|
|
|
parameter_list.append(param_dict); |
|
|
|
} |
|
|
|
py::bool_ ret = |
|
|
|
parse::python_adapter::CallPyFn(PYTHON_MOD_CALLBACK_MODULE, PYTHON_FUN_PROCESS_CHECKPOINT, parameter_list); |
|
|
|
|