|
|
|
@@ -182,11 +182,12 @@ void PrimitivePy::CheckHookConsistency(const py::object &grad_out, const py::obj |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
BaseRef PrimitivePy::RunBpropHookFunction(const py::tuple &py_args) const { |
|
|
|
BaseRef PrimitivePy::RunCellBpropFunction(const py::tuple &py_args) const { |
|
|
|
SyncData(py_args); |
|
|
|
auto size = py_args.size(); |
|
|
|
py::tuple input_args(size - 2); |
|
|
|
for (size_t i = 0; i < size - 2; ++i) { |
|
|
|
constexpr size_t grad_param_nums = 2; |
|
|
|
py::tuple input_args(size - grad_param_nums); |
|
|
|
for (size_t i = 0; i < size - grad_param_nums; ++i) { |
|
|
|
input_args[i] = py_args[i]; |
|
|
|
} |
|
|
|
py::tuple convert_args(py_args.size()); |
|
|
|
@@ -207,48 +208,62 @@ BaseRef PrimitivePy::RunBpropHookFunction(const py::tuple &py_args) const { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
BaseRef PrimitivePy::RunCellHookFunction(const py::tuple &py_args) const { |
|
|
|
constexpr size_t grad_input_index = 1; |
|
|
|
constexpr size_t grad_output_index = 2; |
|
|
|
constexpr size_t input_param_nums = 3; |
|
|
|
SyncData(py_args[grad_output_index]); |
|
|
|
|
|
|
|
py::object obj; |
|
|
|
auto cell_id = GetValue<std::string>(this->GetAttr(kCellIDAttrName)); |
|
|
|
auto iter = hook_grad_.find(cell_id); |
|
|
|
if (iter != hook_grad_.end()) { |
|
|
|
py::tuple convert_args(input_param_nums - 1); |
|
|
|
py::tuple input_args(input_param_nums - 1); |
|
|
|
input_args[0] = iter->second; |
|
|
|
input_args[1] = py_args[grad_output_index]; |
|
|
|
ConvertCTensorToPyTensor(input_args, &convert_args); |
|
|
|
auto hook_args = py::tuple(input_param_nums); |
|
|
|
hook_args[0] = cell_id; |
|
|
|
hook_args[grad_input_index] = py::make_tuple(convert_args[0]); |
|
|
|
hook_args[grad_output_index] = py::make_tuple(convert_args[1]); |
|
|
|
obj = hook_(*hook_args); |
|
|
|
if (py::isinstance<py::none>(obj)) { |
|
|
|
obj = py_args[grad_output_index]; |
|
|
|
} |
|
|
|
CheckHookConsistency(obj, py_args[grad_output_index]); |
|
|
|
(void)hook_grad_.erase(cell_id); |
|
|
|
} else { |
|
|
|
hook_grad_[cell_id] = py_args[grad_output_index]; |
|
|
|
obj = py_args[grad_output_index]; |
|
|
|
} |
|
|
|
obj = py::make_tuple(obj); |
|
|
|
return std::make_shared<PyObjectRef>(obj); |
|
|
|
} |
|
|
|
|
|
|
|
BaseRef PrimitivePy::RunVariableHookFunction(const py::tuple &py_args) const { |
|
|
|
constexpr size_t grad_output_index = 2; |
|
|
|
SyncData(py_args[grad_output_index]); |
|
|
|
py::object obj = hook_(py::make_tuple(py_args[grad_output_index])); |
|
|
|
if (py::isinstance<py::none>(obj)) { |
|
|
|
obj = py_args[grad_output_index]; |
|
|
|
} |
|
|
|
CheckHookConsistency(obj, py_args[grad_output_index]); |
|
|
|
obj = py::make_tuple(obj); |
|
|
|
return std::make_shared<PyObjectRef>(obj); |
|
|
|
} |
|
|
|
|
|
|
|
BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const { |
|
|
|
py::tuple py_args = ConvertDatatoPyTuple(args); |
|
|
|
bool is_bprop = this->HasAttr(kBpropAttrName); |
|
|
|
if (is_bprop) { |
|
|
|
return RunBpropHookFunction(py_args); |
|
|
|
return RunCellBpropFunction(py_args); |
|
|
|
} |
|
|
|
SyncData(py_args[2]); |
|
|
|
bool is_cell = this->HasAttr(kCellHookAttrName); |
|
|
|
py::object obj; |
|
|
|
if (is_cell) { |
|
|
|
auto cell_id = GetValue<std::string>(this->GetAttr(kCellIDAttrName)); |
|
|
|
auto iter = hook_grad_.find(cell_id); |
|
|
|
if (iter != hook_grad_.end()) { |
|
|
|
py::tuple convert_args(2); |
|
|
|
py::tuple input_args(2); |
|
|
|
input_args[0] = iter->second; |
|
|
|
input_args[1] = py_args[2]; |
|
|
|
ConvertCTensorToPyTensor(input_args, &convert_args); |
|
|
|
auto hook_args = py::tuple(3); |
|
|
|
hook_args[0] = cell_id; |
|
|
|
hook_args[1] = py::make_tuple(convert_args[0]); |
|
|
|
hook_args[2] = py::make_tuple(convert_args[1]); |
|
|
|
obj = hook_(*hook_args); |
|
|
|
if (py::isinstance<py::none>(obj)) { |
|
|
|
obj = py_args[2]; |
|
|
|
} |
|
|
|
CheckHookConsistency(obj, py_args[2]); |
|
|
|
(void)hook_grad_.erase(cell_id); |
|
|
|
} else { |
|
|
|
hook_grad_[cell_id] = py_args[2]; |
|
|
|
obj = py_args[2]; |
|
|
|
} |
|
|
|
} else { |
|
|
|
// Hook operator for execute variable hook function |
|
|
|
obj = hook_(py::make_tuple(py_args[2])); |
|
|
|
if (py::isinstance<py::none>(obj)) { |
|
|
|
obj = py_args[2]; |
|
|
|
} |
|
|
|
CheckHookConsistency(obj, py_args[2]); |
|
|
|
return RunCellHookFunction(py_args); |
|
|
|
} |
|
|
|
obj = py::make_tuple(obj); |
|
|
|
return std::make_shared<PyObjectRef>(obj); |
|
|
|
return RunVariableHookFunction(py_args); |
|
|
|
} |
|
|
|
|
|
|
|
py::function PrimitivePy::GetComputeFunction() const { |
|
|
|
|