diff --git a/mindspore/ccsrc/vm/vm.cc b/mindspore/ccsrc/vm/vm.cc index ea55597f15..7107212b6c 100644 --- a/mindspore/ccsrc/vm/vm.cc +++ b/mindspore/ccsrc/vm/vm.cc @@ -603,6 +603,19 @@ void FinalVM::InstPushPrim(const VectorRef &args) { MS_LOG(DEBUG) << "End"; } +void FinalVM::SyncData(const py::object &arg) { + if (py::isinstance(arg)) { + py::tuple arg_list = py::cast(arg); + for (size_t i = 0; i < arg_list.size(); i++) { + SyncData(arg_list[i]); + } + } + if (py::isinstance(arg)) { + auto tensor = py::cast(arg); + (void)tensor->data_sync(); + } +} + BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) { MS_LOG(DEBUG) << "input for operation:"; std::size_t args_size = args.size(); @@ -613,15 +626,20 @@ BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) { MS_LOG(DEBUG) << "arg: " << i << ":"; i++; } + // Hook operator for execute cell custom bprop function py::object obj; bool is_bprop = prim->HasAttr("bprop"); if (is_bprop) { + SyncData(py_args); py::function fn_bprop = prim->hook(); obj = fn_bprop(*py_args); return obj; } + // Sync gradient data from device to host + SyncData(py_args[2]); bool is_cell = prim->HasAttr("cell_hook"); if (is_cell) { + // Hook operator for execute cell hook function std::string cell_id = GetValue(prim->GetAttr("cell_id")); if (_hook_grad.find(cell_id) != _hook_grad.end()) { std::size_t hook_args_size = 3; @@ -640,6 +658,7 @@ BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) { obj = py_args[2]; } } else { + // Hook operator for execute variable hook function py::function fn_hook = prim->hook(); obj = fn_hook(py::make_tuple(py_args[2])); if (py::isinstance(obj)) { diff --git a/mindspore/ccsrc/vm/vm.h b/mindspore/ccsrc/vm/vm.h index 4579b1bc97..6a078c9baf 100644 --- a/mindspore/ccsrc/vm/vm.h +++ b/mindspore/ccsrc/vm/vm.h @@ -115,7 +115,7 @@ class FinalVM { void InstPushPrim(const VectorRef &args); void InstSwitchReturn(const VectorRef &args); void set_insts(const InstSet &value) { insts_ = value; } - BaseRef RunHook(const PrimitivePtr &prim, const VectorRef &args); + BaseRef RunHook(const PrimitivePtr &prim, const VectorRef &arg); protected: BaseRef Ref(int i); @@ -129,6 +129,7 @@ class FinalVM { void PushStatus(bool is_switch_call); bool PopStatus(); void DoJmp(const BaseRef &jmp); + void SyncData(const py::object &args); void MergeJmpArgs(const BaseRef &jmp, const BaseRef &c); BaseRef MergeArgs(const BaseRef &first, const BaseRef &second); diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index e6d2dc7383..dd8c4dac27 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -77,7 +77,7 @@ class Cell: if flags: self.add_flags(**flags) self._backward_hook = None - self._enable_hook = False + self.enable_hook = False self._bprop_debug = False @property @@ -97,10 +97,24 @@ class Cell: @property def bprop_debug(self): + """ + Get whether cell custom bprop debug is enabled. + """ return self._bprop_debug @bprop_debug.setter def bprop_debug(self, value): + """ + Set whether to enable cell custom bprop debug. + + Note: + When bprop is defined in cell, the bprop function will be executed + in python interpreter when bprop debug is true, and will be parsed + and add to graph when bprop debug is false. + + Args: + value (bool): Specifies whether to enable bprop debug. Default: False. + """ if not isinstance(value, bool): raise TypeError("'bprop debug' value must be bool type.") self._bprop_debug = value @@ -755,17 +769,19 @@ class Cell: outputs = self._backward_hook(inputs) return outputs - @property - def enable_hook(self): - """Whether the cell register hook function""" - return self._enable_hook - def register_backward_hook(self, fn): """ Set the cell backward hook function. + Note: + fn should be defined as following code shows, `cell_name` is the name of registered cell, + `grad_input` is gradient passed to the cell, `grad_output` is the gradient computed and pass to + next cell or primitve, which may be modified and return. + >>> hook_fn(cell_name, grad_input, grad_output) -> Tensor or None + Args: fn (function): Specifies the hook function with grad as input. + """ self._backward_hook = HookBackward(fn, self.cls_name + "(" + str(id(self)) + ")") self._enable_hook = True diff --git a/mindspore/ops/operations/debug_ops.py b/mindspore/ops/operations/debug_ops.py index 28d07d9260..f1b56b2850 100644 --- a/mindspore/ops/operations/debug_ops.py +++ b/mindspore/ops/operations/debug_ops.py @@ -247,9 +247,11 @@ class HookBackward(PrimitiveWithInfer): Used as tag to hook gradient in intermediate variables. Note: - The hook function should have one input of gradient of the variable. - hook function will be executed in python environment, while callback - of InsertGradientOf will be parsed and added to the graph. + The hook function should be defined like `hook_fn(grad) -> Tensor or None`, + which grad is the gradient passed to the primitive and gradient may be + modified and passed to nex primitive. the difference between hook function and + callback of InsertGradientOf is that hook function is executed in python + environment while callback will be parsed and added to the graph. Args: hook_fn (Function): Python function. hook function. @@ -312,6 +314,8 @@ class Print(PrimitiveWithInfer): 2. The data of tensor is a scalar type. + In pynative mode, please use python print function. + Inputs: - **input_x** (Union[Tensor, str]) - The graph node to attach to. The input supports multiple strings and tensors which are separated by ','.