Merge pull request !23099 from JoyLvliang/correct_wrong_info_when_using_ms_function_with_bproptags/v1.5.0-rc1
| @@ -27,7 +27,6 @@ from textwrap import dedent | |||
| import asttokens | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore import log as logger | |||
| from mindspore import nn | |||
| from mindspore import ops | |||
| @@ -105,8 +104,6 @@ def get_parse_method_of_class(obj, parse_method=None): | |||
| method_name = parse_method | |||
| elif isinstance(obj, nn.Cell): | |||
| if obj.enable_hook: | |||
| if context.get_context("mode") == context.GRAPH_MODE: | |||
| raise ValueError("The graph mode does not support hook function.") | |||
| method_name = "_hook_construct" | |||
| else: | |||
| method_name = "construct" | |||
| @@ -210,6 +210,11 @@ void DFunctor::BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, | |||
| for (size_t i = 0; i < cnode_morph->size(); i++) { | |||
| auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(SizeToLong(i))}); | |||
| auto input = cnode_morph->input(i); | |||
| // Skip HookBackward op | |||
| if (IsPrimitiveCNode(input, prim::kPrimHookBackward)) { | |||
| auto inp_i = input->cast<CNodePtr>(); | |||
| input = inp_i->input(1); | |||
| } | |||
| // Backprop sens wrt fvs. | |||
| if (IsValueNode<FuncGraph>(input)) { | |||
| auto func_graph = GetValueNode<FuncGraphPtr>(input); | |||
| @@ -257,6 +262,13 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) { | |||
| std::vector<AdjointPtr> param_adjoints; | |||
| for (size_t i = 0; i < cnode_morph->size(); i++) { | |||
| auto node = cnode_morph->input(i); | |||
| // Skip HookBackward op | |||
| if (IsPrimitiveCNode(node, prim::kPrimHookBackward)) { | |||
| auto input_i = node->cast<CNodePtr>(); | |||
| MS_LOG(WARNING) | |||
| << "Hook operation does not work in graph mode or ms_function, it will be eliminated during compilation."; | |||
| node = input_i->input(1); | |||
| } | |||
| AdjointPtr node_adjoint = nullptr; | |||
| auto node_adjoint_iter = anfnode_to_adjoin_.find(node); | |||
| if (node_adjoint_iter != anfnode_to_adjoin_.end()) { | |||
| @@ -417,11 +429,19 @@ void DFunctor::MapMorphism() { | |||
| // Handle free morphism before output, because in some case, free morphism might depend on output's fv tangent | |||
| MapFreeMorphism(); | |||
| // Skip HookBackward when it is the output node. | |||
| auto output_node = primal_graph_->output(); | |||
| if (IsPrimitiveCNode(output_node, prim::kPrimHookBackward)) { | |||
| auto output_cnode = output_node->cast<CNodePtr>(); | |||
| MS_LOG(WARNING) | |||
| << "Hook operation does not work in graph mode or ms_function, it will be eliminated during compilation."; | |||
| output_node = output_cnode->input(1); | |||
| } | |||
| // Handle morphism from output. | |||
| (void)MapMorphism(primal_graph_->output()); | |||
| (void)MapMorphism(output_node); | |||
| // Construct K for primal_graph_ | |||
| auto output_adjoint = anfnode_to_adjoin_.find(primal_graph_->output()); | |||
| // Construct K for primal_graph_. | |||
| auto output_adjoint = anfnode_to_adjoin_.find(output_node); | |||
| // Attach dout_ parameter to output_adjoint. | |||
| output_adjoint->second->AccumulateDout(dout_); | |||
| @@ -612,7 +632,9 @@ void DFunctor::MapValueObject() { | |||
| AdjointPtr adjoint = nullptr; | |||
| if (IsValueNode<Primitive>(node)) { // Primitive. | |||
| if (GetValueNode<PrimitivePtr>(node) == prim::kPrimReturn) { | |||
| auto prim = GetValueNode<PrimitivePtr>(node); | |||
| if (GetValueNode<PrimitivePtr>(node) == prim::kPrimReturn || | |||
| (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name())) { | |||
| continue; | |||
| } | |||
| MS_LOG(DEBUG) << "Map Primitive node " << node->DebugString() << "."; | |||
| @@ -63,6 +63,10 @@ class SpecialOpEliminater : public OptimizerCaller { | |||
| for (auto &eliminater : eliminaters_) { | |||
| new_node = (*eliminater)(optimizer, node); | |||
| if (new_node != nullptr) { | |||
| if (IsPrimitiveCNode(node, prim::kPrimHookBackward)) { | |||
| MS_LOG(WARNING) | |||
| << "Hook operation does not work in graph mode or ms_function, it will be eliminated during compilation."; | |||
| } | |||
| return new_node; | |||
| } | |||
| } | |||
| @@ -2374,6 +2374,10 @@ void GradExecutor::DoGradForCustomBprop(const py::object &cell, const py::object | |||
| (void)fake_prim->AddAttr(parse::CUSTOM_BPROP_NAME, MakeValue(true)); | |||
| py::object code_obj = py::getattr(bprop_func, "__code__"); | |||
| py::object co_name = py::getattr(code_obj, "co_name"); | |||
| if (std::string(py::str(co_name)) == "staging_specialize") { | |||
| MS_LOG(EXCEPTION) << "Decorating bprop with '@ms_function' is not supported."; | |||
| } | |||
| // Three parameters self, out and dout need to be excluded | |||
| const size_t inputs_num = py::cast<int64_t>(py::getattr(code_obj, "co_argcount")) - 3; | |||
| if (inputs_num > args.size()) { | |||
| @@ -219,6 +219,12 @@ BaseRef PrimitivePy::RunCellHookFunction(const py::tuple &py_args) const { | |||
| auto cell_id = GetValue<std::string>(this->GetAttr(kCellIDAttrName)); | |||
| auto iter = hook_grad_.find(cell_id); | |||
| if (iter != hook_grad_.end()) { | |||
| py::object code_obj = py::getattr(hook_, "__code__"); | |||
| py::object co_name = py::getattr(code_obj, "co_name"); | |||
| if (std::string(py::str(co_name)) == "staging_specialize") { | |||
| MS_LOG(EXCEPTION) << "Decorating hook function with '@ms_function' is not supported."; | |||
| } | |||
| py::tuple convert_args(input_param_nums - 1); | |||
| py::tuple input_args(input_param_nums - 1); | |||
| input_args[0] = iter->second; | |||
| @@ -243,6 +249,12 @@ BaseRef PrimitivePy::RunCellHookFunction(const py::tuple &py_args) const { | |||
| } | |||
| BaseRef PrimitivePy::RunVariableHookFunction(const py::tuple &py_args) const { | |||
| py::object code_obj = py::getattr(hook_, "__code__"); | |||
| py::object co_name = py::getattr(code_obj, "co_name"); | |||
| if (std::string(py::str(co_name)) == "staging_specialize") { | |||
| MS_LOG(EXCEPTION) << "Decorating hook function with '@ms_function' is not supported."; | |||
| } | |||
| constexpr size_t grad_output_index = 2; | |||
| SyncData(py_args[grad_output_index]); | |||
| py::object obj = hook_(py::make_tuple(py_args[grad_output_index])); | |||
| @@ -0,0 +1,157 @@ | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import pytest | |||
| import numpy as np | |||
| import mindspore.nn as nn | |||
| import mindspore.ops.operations as P | |||
| from mindspore.ops import composite as C | |||
| from mindspore import context, Tensor | |||
| from mindspore.common.api import ms_function | |||
| grad_all = C.GradOperation(get_all=True) | |||
| def var_hook_function(grad_out): | |||
| print("grad:", grad_out) | |||
| class GraphVarHook(nn.Cell): | |||
| def __init__(self): | |||
| super(GraphVarHook, self).__init__() | |||
| self.relu = nn.ReLU() | |||
| self.hook = P.HookBackward(var_hook_function) | |||
| def construct(self, x): | |||
| x = x + x | |||
| x = x * x | |||
| x = self.hook(x) | |||
| x = self.relu(x) | |||
| return x | |||
| class MsFuncVarHook(nn.Cell): | |||
| def __init__(self): | |||
| super(MsFuncVarHook, self).__init__() | |||
| self.relu = nn.ReLU() | |||
| self.hook = P.HookBackward(var_hook_function) | |||
| @ms_function | |||
| def construct(self, x): | |||
| x = x + x | |||
| x = x * x | |||
| x = self.hook(x) | |||
| x = self.relu(x) | |||
| return x | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_var_hook_forward(): | |||
| input_x = Tensor(np.random.randn(2, 2).astype(np.float32)) | |||
| context.set_context(mode=context.PYNATIVE_MODE) | |||
| net1 = MsFuncVarHook() | |||
| out1 = net1(input_x) | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| net2 = GraphVarHook() | |||
| out2 = net2(input_x) | |||
| assert np.allclose(out1.asnumpy(), out2.asnumpy(), 0.00001, 0.00001) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_var_hook_grad(): | |||
| input_x = Tensor(np.random.randn(2, 2).astype(np.float32)) | |||
| context.set_context(mode=context.PYNATIVE_MODE) | |||
| net1 = MsFuncVarHook() | |||
| grad_out1 = grad_all(net1)(input_x) | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| net2 = GraphVarHook() | |||
| grad_out2 = grad_all(net2)(input_x) | |||
| assert np.allclose(grad_out1[0].asnumpy(), grad_out2[0].asnumpy(), 0.00001, 0.00001) | |||
| def cell_hook_function(cell_id, grad_input, grad_output): | |||
| print("cell id:", cell_id) | |||
| print("grad input:", grad_input) | |||
| print("grad output:", grad_output) | |||
| class GraphCellHook(nn.Cell): | |||
| def __init__(self): | |||
| super(GraphCellHook, self).__init__() | |||
| self.relu = nn.ReLU() | |||
| self.relu.register_backward_hook(cell_hook_function) | |||
| def construct(self, x): | |||
| x = x + x | |||
| x = x * x | |||
| x = self.relu(x) | |||
| return x | |||
| class MsFuncCellHook(nn.Cell): | |||
| def __init__(self): | |||
| super(MsFuncCellHook, self).__init__() | |||
| self.relu = nn.ReLU() | |||
| self.relu.register_backward_hook(cell_hook_function) | |||
| @ms_function | |||
| def construct(self, x): | |||
| x = x + x | |||
| x = x * x | |||
| x = self.relu(x) | |||
| return x | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_cell_hook_forward(): | |||
| input_x = Tensor(np.random.randn(2, 2).astype(np.float32)) | |||
| context.set_context(mode=context.PYNATIVE_MODE) | |||
| net1 = MsFuncCellHook() | |||
| out1 = net1(input_x) | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| net2 = GraphCellHook() | |||
| out2 = net2(input_x) | |||
| assert np.allclose(out1.asnumpy(), out2.asnumpy(), 0.00001, 0.00001) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.env_onecard | |||
| def test_cell_hook_grad(): | |||
| input_x = Tensor(np.random.randn(2, 2).astype(np.float32)) | |||
| context.set_context(mode=context.PYNATIVE_MODE) | |||
| net1 = MsFuncCellHook() | |||
| grad_out1 = grad_all(net1)(input_x) | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| net2 = GraphCellHook() | |||
| grad_out2 = grad_all(net2)(input_x) | |||
| assert np.allclose(grad_out1[0].asnumpy(), grad_out2[0].asnumpy(), 0.00001, 0.00001) | |||
| @@ -0,0 +1,76 @@ | |||
| # Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.nn as nn | |||
| from mindspore.ops import composite as C | |||
| from mindspore.nn import Momentum | |||
| from mindspore import context, Tensor | |||
| from mindspore.common.api import ms_function | |||
| grad_all = C.GradOperation(get_all=True) | |||
| class CellBprop(nn.Cell): | |||
| def __init__(self): | |||
| super(CellBprop, self).__init__() | |||
| def construct(self, x, y): | |||
| return 2 * x * x + y * y | |||
| @ms_function | |||
| def bprop(self, x, y, out, dout): | |||
| return dout, 2 * y | |||
| def test_cell_bprop_grad(): | |||
| input_x = Tensor(np.random.randn(2, 2).astype(np.float32)) | |||
| input_y = Tensor(np.random.randn(2, 2).astype(np.float32)) | |||
| context.set_context(mode=context.PYNATIVE_MODE) | |||
| net = CellBprop() | |||
| with pytest.raises(RuntimeError): | |||
| grad_all(net)(input_x, input_y) | |||
| class ConvNet(nn.Cell): | |||
| def __init__(self): | |||
| super(ConvNet, self).__init__() | |||
| self.conv = nn.Conv2d(1, 2, kernel_size=2, stride=1, padding=0, weight_init="ones", pad_mode="valid") | |||
| def construct(self, x): | |||
| out = self.conv(x) | |||
| return out | |||
| class MomentumWithMsFunc(nn.Cell): | |||
| def __init__(self, net): | |||
| super(MomentumWithMsFunc, self).__init__() | |||
| self.net = net | |||
| self.optimizer = Momentum(filter(lambda x: x.requires_grad, self.net.get_parameters()), 0.1, 0.9) | |||
| @ms_function | |||
| def construct(self, grads): | |||
| ret = self.optimizer(grads) | |||
| return ret | |||
| def test_ms_func_decorate_forward(): | |||
| context.set_context(mode=context.PYNATIVE_MODE) | |||
| input_x = Tensor(np.random.randn(1, 1, 2, 2).astype(np.float32)) | |||
| net = ConvNet() | |||
| grad_out = grad_all(net)(input_x) | |||
| opt = MomentumWithMsFunc(net) | |||
| opt(grad_out) | |||