Merge pull request !1687 from wangqiuliang/add-data-sync-before-hook-functiontags/v0.5.0-beta
| @@ -603,6 +603,19 @@ void FinalVM::InstPushPrim(const VectorRef &args) { | |||
| MS_LOG(DEBUG) << "End"; | |||
| } | |||
| void FinalVM::SyncData(const py::object &arg) { | |||
| if (py::isinstance<py::tuple>(arg)) { | |||
| py::tuple arg_list = py::cast<py::tuple>(arg); | |||
| for (size_t i = 0; i < arg_list.size(); i++) { | |||
| SyncData(arg_list[i]); | |||
| } | |||
| } | |||
| if (py::isinstance<tensor::Tensor>(arg)) { | |||
| auto tensor = py::cast<tensor::TensorPtr>(arg); | |||
| (void)tensor->data_sync(); | |||
| } | |||
| } | |||
| BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) { | |||
| MS_LOG(DEBUG) << "input for operation:"; | |||
| std::size_t args_size = args.size(); | |||
| @@ -613,15 +626,20 @@ BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) { | |||
| MS_LOG(DEBUG) << "arg: " << i << ":"; | |||
| i++; | |||
| } | |||
| // Hook operator for execute cell custom bprop function | |||
| py::object obj; | |||
| bool is_bprop = prim->HasAttr("bprop"); | |||
| if (is_bprop) { | |||
| SyncData(py_args); | |||
| py::function fn_bprop = prim->hook(); | |||
| obj = fn_bprop(*py_args); | |||
| return obj; | |||
| } | |||
| // Sync gradient data from device to host | |||
| SyncData(py_args[2]); | |||
| bool is_cell = prim->HasAttr("cell_hook"); | |||
| if (is_cell) { | |||
| // Hook operator for execute cell hook function | |||
| std::string cell_id = GetValue<std::string>(prim->GetAttr("cell_id")); | |||
| if (_hook_grad.find(cell_id) != _hook_grad.end()) { | |||
| std::size_t hook_args_size = 3; | |||
| @@ -640,6 +658,7 @@ BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) { | |||
| obj = py_args[2]; | |||
| } | |||
| } else { | |||
| // Hook operator for execute variable hook function | |||
| py::function fn_hook = prim->hook(); | |||
| obj = fn_hook(py::make_tuple(py_args[2])); | |||
| if (py::isinstance<py::none>(obj)) { | |||
| @@ -115,7 +115,7 @@ class FinalVM { | |||
| void InstPushPrim(const VectorRef &args); | |||
| void InstSwitchReturn(const VectorRef &args); | |||
| void set_insts(const InstSet &value) { insts_ = value; } | |||
| BaseRef RunHook(const PrimitivePtr &prim, const VectorRef &args); | |||
| BaseRef RunHook(const PrimitivePtr &prim, const VectorRef &arg); | |||
| protected: | |||
| BaseRef Ref(int i); | |||
| @@ -129,6 +129,7 @@ class FinalVM { | |||
| void PushStatus(bool is_switch_call); | |||
| bool PopStatus(); | |||
| void DoJmp(const BaseRef &jmp); | |||
| void SyncData(const py::object &args); | |||
| void MergeJmpArgs(const BaseRef &jmp, const BaseRef &c); | |||
| BaseRef MergeArgs(const BaseRef &first, const BaseRef &second); | |||
| @@ -77,7 +77,7 @@ class Cell: | |||
| if flags: | |||
| self.add_flags(**flags) | |||
| self._backward_hook = None | |||
| self._enable_hook = False | |||
| self.enable_hook = False | |||
| self._bprop_debug = False | |||
| @property | |||
| @@ -97,10 +97,24 @@ class Cell: | |||
| @property | |||
| def bprop_debug(self): | |||
| """ | |||
| Get whether cell custom bprop debug is enabled. | |||
| """ | |||
| return self._bprop_debug | |||
| @bprop_debug.setter | |||
| def bprop_debug(self, value): | |||
| """ | |||
| Set whether to enable cell custom bprop debug. | |||
| Note: | |||
| When bprop is defined in cell, the bprop function will be executed | |||
| in python interpreter when bprop debug is true, and will be parsed | |||
| and add to graph when bprop debug is false. | |||
| Args: | |||
| value (bool): Specifies whether to enable bprop debug. Default: False. | |||
| """ | |||
| if not isinstance(value, bool): | |||
| raise TypeError("'bprop debug' value must be bool type.") | |||
| self._bprop_debug = value | |||
| @@ -755,17 +769,19 @@ class Cell: | |||
| outputs = self._backward_hook(inputs) | |||
| return outputs | |||
| @property | |||
| def enable_hook(self): | |||
| """Whether the cell register hook function""" | |||
| return self._enable_hook | |||
| def register_backward_hook(self, fn): | |||
| """ | |||
| Set the cell backward hook function. | |||
| Note: | |||
| fn should be defined as following code shows, `cell_name` is the name of registered cell, | |||
| `grad_input` is gradient passed to the cell, `grad_output` is the gradient computed and pass to | |||
| next cell or primitve, which may be modified and return. | |||
| >>> hook_fn(cell_name, grad_input, grad_output) -> Tensor or None | |||
| Args: | |||
| fn (function): Specifies the hook function with grad as input. | |||
| """ | |||
| self._backward_hook = HookBackward(fn, self.cls_name + "(" + str(id(self)) + ")") | |||
| self._enable_hook = True | |||
| @@ -247,9 +247,11 @@ class HookBackward(PrimitiveWithInfer): | |||
| Used as tag to hook gradient in intermediate variables. | |||
| Note: | |||
| The hook function should have one input of gradient of the variable. | |||
| hook function will be executed in python environment, while callback | |||
| of InsertGradientOf will be parsed and added to the graph. | |||
| The hook function should be defined like `hook_fn(grad) -> Tensor or None`, | |||
| which grad is the gradient passed to the primitive and gradient may be | |||
| modified and passed to nex primitive. the difference between hook function and | |||
| callback of InsertGradientOf is that hook function is executed in python | |||
| environment while callback will be parsed and added to the graph. | |||
| Args: | |||
| hook_fn (Function): Python function. hook function. | |||
| @@ -312,6 +314,8 @@ class Print(PrimitiveWithInfer): | |||
| 2. The data of tensor is a scalar type. | |||
| In pynative mode, please use python print function. | |||
| Inputs: | |||
| - **input_x** (Union[Tensor, str]) - The graph node to attach to. The input supports | |||
| multiple strings and tensors which are separated by ','. | |||