|
|
|
@@ -2857,7 +2857,7 @@ py::object GradExecutor::CheckAlreadyRun(const prim::GradOperationPtr &grad, con |
|
|
|
bool forward_run = false; |
|
|
|
// Get cell id and input args info |
|
|
|
const auto &cell_id = GetCellId(cell, args); |
|
|
|
grad_operation_ = std::to_string(grad->get_all_) + std::to_string(grad->get_by_list_); |
|
|
|
grad_operation_ = std::to_string(grad->get_all_) + std::to_string(grad->get_by_list_) + grad->grad_position_; |
|
|
|
|
|
|
|
std::string input_args_id; |
|
|
|
for (size_t i = 0; i < args.size(); ++i) { |
|
|
|
@@ -3282,6 +3282,10 @@ py::object PynativeExecutor::CheckGraph(const py::object &cell, const py::args & |
|
|
|
return grad_executor()->CheckGraph(cell, args); |
|
|
|
} |
|
|
|
|
|
|
|
void PynativeExecutor::set_grad_position(const prim::GradOperationPtr &grad, const py::object &grad_position) { |
|
|
|
grad->set_grad_position(std::string(py::str(grad_position))); |
|
|
|
} |
|
|
|
|
|
|
|
py::object PynativeExecutor::CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &cell, |
|
|
|
const py::args &args) { |
|
|
|
return grad_executor()->CheckAlreadyRun(grad, cell, args); |
|
|
|
@@ -3433,6 +3437,7 @@ REGISTER_PYBIND_DEFINE(PynativeExecutor_, ([](const py::module *m) { |
|
|
|
.def("__call__", &PynativeExecutor::Run, "pynative executor run grad graph.") |
|
|
|
.def("set_graph_phase", &PynativeExecutor::set_graph_phase, "pynative set graph phase") |
|
|
|
.def("grad_flag", &PynativeExecutor::grad_flag, "pynative grad flag") |
|
|
|
.def("set_grad_position", &PynativeExecutor::set_grad_position, "set pynative grad position") |
|
|
|
.def("set_grad_flag", &PynativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false), |
|
|
|
"Executor set grad flag.") |
|
|
|
.def("set_py_exe_path", &PynativeExecutor::set_py_exe_path, |
|
|
|
|