|
|
|
@@ -753,7 +753,7 @@ void ForwardExecutor::RunOpInner(py::object *ret, const OpExecInfoPtr &op_exec_i |
|
|
|
} |
|
|
|
|
|
|
|
// Save cnode info and build grad graph |
|
|
|
if (grad()->need_construct_graph()) { |
|
|
|
if (grad()->need_construct_graph() && !grad()->in_cell_with_custom_bprop_()) { |
|
|
|
grad()->SaveOutputNodeMap(obj_id, out_real, cnode); |
|
|
|
grad()->DoOpGrad(op_exec_info, cnode, out_real); |
|
|
|
} |
|
|
|
@@ -1778,6 +1778,10 @@ void GradExecutor::InitResourceAndDfBuilder(const std::string &cell_id, const py |
|
|
|
void GradExecutor::NewGraphInner(py::object *ret, const py::object &cell, const py::args &args) { |
|
|
|
auto cell_id = GetCellId(cell, args); |
|
|
|
MS_LOG(DEBUG) << "NewGraphInner start " << args.size() << " " << cell_id; |
|
|
|
// When the cell has custom bprop, in_custom_bprop_cell is lager than 0 |
|
|
|
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) { |
|
|
|
custom_bprop_cell_count_ += 1; |
|
|
|
} |
|
|
|
if (cell_stack_.empty() && top_cell_ != nullptr) { |
|
|
|
// non-first step |
|
|
|
if (!top_cell()->IsSubCell(cell_id) && already_run_top_cell_.find(cell_id) != already_run_top_cell_.end()) { |
|
|
|
@@ -1910,6 +1914,7 @@ void GradExecutor::EndGraphInner(py::object *ret, const py::object &cell, const |
|
|
|
MS_LOG(DEBUG) << "Brop cell no need construct graph"; |
|
|
|
return; |
|
|
|
} |
|
|
|
DoGradForCustomBprop(cell, out, args); |
|
|
|
if ((cell_stack_.size() > 1 && !IsNestedGrad()) || (IsNestedGrad() && cell_stack_.size() != cell_nums())) { |
|
|
|
PopCellStack(); |
|
|
|
MS_LOG(DEBUG) << "Sub cell no need construct graph"; |
|
|
|
@@ -1927,6 +1932,46 @@ void GradExecutor::EndGraphInner(py::object *ret, const py::object &cell, const |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void GradExecutor::DoGradForCustomBprop(const py::object &cell, const py::object &out, const py::args &args) { |
|
|
|
if (!py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) { |
|
|
|
return; |
|
|
|
} |
|
|
|
custom_bprop_cell_count_ -= 1; |
|
|
|
if (custom_bprop_cell_count_ != 0) { |
|
|
|
return; |
|
|
|
} |
|
|
|
py::function bprop_func = py::getattr(cell, parse::CUSTOM_BPROP_NAME); |
|
|
|
auto fake_prim = std::make_shared<PrimitivePy>(prim::kPrimHookBackward->name(), py::object()); |
|
|
|
fake_prim->set_hook(bprop_func); |
|
|
|
const auto &cell_id = GetCellId(cell, args); |
|
|
|
(void)fake_prim->AddAttr("cell_id", MakeValue(cell_id)); |
|
|
|
(void)fake_prim->AddAttr(parse::CUSTOM_BPROP_NAME, MakeValue(true)); |
|
|
|
|
|
|
|
py::object code_obj = py::getattr(bprop_func, "__code__"); |
|
|
|
// Three parameters self, out and dout need to be excluded |
|
|
|
const size_t inputs_num = py::cast<int64_t>(py::getattr(code_obj, "co_argcount")) - 3; |
|
|
|
if (inputs_num > args.size()) { |
|
|
|
MS_LOG(EXCEPTION) << "Size of bprop func inputs[" << inputs_num << "] is larger than size of cell inputs[" |
|
|
|
<< args.size() << "]"; |
|
|
|
} |
|
|
|
|
|
|
|
py::list cell_inputs; |
|
|
|
for (size_t i = 0; i < inputs_num; i += 1) { |
|
|
|
cell_inputs.append(args[i]); |
|
|
|
} |
|
|
|
OpExecInfoPtr op_exec_info = std::make_shared<OpExecInfo>(); |
|
|
|
op_exec_info->op_name = fake_prim->name(); |
|
|
|
op_exec_info->py_primitive = fake_prim; |
|
|
|
op_exec_info->op_inputs = cell_inputs; |
|
|
|
|
|
|
|
abstract::AbstractBasePtrList args_spec_list; |
|
|
|
std::vector<int64_t> op_masks; |
|
|
|
auto cnode = forward()->MakeCNode(op_exec_info, &op_masks, &args_spec_list); |
|
|
|
DoOpGrad(op_exec_info, cnode, out); |
|
|
|
const std::string out_obj_id = GetId(out); |
|
|
|
SaveOutputNodeMap(out_obj_id, out, cnode); |
|
|
|
} |
|
|
|
|
|
|
|
void GradExecutor::UpdateBpropCellGraph(const py::object &cell, const std::string &cell_id) { |
|
|
|
if (!py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) { |
|
|
|
return; |
|
|
|
|