Merge pull request !1687 from wangqiuliang/add-data-sync-before-hook-functiontags/v0.5.0-beta
| @@ -603,6 +603,19 @@ void FinalVM::InstPushPrim(const VectorRef &args) { | |||||
| MS_LOG(DEBUG) << "End"; | MS_LOG(DEBUG) << "End"; | ||||
| } | } | ||||
| void FinalVM::SyncData(const py::object &arg) { | |||||
| if (py::isinstance<py::tuple>(arg)) { | |||||
| py::tuple arg_list = py::cast<py::tuple>(arg); | |||||
| for (size_t i = 0; i < arg_list.size(); i++) { | |||||
| SyncData(arg_list[i]); | |||||
| } | |||||
| } | |||||
| if (py::isinstance<tensor::Tensor>(arg)) { | |||||
| auto tensor = py::cast<tensor::TensorPtr>(arg); | |||||
| (void)tensor->data_sync(); | |||||
| } | |||||
| } | |||||
| BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) { | BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) { | ||||
| MS_LOG(DEBUG) << "input for operation:"; | MS_LOG(DEBUG) << "input for operation:"; | ||||
| std::size_t args_size = args.size(); | std::size_t args_size = args.size(); | ||||
| @@ -613,15 +626,20 @@ BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) { | |||||
| MS_LOG(DEBUG) << "arg: " << i << ":"; | MS_LOG(DEBUG) << "arg: " << i << ":"; | ||||
| i++; | i++; | ||||
| } | } | ||||
| // Hook operator for execute cell custom bprop function | |||||
| py::object obj; | py::object obj; | ||||
| bool is_bprop = prim->HasAttr("bprop"); | bool is_bprop = prim->HasAttr("bprop"); | ||||
| if (is_bprop) { | if (is_bprop) { | ||||
| SyncData(py_args); | |||||
| py::function fn_bprop = prim->hook(); | py::function fn_bprop = prim->hook(); | ||||
| obj = fn_bprop(*py_args); | obj = fn_bprop(*py_args); | ||||
| return obj; | return obj; | ||||
| } | } | ||||
| // Sync gradient data from device to host | |||||
| SyncData(py_args[2]); | |||||
| bool is_cell = prim->HasAttr("cell_hook"); | bool is_cell = prim->HasAttr("cell_hook"); | ||||
| if (is_cell) { | if (is_cell) { | ||||
| // Hook operator for execute cell hook function | |||||
| std::string cell_id = GetValue<std::string>(prim->GetAttr("cell_id")); | std::string cell_id = GetValue<std::string>(prim->GetAttr("cell_id")); | ||||
| if (_hook_grad.find(cell_id) != _hook_grad.end()) { | if (_hook_grad.find(cell_id) != _hook_grad.end()) { | ||||
| std::size_t hook_args_size = 3; | std::size_t hook_args_size = 3; | ||||
| @@ -640,6 +658,7 @@ BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) { | |||||
| obj = py_args[2]; | obj = py_args[2]; | ||||
| } | } | ||||
| } else { | } else { | ||||
| // Hook operator for execute variable hook function | |||||
| py::function fn_hook = prim->hook(); | py::function fn_hook = prim->hook(); | ||||
| obj = fn_hook(py::make_tuple(py_args[2])); | obj = fn_hook(py::make_tuple(py_args[2])); | ||||
| if (py::isinstance<py::none>(obj)) { | if (py::isinstance<py::none>(obj)) { | ||||
| @@ -115,7 +115,7 @@ class FinalVM { | |||||
| void InstPushPrim(const VectorRef &args); | void InstPushPrim(const VectorRef &args); | ||||
| void InstSwitchReturn(const VectorRef &args); | void InstSwitchReturn(const VectorRef &args); | ||||
| void set_insts(const InstSet &value) { insts_ = value; } | void set_insts(const InstSet &value) { insts_ = value; } | ||||
| BaseRef RunHook(const PrimitivePtr &prim, const VectorRef &args); | |||||
| BaseRef RunHook(const PrimitivePtr &prim, const VectorRef &arg); | |||||
| protected: | protected: | ||||
| BaseRef Ref(int i); | BaseRef Ref(int i); | ||||
| @@ -129,6 +129,7 @@ class FinalVM { | |||||
| void PushStatus(bool is_switch_call); | void PushStatus(bool is_switch_call); | ||||
| bool PopStatus(); | bool PopStatus(); | ||||
| void DoJmp(const BaseRef &jmp); | void DoJmp(const BaseRef &jmp); | ||||
| void SyncData(const py::object &args); | |||||
| void MergeJmpArgs(const BaseRef &jmp, const BaseRef &c); | void MergeJmpArgs(const BaseRef &jmp, const BaseRef &c); | ||||
| BaseRef MergeArgs(const BaseRef &first, const BaseRef &second); | BaseRef MergeArgs(const BaseRef &first, const BaseRef &second); | ||||
| @@ -77,7 +77,7 @@ class Cell: | |||||
| if flags: | if flags: | ||||
| self.add_flags(**flags) | self.add_flags(**flags) | ||||
| self._backward_hook = None | self._backward_hook = None | ||||
| self._enable_hook = False | |||||
| self.enable_hook = False | |||||
| self._bprop_debug = False | self._bprop_debug = False | ||||
| @property | @property | ||||
| @@ -97,10 +97,24 @@ class Cell: | |||||
| @property | @property | ||||
| def bprop_debug(self): | def bprop_debug(self): | ||||
| """ | |||||
| Get whether cell custom bprop debug is enabled. | |||||
| """ | |||||
| return self._bprop_debug | return self._bprop_debug | ||||
| @bprop_debug.setter | @bprop_debug.setter | ||||
| def bprop_debug(self, value): | def bprop_debug(self, value): | ||||
| """ | |||||
| Set whether to enable cell custom bprop debug. | |||||
| Note: | |||||
| When bprop is defined in cell, the bprop function will be executed | |||||
| in python interpreter when bprop debug is true, and will be parsed | |||||
| and add to graph when bprop debug is false. | |||||
| Args: | |||||
| value (bool): Specifies whether to enable bprop debug. Default: False. | |||||
| """ | |||||
| if not isinstance(value, bool): | if not isinstance(value, bool): | ||||
| raise TypeError("'bprop debug' value must be bool type.") | raise TypeError("'bprop debug' value must be bool type.") | ||||
| self._bprop_debug = value | self._bprop_debug = value | ||||
| @@ -755,17 +769,19 @@ class Cell: | |||||
| outputs = self._backward_hook(inputs) | outputs = self._backward_hook(inputs) | ||||
| return outputs | return outputs | ||||
| @property | |||||
| def enable_hook(self): | |||||
| """Whether the cell register hook function""" | |||||
| return self._enable_hook | |||||
| def register_backward_hook(self, fn): | def register_backward_hook(self, fn): | ||||
| """ | """ | ||||
| Set the cell backward hook function. | Set the cell backward hook function. | ||||
| Note: | |||||
| fn should be defined as following code shows, `cell_name` is the name of registered cell, | |||||
| `grad_input` is gradient passed to the cell, `grad_output` is the gradient computed and pass to | |||||
| next cell or primitve, which may be modified and return. | |||||
| >>> hook_fn(cell_name, grad_input, grad_output) -> Tensor or None | |||||
| Args: | Args: | ||||
| fn (function): Specifies the hook function with grad as input. | fn (function): Specifies the hook function with grad as input. | ||||
| """ | """ | ||||
| self._backward_hook = HookBackward(fn, self.cls_name + "(" + str(id(self)) + ")") | self._backward_hook = HookBackward(fn, self.cls_name + "(" + str(id(self)) + ")") | ||||
| self._enable_hook = True | self._enable_hook = True | ||||
| @@ -247,9 +247,11 @@ class HookBackward(PrimitiveWithInfer): | |||||
| Used as tag to hook gradient in intermediate variables. | Used as tag to hook gradient in intermediate variables. | ||||
| Note: | Note: | ||||
| The hook function should have one input of gradient of the variable. | |||||
| hook function will be executed in python environment, while callback | |||||
| of InsertGradientOf will be parsed and added to the graph. | |||||
| The hook function should be defined like `hook_fn(grad) -> Tensor or None`, | |||||
| which grad is the gradient passed to the primitive and gradient may be | |||||
| modified and passed to nex primitive. the difference between hook function and | |||||
| callback of InsertGradientOf is that hook function is executed in python | |||||
| environment while callback will be parsed and added to the graph. | |||||
| Args: | Args: | ||||
| hook_fn (Function): Python function. hook function. | hook_fn (Function): Python function. hook function. | ||||
| @@ -312,6 +314,8 @@ class Print(PrimitiveWithInfer): | |||||
| 2. The data of tensor is a scalar type. | 2. The data of tensor is a scalar type. | ||||
| In pynative mode, please use python print function. | |||||
| Inputs: | Inputs: | ||||
| - **input_x** (Union[Tensor, str]) - The graph node to attach to. The input supports | - **input_x** (Union[Tensor, str]) - The graph node to attach to. The input supports | ||||
| multiple strings and tensors which are separated by ','. | multiple strings and tensors which are separated by ','. | ||||