|
|
|
@@ -14,7 +14,7 @@ |
|
|
|
# ============================================================================ |
|
|
|
|
|
|
|
"""debug_ops""" |
|
|
|
from types import FunctionType |
|
|
|
from types import FunctionType, MethodType |
|
|
|
from ..._checkparam import Validator as validator |
|
|
|
from ...common import dtype as mstype |
|
|
|
from ..primitive import prim_attr_register, PrimitiveWithInfer |
|
|
|
@@ -279,7 +279,7 @@ class HookBackward(PrimitiveWithInfer): |
|
|
|
super(HookBackward, self).__init__(self.__class__.__name__) |
|
|
|
self.add_prim_attr("cell_id", cell_id) |
|
|
|
self.init_attrs["cell_id"] = cell_id |
|
|
|
if not isinstance(hook_fn, FunctionType): |
|
|
|
if not isinstance(hook_fn, (FunctionType, MethodType)): |
|
|
|
raise TypeError("Hook function should be python function type.") |
|
|
|
self.register_hook(hook_fn) |
|
|
|
self.cell_id = cell_id |
|
|
|
|