|
|
|
@@ -2843,6 +2843,33 @@ std::vector<size_t> GradExecutor::GetGradPositionArgs(const py::object &grad_pos |
|
|
|
MS_LOG(EXCEPTION) << "Grad position only support tuple."; |
|
|
|
} |
|
|
|
|
|
|
|
void GradExecutor::ShallowCopySensValue(const py::tuple &input_args, bool has_sens, VectorRef *run_args) { |
|
|
|
if (!has_sens) { |
|
|
|
return; |
|
|
|
} |
|
|
|
// Get index and number of sens args. |
|
|
|
size_t sens_index = input_args.size() - 1; |
|
|
|
size_t sens_num = 1; |
|
|
|
if (py::isinstance<py::tuple>(input_args[sens_index])) { |
|
|
|
py::tuple tuple_sens = py::cast<py::tuple>(input_args[sens_index]); |
|
|
|
sens_num = ConvertArgs(tuple_sens).size(); |
|
|
|
} |
|
|
|
// Shallow copy sens args to new sens args. |
|
|
|
MS_EXCEPTION_IF_NULL(run_args); |
|
|
|
for (size_t i = sens_index; i < sens_index + sens_num; ++i) { |
|
|
|
const auto &original_sens = (*run_args)[i]; |
|
|
|
if (utils::isa<ValuePtr>(original_sens)) { |
|
|
|
auto sens_value = utils::cast<ValuePtr>(original_sens); |
|
|
|
MS_EXCEPTION_IF_NULL(sens_value); |
|
|
|
auto new_sens_value = ShallowCopyTensorValue(sens_value); |
|
|
|
MS_EXCEPTION_IF_NULL(new_sens_value); |
|
|
|
MS_LOG(DEBUG) << "sens args [" << sens_value->ToString() << "] has been shallow copied to [" |
|
|
|
<< new_sens_value->ToString() << "]."; |
|
|
|
(*run_args)[i] = new_sens_value; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void GradExecutor::UpdateParamAbsByArgs(const py::list &args, const FuncGraphPtr &bprop_graph) { |
|
|
|
MS_EXCEPTION_IF_NULL(bprop_graph); |
|
|
|
const auto &bprop_params = bprop_graph->parameters(); |
|
|
|
@@ -3061,8 +3088,10 @@ void GradExecutor::RunGradGraph(py::object *ret, const py::object &cell, const p |
|
|
|
MS_LOG(DEBUG) << "Run resource ptr " << resource.get(); |
|
|
|
|
|
|
|
VectorRef arg_list; |
|
|
|
py::tuple converted_args = ConvertArgs(FilterTensorArgs(args, has_sens)); |
|
|
|
auto filter_args = FilterTensorArgs(args, has_sens); |
|
|
|
py::tuple converted_args = ConvertArgs(filter_args); |
|
|
|
pipeline::ProcessVmArgInner(converted_args, resource, &arg_list); |
|
|
|
ShallowCopySensValue(filter_args, has_sens, &arg_list); |
|
|
|
MS_LOG(DEBUG) << "Convert args size " << converted_args.size() << ", graph param size " << arg_list.size(); |
|
|
|
compile::VmEvalFuncPtr run = resource->GetResult(pipeline::kOutput).cast<compile::VmEvalFuncPtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(run); |
|
|
|
|