浏览代码

!1430 Pynative can not add cell hook

Merge pull request !1430 from JoyLvliang/pynative-cell-hook-grad-abnormal
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 年前
父节点
当前提交
8c4f5e5019
共有 3 个文件被更改,包括 6 次插入3 次删除
  1. +3
    -0
      mindspore/ccsrc/pre_activate/ascend/format_type/merge_cast_to_op.cc
  2. +1
    -1
      mindspore/nn/cell.py
  3. +2
    -2
      mindspore/ops/operations/debug_ops.py

+ 3
- 0
mindspore/ccsrc/pre_activate/ascend/format_type/merge_cast_to_op.cc 查看文件

@@ -170,6 +170,9 @@ bool GetPriorOp(const AnfNodePtr &x_node, CNodePtr *prior_op, bool *single_outpu
MS_EXCEPTION_IF_NULL(output_idx);
AnfNodePtr input1 = x_cnode->input(1);
MS_EXCEPTION_IF_NULL(input1);
if (!input1->isa<CNode>()) {
return false;
}
*prior_op = input1->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(*prior_op);
AnfNodePtr input2 = x_cnode->input(2);


+ 1
- 1
mindspore/nn/cell.py 查看文件

@@ -762,5 +762,5 @@ class Cell:
Args:
fn (function): Specifies the hook function with grad as input.
"""
self._backward_hook = HookBackward(fn, str(id(self)))
self._backward_hook = HookBackward(fn, self.cls_name + "(" + str(id(self)) + ")")
self._enable_hook = True

+ 2
- 2
mindspore/ops/operations/debug_ops.py 查看文件

@@ -14,7 +14,7 @@
# ============================================================================

"""debug_ops"""
from types import FunctionType
from types import FunctionType, MethodType
from ..._checkparam import Validator as validator
from ...common import dtype as mstype
from ..primitive import prim_attr_register, PrimitiveWithInfer
@@ -279,7 +279,7 @@ class HookBackward(PrimitiveWithInfer):
super(HookBackward, self).__init__(self.__class__.__name__)
self.add_prim_attr("cell_id", cell_id)
self.init_attrs["cell_id"] = cell_id
if not isinstance(hook_fn, FunctionType):
if not isinstance(hook_fn, (FunctionType, MethodType)):
raise TypeError("Hook function should be python function type.")
self.register_hook(hook_fn)
self.cell_id = cell_id


正在加载...
取消
保存