|
|
@@ -624,8 +624,8 @@ BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) { |
|
|
if (_hook_grad.find(cell_id) != _hook_grad.end()) { |
|
|
if (_hook_grad.find(cell_id) != _hook_grad.end()) { |
|
|
py::tuple hook_args = py::tuple(3); |
|
|
py::tuple hook_args = py::tuple(3); |
|
|
hook_args[0] = cell_id; |
|
|
hook_args[0] = cell_id; |
|
|
hook_args[1] = _hook_grad[cell_id]; |
|
|
|
|
|
hook_args[2] = py_args[2]; |
|
|
|
|
|
|
|
|
hook_args[1] = py::make_tuple(_hook_grad[cell_id]); |
|
|
|
|
|
hook_args[2] = py::make_tuple(py_args[2]); |
|
|
py::function fn_hook = prim->hook(); |
|
|
py::function fn_hook = prim->hook(); |
|
|
obj = fn_hook(*hook_args); |
|
|
obj = fn_hook(*hook_args); |
|
|
if (py::isinstance<py::none>(obj)) { |
|
|
if (py::isinstance<py::none>(obj)) { |
|
|
@@ -638,7 +638,7 @@ BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) { |
|
|
} |
|
|
} |
|
|
} else { |
|
|
} else { |
|
|
py::function fn_hook = prim->hook(); |
|
|
py::function fn_hook = prim->hook(); |
|
|
obj = fn_hook(py_args[2]); |
|
|
|
|
|
|
|
|
obj = fn_hook(py::make_tuple(py_args[2])); |
|
|
if (py::isinstance<py::none>(obj)) { |
|
|
if (py::isinstance<py::none>(obj)) { |
|
|
obj = py_args[2]; |
|
|
obj = py_args[2]; |
|
|
} |
|
|
} |
|
|
|