| @@ -980,7 +980,7 @@ std::vector<AnfNodePtr> PynativeExecutor::GetWeightsArgs(const py::object &weigh | |||||
| } | } | ||||
| } | } | ||||
| } else { | } else { | ||||
| MS_LOG(EXCEPTION) << "training not paramter_tuple"; | |||||
| MS_LOG(DEBUG) << "training not paramter_tuple"; | |||||
| } | } | ||||
| return w_args; | return w_args; | ||||
| } | } | ||||
| @@ -181,6 +181,9 @@ class Tensor(Tensor_): | |||||
| def __imod__(self, other): | def __imod__(self, other): | ||||
| return self.__mod__(other) | return self.__mod__(other) | ||||
| def __pow__(self, other): | |||||
| return tensor_operator_registry.get('__pow__')(self, other) | |||||
| def __floordiv__(self, other): | def __floordiv__(self, other): | ||||
| return tensor_operator_registry.get('__floordiv__')(self, other) | return tensor_operator_registry.get('__floordiv__')(self, other) | ||||
| @@ -176,7 +176,10 @@ class _Context: | |||||
| self._context_switches.push(True, None) | self._context_switches.push(True, None) | ||||
| else: | else: | ||||
| if self.enable_debug_runtime: | if self.enable_debug_runtime: | ||||
| self.set_backend_policy("ge") | |||||
| if self.device_target == "CPU": | |||||
| self.set_backend_policy("vm") | |||||
| else: | |||||
| self.set_backend_policy("ge") | |||||
| self._context_switches.push(False, None) | self._context_switches.push(False, None) | ||||
| def set_backend_policy(self, policy): | def set_backend_policy(self, policy): | ||||
| @@ -16,6 +16,7 @@ | |||||
| import time | import time | ||||
| import gc | import gc | ||||
| from collections import OrderedDict | from collections import OrderedDict | ||||
| import numpy | |||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from .. import context | from .. import context | ||||
| from ..common import dtype as mstype | from ..common import dtype as mstype | ||||
| @@ -211,6 +212,9 @@ class Cell: | |||||
| if context.get_context("mode") == context.GRAPH_MODE: | if context.get_context("mode") == context.GRAPH_MODE: | ||||
| out = self.compile_and_run(*inputs) | out = self.compile_and_run(*inputs) | ||||
| return out | return out | ||||
| for item in inputs: | |||||
| if isinstance(item, numpy.ndarray): | |||||
| raise TypeError("cell inputs should not be numpy array.") | |||||
| self.init_parameters_data() | self.init_parameters_data() | ||||
| orign_grad = [] | orign_grad = [] | ||||
| if self.requires_grad is True: | if self.requires_grad is True: | ||||
| @@ -17,6 +17,7 @@ | |||||
| """Basic composite operations.""" | """Basic composite operations.""" | ||||
| from functools import partial | from functools import partial | ||||
| from types import FunctionType | |||||
| from mindspore import context | from mindspore import context | ||||
| from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, \ | from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, \ | ||||
| @@ -25,6 +26,7 @@ from ...common import dtype as mstype | |||||
| from ...common.api import ms_function, _pynative_exec, _wrap_func | from ...common.api import ms_function, _pynative_exec, _wrap_func | ||||
| from .. import functional as F | from .. import functional as F | ||||
| from ...common.parameter import Parameter | from ...common.parameter import Parameter | ||||
| from ...common.tensor import Tensor | |||||
| __all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_] | __all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_] | ||||
| @@ -114,37 +116,48 @@ class GradOperation(GradOperation_): | |||||
| self.fn = None | self.fn = None | ||||
| self.need_forward = False | self.need_forward = False | ||||
| def _pynative_forward_run(self, args, fn): | |||||
| """ Pynative forward run to build grad graph. """ | |||||
| if self.sens_param: | |||||
| args = args[:-1] | |||||
| if isinstance(fn, FunctionType): | |||||
| _pynative_exec.set_grad_flag(True) | |||||
| _pynative_exec.new_graph(fn, *args) | |||||
| output = fn(*args) | |||||
| _pynative_exec.end_graph(fn, output, *args) | |||||
| else: | |||||
| if fn.is_run and not fn.requires_grad: | |||||
| raise ValueError("obj must set_grad.") | |||||
| if not fn.is_run: | |||||
| self.need_forward = True | |||||
| print("already has forward run before grad by user") | |||||
| if self.need_forward: | |||||
| fn.set_grad() | |||||
| fn(*args) | |||||
| def __call__(self, fn, weights=None): | def __call__(self, fn, weights=None): | ||||
| grad_ = GradOperation('grad', self.get_all, self.get_by_list, self.sens_param) | grad_ = GradOperation('grad', self.get_all, self.get_by_list, self.sens_param) | ||||
| if self.grad_fn is None or self.fn != fn: | if self.grad_fn is None or self.fn != fn: | ||||
| if self.get_by_list: | |||||
| if context.get_context("mode") == context.GRAPH_MODE: | |||||
| if context.get_context("mode") == context.GRAPH_MODE: | |||||
| if self.get_by_list: | |||||
| @ms_function(obj=fn) | @ms_function(obj=fn) | ||||
| def after_grad(*args): | def after_grad(*args): | ||||
| return grad_(fn, weights)(*args) | return grad_(fn, weights)(*args) | ||||
| else: | else: | ||||
| @_wrap_func | |||||
| @ms_function(obj=fn) | |||||
| def after_grad(*args): | def after_grad(*args): | ||||
| if fn.is_run and not fn.requires_grad: | |||||
| raise ValueError("obj must set_grad.") | |||||
| if not fn.is_run: | |||||
| self.need_forward = True | |||||
| print("already has forward run before grad by user") | |||||
| if self.need_forward: | |||||
| fn.set_grad() | |||||
| if self.sens_param: | |||||
| f_args = args[:-1] | |||||
| fn(*f_args) | |||||
| else: | |||||
| fn(*args) | |||||
| _pynative_exec.grad(grad_, fn, weights, *args) | |||||
| out = _pynative_exec(*args) | |||||
| _pynative_exec.clear() | |||||
| return out | |||||
| return grad_(fn)(*args) | |||||
| else: | else: | ||||
| @ms_function(obj=fn) | |||||
| @_wrap_func | |||||
| def after_grad(*args): | def after_grad(*args): | ||||
| return grad_(fn)(*args) | |||||
| for arg in args: | |||||
| if not isinstance(arg, Tensor): | |||||
| raise TypeError("grad inputs should be tensor in pynative mode") | |||||
| self._pynative_forward_run(args, fn) | |||||
| _pynative_exec.grad(grad_, fn, weights, *args) | |||||
| out = _pynative_exec(*args) | |||||
| _pynative_exec.clear() | |||||
| return out | |||||
| self.grad_fn = after_grad | self.grad_fn = after_grad | ||||
| self.fn = fn | self.fn = fn | ||||
| return self.grad_fn | return self.grad_fn | ||||
| @@ -166,6 +166,7 @@ tensor_operator_registry.register('__sub__', tensor_sub) | |||||
| tensor_operator_registry.register('__mul__', tensor_mul) | tensor_operator_registry.register('__mul__', tensor_mul) | ||||
| tensor_operator_registry.register('__truediv__', tensor_div) | tensor_operator_registry.register('__truediv__', tensor_div) | ||||
| tensor_operator_registry.register('__mod__', tensor_mod) | tensor_operator_registry.register('__mod__', tensor_mod) | ||||
| tensor_operator_registry.register('__pow__', tensor_pow) | |||||
| tensor_operator_registry.register('__floordiv__', tensor_floordiv) | tensor_operator_registry.register('__floordiv__', tensor_floordiv) | ||||
| #ms cannot support Tensor(True) compare | #ms cannot support Tensor(True) compare | ||||
| tensor_operator_registry.register('__eq__', equal) | tensor_operator_registry.register('__eq__', equal) | ||||
| @@ -228,6 +228,7 @@ def test_biasadd_3d(): | |||||
| error = np.ones(shape=[3, 4, 8]) * 1.0e-6 | error = np.ones(shape=[3, 4, 8]) * 1.0e-6 | ||||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | ||||
| net = BiasAdd() | net = BiasAdd() | ||||
| net.set_grad() | |||||
| result = net(x, b) | result = net(x, b) | ||||
| diff = result.asnumpy() - expect | diff = result.asnumpy() - expect | ||||
| assert np.all(diff < error) | assert np.all(diff < error) | ||||
| @@ -45,6 +45,7 @@ def test_net_infer(): | |||||
| def test_assign_in_while(): | def test_assign_in_while(): | ||||
| context.set_context(device_target="Ascend") | |||||
| context.set_context(mode=context.GRAPH_MODE) | context.set_context(mode=context.GRAPH_MODE) | ||||
| class Net(nn.Cell): | class Net(nn.Cell): | ||||
| def __init__(self, input_shape): | def __init__(self, input_shape): | ||||
| @@ -16,6 +16,7 @@ | |||||
| import numpy as np | import numpy as np | ||||
| import pytest | import pytest | ||||
| import mindspore as ms | |||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore import Parameter | from mindspore import Parameter | ||||
| @@ -24,12 +25,15 @@ from mindspore.common.initializer import initializer | |||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from ....mindspore_test_framework.utils.bprop_util import bprop | |||||
| from .....mindspore_test_framework.utils.bprop_util import bprop | |||||
| def setup_module(module): | def setup_module(module): | ||||
| context.set_context(mode=context.PYNATIVE_MODE) | |||||
| context.set_context(device_target="CPU") | |||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| def teardown_module(module): | |||||
| context.set_context(device_target="Ascend") | |||||
| class MulAdd(nn.Cell): | class MulAdd(nn.Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| @@ -45,7 +49,9 @@ class MulAdd(nn.Cell): | |||||
| def test_grad_mul_add(): | def test_grad_mul_add(): | ||||
| mul_add = MulAdd() | mul_add = MulAdd() | ||||
| assert C.grad_all(mul_add)(1, 2) == (2, 4) | |||||
| x = Tensor(1, dtype=ms.int32) | |||||
| y = Tensor(2, dtype=ms.int32) | |||||
| assert C.grad_all(mul_add)(x, y) == (2, 4) | |||||
| class InlineMulADD(nn.Cell): | class InlineMulADD(nn.Cell): | ||||
| @@ -60,7 +66,9 @@ class InlineMulADD(nn.Cell): | |||||
| def test_grad_inline_mul_add(): | def test_grad_inline_mul_add(): | ||||
| inline_mul_add = InlineMulADD() | inline_mul_add = InlineMulADD() | ||||
| assert C.grad_all(inline_mul_add)(1, 2) == (3, 6) | |||||
| x = Tensor(1, dtype=ms.int32) | |||||
| y = Tensor(2, dtype=ms.int32) | |||||
| assert C.grad_all(inline_mul_add)(x, y) == (3, 6) | |||||
| class WithParameter(nn.Cell): | class WithParameter(nn.Cell): | ||||
| @@ -93,7 +101,9 @@ class WithNoBprop(nn.Cell): | |||||
| def test_with_no_bprop(): | def test_with_no_bprop(): | ||||
| with_no_bprop = WithNoBprop() | with_no_bprop = WithNoBprop() | ||||
| assert C.grad_all(with_no_bprop)(1, 2) == (2, 1) | |||||
| x = Tensor(1, dtype=ms.int32) | |||||
| y = Tensor(2, dtype=ms.int32) | |||||
| assert C.grad_all(with_no_bprop)(x, y) == (2, 1) | |||||
| def test_grad_in_bprop_1(): | def test_grad_in_bprop_1(): | ||||
| @@ -19,21 +19,27 @@ | |||||
| @Desc : | @Desc : | ||||
| """ | """ | ||||
| import logging | import logging | ||||
| import pytest | |||||
| import numpy as np | import numpy as np | ||||
| import mindspore as ms | import mindspore as ms | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore import context | |||||
| from mindspore.ops import composite as C | |||||
| from mindspore.common.api import ms_function, _executor | from mindspore.common.api import ms_function, _executor | ||||
| from mindspore.ops._grad.grad_base import bprop_getters | |||||
| from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer | |||||
| from mindspore.ops.functional import tensor_add | from mindspore.ops.functional import tensor_add | ||||
| from ...ut_filter import non_graph_engine | from ...ut_filter import non_graph_engine | ||||
| # pylint: disable=W0613 | |||||
| # pylint: disable=W0613,W0612 | |||||
| # W0613: unused-argument | # W0613: unused-argument | ||||
| log = logging.getLogger("test") | log = logging.getLogger("test") | ||||
| log.setLevel(level=logging.ERROR) | log.setLevel(level=logging.ERROR) | ||||
| context.set_context(mode=context.GRAPH_MODE) | |||||
| # Test case: use the parse obj interface use default parameter | # Test case: use the parse obj interface use default parameter | ||||
| @@ -135,3 +141,113 @@ def test_net_with_ndarray(): | |||||
| input_data = np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32') | input_data = np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32') | ||||
| net(ms.Tensor(input_data)) | net(ms.Tensor(input_data)) | ||||
| def test_bprop_with_wrong_output_num(): | |||||
| context.set_context(check_bprop=True) | |||||
| class BpropWithWrongOutputNum(PrimitiveWithInfer): | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| super(BpropWithWrongOutputNum, self).__init__('BpropWithWrongOutputNum') | |||||
| def __call__(self, x, y): | |||||
| return x | |||||
| def infer_shape(self, x_shape, yshape): | |||||
| return x_shape | |||||
| def infer_dtype(self, x_type, y_type): | |||||
| return x_type | |||||
| @bprop_getters.register(BpropWithWrongOutputNum) | |||||
| def get_bprop_with_wrong_output_num(self): | |||||
| """Generate bprop for BpropWithWrongOutputNum""" | |||||
| def bprop(x, y, out, dout): | |||||
| return (dout,) | |||||
| return bprop | |||||
| class BpropWithWrongOutputNumCell(nn.Cell): | |||||
| def __init__(self): | |||||
| super(BpropWithWrongOutputNumCell, self).__init__() | |||||
| def construct(self, x, y): | |||||
| return BpropWithWrongOutputNum()(x, y) | |||||
| with pytest.raises(TypeError): | |||||
| C.grad_all(BpropWithWrongOutputNumCell())(1, 2) | |||||
| def test_bprop_with_wrong_output_type(): | |||||
| context.set_context(check_bprop=True) | |||||
| class BpropWithWrongOutputType(PrimitiveWithInfer): | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| super(BpropWithWrongOutputType, self).__init__('BpropWithWrongOutputType') | |||||
| def __call__(self, x): | |||||
| return x | |||||
| def infer_shape(self, x_shape): | |||||
| return x_shape | |||||
| def infer_dtype(self, x_type): | |||||
| return x_type | |||||
| @bprop_getters.register(BpropWithWrongOutputType) | |||||
| def get_bprop_with_wrong_output_type(self): | |||||
| """Generate bprop for BpropWithWrongOutputType""" | |||||
| def bprop(x, out, dout): | |||||
| return (1,) | |||||
| return bprop | |||||
| class BpropWithWrongOutputTypeCell(nn.Cell): | |||||
| def __init__(self): | |||||
| super(BpropWithWrongOutputTypeCell, self).__init__() | |||||
| def construct(self, x): | |||||
| return BpropWithWrongOutputType()(x) | |||||
| with pytest.raises(TypeError): | |||||
| C.grad_all(BpropWithWrongOutputTypeCell())(Tensor(np.ones([64, 10]).astype(np.int32))) | |||||
| def test_bprop_with_wrong_output_shape(): | |||||
| context.set_context(check_bprop=True) | |||||
| class BpropWithWrongOutputShape(PrimitiveWithInfer): | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| super(BpropWithWrongOutputShape, self).__init__('BpropWithWrongOutputShape') | |||||
| def __call__(self, x): | |||||
| return x | |||||
| def infer_shape(self, x_shape): | |||||
| return x_shape | |||||
| def infer_dtype(self, x_type): | |||||
| return x_type | |||||
| @bprop_getters.register(BpropWithWrongOutputShape) | |||||
| def get_bprop_with_wrong_output_shape(self): | |||||
| """Generate bprop for BpropWithWrongOutputShape""" | |||||
| ones = Tensor(np.ones([2,]).astype(np.int32)) | |||||
| def bprop(x, out, dout): | |||||
| return (ones,) | |||||
| return bprop | |||||
| class BpropWithWrongOutputShapeCell(nn.Cell): | |||||
| def __init__(self): | |||||
| super(BpropWithWrongOutputShapeCell, self).__init__() | |||||
| def construct(self, x): | |||||
| return BpropWithWrongOutputShape()(x) | |||||
| with pytest.raises(TypeError): | |||||
| net = BpropWithWrongOutputShapeCell() | |||||
| net.set_grad() | |||||
| C.grad_all(net)(Tensor(np.ones([64, 10]).astype(np.int32))) | |||||
| @@ -78,3 +78,9 @@ def test_tensor_imul(): | |||||
| y = Tensor(np.ones([3, 3, 3, 3]).astype(np.float32)) | y = Tensor(np.ones([3, 3, 3, 3]).astype(np.float32)) | ||||
| x *= y | x *= y | ||||
| assert x.asnumpy()[0][0][0][0] == 1.0 | assert x.asnumpy()[0][0][0][0] == 1.0 | ||||
| def test_tensor_pow(): | |||||
| x = Tensor(np.ones([3, 3, 3, 3]).astype(np.float32) * 2) | |||||
| y = x ** 3 | |||||
| assert y.asnumpy()[0][0][0][0] == 8.0 | |||||
| @@ -89,7 +89,11 @@ def test_scalar_cast_grad(): | |||||
| output = F.scalar_cast(x, input_t) | output = F.scalar_cast(x, input_t) | ||||
| return output | return output | ||||
| gfn = C.grad(fx_cast)(input_x) | |||||
| @ms_function | |||||
| def grad_fx_cast(input_x): | |||||
| return C.grad(fx_cast)(input_x) | |||||
| gfn = grad_fx_cast(input_x) | |||||
| expect_dx = 1 | expect_dx = 1 | ||||
| assert gfn == expect_dx | assert gfn == expect_dx | ||||
| @@ -133,25 +137,6 @@ def test_transpose_grad(): | |||||
| assert np.all(gout[0].asnumpy() == expect) | assert np.all(gout[0].asnumpy() == expect) | ||||
| @non_graph_engine | |||||
| def test_squeeze_grad(): | |||||
| """ test_squeeze_grad """ | |||||
| input_tensor = Tensor(np.ones(shape=[3, 2, 1])) | |||||
| squeeze = P.Squeeze(2) | |||||
| def fn(x): | |||||
| output = squeeze(x) | |||||
| return output | |||||
| out = fn(input_tensor) | |||||
| gfn = grad_all_with_sens(fn) | |||||
| sens = Tensor(np.ones_like(out.asnumpy())) | |||||
| args = [input_tensor, sens] | |||||
| gout = gfn(*args) | |||||
| expect = np.ones([3, 2, 1]) | |||||
| assert np.all(gout[0].asnumpy() == expect) | |||||
| def test_select_grad(): | def test_select_grad(): | ||||
| """ test_select_grad """ | """ test_select_grad """ | ||||
| select = P.Select() | select = P.Select() | ||||
| @@ -176,6 +161,25 @@ def test_select_grad(): | |||||
| assert np.all(gout[2].asnumpy() == expect_y) | assert np.all(gout[2].asnumpy() == expect_y) | ||||
| @non_graph_engine | |||||
| def test_squeeze_grad(): | |||||
| """ test_squeeze_grad """ | |||||
| input_tensor = Tensor(np.ones(shape=[3, 2, 1])) | |||||
| squeeze = P.Squeeze(2) | |||||
| def fn(x): | |||||
| output = squeeze(x) | |||||
| return output | |||||
| out = fn(input_tensor) | |||||
| gfn = grad_all_with_sens(fn) | |||||
| sens = Tensor(np.ones_like(out.asnumpy())) | |||||
| args = [input_tensor, sens] | |||||
| gout = gfn(*args) | |||||
| expect = np.ones([3, 2, 1]) | |||||
| assert np.all(gout[0].asnumpy() == expect) | |||||
| def test_SubGrad(): | def test_SubGrad(): | ||||
| """ test_SubGrad """ | """ test_SubGrad """ | ||||
| input_x = Tensor(np.array([[2, 2]])) | input_x = Tensor(np.array([[2, 2]])) | ||||
| @@ -16,6 +16,7 @@ | |||||
| import numpy as np | import numpy as np | ||||
| import pytest | import pytest | ||||
| import mindspore as ms | |||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| @@ -23,8 +24,6 @@ from mindspore.common.parameter import Parameter, ParameterTuple | |||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| from mindspore.ops import composite as C | from mindspore.ops import composite as C | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops._grad.grad_base import bprop_getters | |||||
| from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer | |||||
| from ..ut_filter import non_graph_engine | from ..ut_filter import non_graph_engine | ||||
| from ....mindspore_test_framework.utils.check_gradient import ( | from ....mindspore_test_framework.utils.check_gradient import ( | ||||
| ms_function, check_jacobian, Tensor, NNGradChecker, | ms_function, check_jacobian, Tensor, NNGradChecker, | ||||
| @@ -156,14 +155,14 @@ def test_if_always_true(): | |||||
| @non_graph_engine | @non_graph_engine | ||||
| def test_f(): | def test_f(): | ||||
| """ test_f """ | """ test_f """ | ||||
| res = mainf(3, 2) | |||||
| res = mainf(Tensor(3, dtype=ms.int32), Tensor(2, dtype=ms.int32)) | |||||
| assert res == (2, 3) | assert res == (2, 3) | ||||
| @non_graph_engine | @non_graph_engine | ||||
| def test_grad_add_mul(): | def test_grad_add_mul(): | ||||
| """ test_grad_add_mul """ | """ test_grad_add_mul """ | ||||
| res = grad_add_mul(3, 2) | |||||
| res = grad_add_mul(Tensor(3, dtype=ms.int32), Tensor(2, dtype=ms.int32)) | |||||
| assert res == (2, 7) | assert res == (2, 7) | ||||
| @@ -262,17 +261,19 @@ def test_if_tensor(): | |||||
| assert res == Tensor(np.ones([1]).astype(np.int32) * 4) | assert res == Tensor(np.ones([1]).astype(np.int32) * 4) | ||||
| @ms_function | |||||
| def rec(x): | def rec(x): | ||||
| """ rec """ | """ rec """ | ||||
| if x > 0: | if x > 0: | ||||
| return rec(x - 1) | return rec(x - 1) | ||||
| return x | return x | ||||
| @ms_function | |||||
| def grad_rec(input_x): | |||||
| return C.grad(rec)(input_x) | |||||
| def test_grad_rec(): | def test_grad_rec(): | ||||
| """ test_grad_rec """ | """ test_grad_rec """ | ||||
| res = C.grad(rec)(10) | |||||
| res = grad_rec(3) | |||||
| assert res == 1 | assert res == 1 | ||||
| @@ -282,7 +283,6 @@ def test_me_rec(): | |||||
| assert res == 0 | assert res == 0 | ||||
| @ms_function | |||||
| def t2_while(x, y): | def t2_while(x, y): | ||||
| out = y - x | out = y - x | ||||
| i = 0 | i = 0 | ||||
| @@ -298,8 +298,10 @@ def test_while2(): | |||||
| def test_grad_while2(): | def test_grad_while2(): | ||||
| res = C.grad(t2_while)(2, 3) | |||||
| assert res == 3 | |||||
| @ms_function | |||||
| def df_t2_while(input_x, input_y): | |||||
| return C.grad(t2_while)(input_x, input_y) | |||||
| assert df_t2_while(2, 3) == 3 | |||||
| def if_test(a, b): | def if_test(a, b): | ||||
| @@ -316,7 +318,7 @@ def grad_if(x, y): | |||||
| def test_grad_if(): | def test_grad_if(): | ||||
| """ test_grad_if """ | """ test_grad_if """ | ||||
| assert grad_if(5, 4) == (3, 0) | |||||
| assert grad_if(Tensor(5, dtype=ms.int32), Tensor(4, dtype=ms.int32)) == (3, 0) | |||||
| # While loop is not unrolled in forward and backward graphs. | # While loop is not unrolled in forward and backward graphs. | ||||
| @@ -421,7 +423,7 @@ def grad_while(x): | |||||
| def test_grad_while(): | def test_grad_while(): | ||||
| """ test_grad_while """ | """ test_grad_while """ | ||||
| assert grad_while(5) == (60,) | |||||
| assert grad_while(Tensor(5, dtype=ms.int32)) == (60,) | |||||
| @ms_function | @ms_function | ||||
| @@ -438,8 +440,10 @@ def test_factorial(): | |||||
| def test_grad_factorial(): | def test_grad_factorial(): | ||||
| res = C.grad(factorial)(3) | |||||
| assert res == 11 | |||||
| @ms_function | |||||
| def df_factorial(x): | |||||
| return C.grad(factorial)(x) | |||||
| assert df_factorial(3) == 11 | |||||
| @ms_function | @ms_function | ||||
| @@ -513,7 +517,7 @@ def _for(x): | |||||
| ret = ret * i | ret = ret * i | ||||
| return ret | return ret | ||||
| @ms_function | |||||
| def grad_for(x): | def grad_for(x): | ||||
| """ grad_for """ | """ grad_for """ | ||||
| return C.grad_all(_for)(x) | return C.grad_all(_for)(x) | ||||
| @@ -786,7 +790,10 @@ def multi_outputs(x, y): | |||||
| def test_grad_multi_outputs(): | def test_grad_multi_outputs(): | ||||
| assert C.grad_all_with_sens(multi_outputs)(2, 3, (1, 1)) == (4, 4) | |||||
| @ms_function | |||||
| def df_multi_outputs(x, y): | |||||
| return C.grad_all_with_sens(multi_outputs)(x, y, (1, 1)) | |||||
| assert df_multi_outputs(2, 3) == (4, 4) | |||||
| @ms_function | @ms_function | ||||
| @@ -813,7 +820,7 @@ def grad_refactor_simple_1(x, y): | |||||
| def test_grad_refactor_simple_1(): | def test_grad_refactor_simple_1(): | ||||
| assert C.grad_all(grad_refactor_simple_1)(2, 1) == (4, 2) | |||||
| assert C.grad_all(grad_refactor_simple_1)(Tensor(2, dtype=ms.int32), Tensor(1, dtype=ms.int32)) == (4, 2) | |||||
| def grad_refactor_simple_2(x, y, z): | def grad_refactor_simple_2(x, y, z): | ||||
| @@ -822,7 +829,10 @@ def grad_refactor_simple_2(x, y, z): | |||||
| def test_grad_refactor_simple_2(): | def test_grad_refactor_simple_2(): | ||||
| assert C.grad_all(grad_refactor_simple_2)(2, 3, 0) == (7, 4, 7) | |||||
| x = Tensor(2, dtype=ms.int32) | |||||
| y = Tensor(3, dtype=ms.int32) | |||||
| z = Tensor(0, dtype=ms.int32) | |||||
| assert C.grad_all(grad_refactor_simple_2)(x, y, z) == (7, 4, 7) | |||||
| def grad_refactor_1(a, b): | def grad_refactor_1(a, b): | ||||
| @@ -835,7 +845,7 @@ def grad_refactor_1(a, b): | |||||
| def test_grad_refactor_1(): | def test_grad_refactor_1(): | ||||
| assert C.grad_all(grad_refactor_1)(2, 3) == (3, 2) | |||||
| assert C.grad_all(grad_refactor_1)(Tensor(2, dtype=ms.int32), Tensor(3, dtype=ms.int32)) == (3, 2) | |||||
| def grad_refactor_2(a, b): | def grad_refactor_2(a, b): | ||||
| @@ -848,7 +858,7 @@ def grad_refactor_2(a, b): | |||||
| def test_grad_refactor_2(): | def test_grad_refactor_2(): | ||||
| assert C.grad_all(grad_refactor_2)(2, 3) == (27, 54) | |||||
| assert C.grad_all(grad_refactor_2)(Tensor(2, dtype=ms.int32), Tensor(3, dtype=ms.int32)) == (27, 54) | |||||
| def grad_refactor_3(a): | def grad_refactor_3(a): | ||||
| @@ -859,7 +869,10 @@ def grad_refactor_3(a): | |||||
| def test_grad_refactor_3(): | def test_grad_refactor_3(): | ||||
| assert C.grad_all(grad_refactor_3)(3) == (3,) | |||||
| @ms_function | |||||
| def df_refactor_3(x): | |||||
| return C.grad_all(grad_refactor_3)(x) | |||||
| assert df_refactor_3(3) == (3,) | |||||
| def grad_refactor_4(a): | def grad_refactor_4(a): | ||||
| @@ -870,7 +883,7 @@ def grad_refactor_4(a): | |||||
| def test_grad_refactor_4(): | def test_grad_refactor_4(): | ||||
| assert C.grad_all(grad_refactor_4)(4) == (3,) | |||||
| assert C.grad_all(grad_refactor_4)(Tensor(4, dtype=ms.int32)) == (3,) | |||||
| def grad_refactor_5(a): | def grad_refactor_5(a): | ||||
| @@ -881,7 +894,10 @@ def grad_refactor_5(a): | |||||
| def test_grad_refactor_5(): | def test_grad_refactor_5(): | ||||
| assert C.grad_all(grad_refactor_5)(1) == (1,) | |||||
| @ms_function | |||||
| def df_refactor_5(x): | |||||
| return C.grad_all(grad_refactor_5)(x) | |||||
| assert df_refactor_5(1) == (1,) | |||||
| def grad_refactor_6(a, b): | def grad_refactor_6(a, b): | ||||
| @@ -892,7 +908,7 @@ def grad_refactor_6(a, b): | |||||
| def test_grad_refactor_6(): | def test_grad_refactor_6(): | ||||
| assert C.grad_all(grad_refactor_6)(3, 2) == (3, 1) | |||||
| assert C.grad_all(grad_refactor_6)(Tensor(3, dtype=ms.int32), Tensor(2, dtype=ms.int32)) == (3, 1) | |||||
| def grad_refactor_while(x): | def grad_refactor_while(x): | ||||
| @@ -904,7 +920,10 @@ def grad_refactor_while(x): | |||||
| def test_grad_refactor_9(): | def test_grad_refactor_9(): | ||||
| assert C.grad_all(grad_refactor_while)(3) == (6,) | |||||
| @ms_function | |||||
| def df_refactor_while(input_x): | |||||
| return C.grad_all(grad_refactor_while)(input_x) | |||||
| assert df_refactor_while(3) == (6,) | |||||
| def grad_refactor__while_1(x): | def grad_refactor__while_1(x): | ||||
| @@ -919,7 +938,7 @@ def grad_refactor__while_1(x): | |||||
| def test_grad_refactor_10(): | def test_grad_refactor_10(): | ||||
| """ test_grad_while """ | """ test_grad_while """ | ||||
| assert C.grad_all(grad_refactor__while_1)(5) == (60,) | |||||
| assert C.grad_all(grad_refactor__while_1)(Tensor(5, dtype=ms.int32)) == (60,) | |||||
| def test_grad_refactor_11(): | def test_grad_refactor_11(): | ||||
| @@ -985,7 +1004,10 @@ def grad_refactor_14(a, b): | |||||
| def test_grad_refactor_14(): | def test_grad_refactor_14(): | ||||
| assert C.grad_all(grad_refactor_14)(2, 3) == (3, 9) | |||||
| @ms_function | |||||
| def df_refactor_14(x, y): | |||||
| return C.grad_all(grad_refactor_14)(x, y) | |||||
| assert df_refactor_14(2, 3) == (3, 9) | |||||
| # pylint: disable=using-constant-test | # pylint: disable=using-constant-test | ||||
| @@ -1009,111 +1031,3 @@ def test_grad_if_defer_inline(): | |||||
| inp = Tensor(np.ones([128, 96]).astype(np.float32)) | inp = Tensor(np.ones([128, 96]).astype(np.float32)) | ||||
| grads = C.grad_all(network)(inp) | grads = C.grad_all(network)(inp) | ||||
| assert grads == (Tensor(np.full([128, 96], 0.6, dtype=np.float32)),) | assert grads == (Tensor(np.full([128, 96], 0.6, dtype=np.float32)),) | ||||
| def test_bprop_with_wrong_output_num(): | |||||
| context.set_context(check_bprop=True) | |||||
| class BpropWithWrongOutputNum(PrimitiveWithInfer): | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| super(BpropWithWrongOutputNum, self).__init__('BpropWithWrongOutputNum') | |||||
| def __call__(self, x, y): | |||||
| return x | |||||
| def infer_shape(self, x_shape, yshape): | |||||
| return x_shape | |||||
| def infer_dtype(self, x_type, y_type): | |||||
| return x_type | |||||
| @bprop_getters.register(BpropWithWrongOutputNum) | |||||
| def get_bprop_with_wrong_output_num(self): | |||||
| """Generate bprop for BpropWithWrongOutputNum""" | |||||
| def bprop(x, y, out, dout): | |||||
| return (dout,) | |||||
| return bprop | |||||
| class BpropWithWrongOutputNumCell(nn.Cell): | |||||
| def __init__(self): | |||||
| super(BpropWithWrongOutputNumCell, self).__init__() | |||||
| def construct(self, x, y): | |||||
| return BpropWithWrongOutputNum()(x, y) | |||||
| with pytest.raises(TypeError): | |||||
| C.grad_all(BpropWithWrongOutputNumCell())(1, 2) | |||||
| def test_bprop_with_wrong_output_type(): | |||||
| context.set_context(check_bprop=True) | |||||
| class BpropWithWrongOutputType(PrimitiveWithInfer): | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| super(BpropWithWrongOutputType, self).__init__('BpropWithWrongOutputType') | |||||
| def __call__(self, x): | |||||
| return x | |||||
| def infer_shape(self, x_shape): | |||||
| return x_shape | |||||
| def infer_dtype(self, x_type): | |||||
| return x_type | |||||
| @bprop_getters.register(BpropWithWrongOutputType) | |||||
| def get_bprop_with_wrong_output_type(self): | |||||
| """Generate bprop for BpropWithWrongOutputType""" | |||||
| def bprop(x, out, dout): | |||||
| return (1,) | |||||
| return bprop | |||||
| class BpropWithWrongOutputTypeCell(nn.Cell): | |||||
| def __init__(self): | |||||
| super(BpropWithWrongOutputTypeCell, self).__init__() | |||||
| def construct(self, x): | |||||
| return BpropWithWrongOutputType()(x) | |||||
| with pytest.raises(TypeError): | |||||
| C.grad_all(BpropWithWrongOutputTypeCell())(Tensor(np.ones([64, 10]).astype(np.int32))) | |||||
| def test_bprop_with_wrong_output_shape(): | |||||
| context.set_context(check_bprop=True) | |||||
| class BpropWithWrongOutputShape(PrimitiveWithInfer): | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| super(BpropWithWrongOutputShape, self).__init__('BpropWithWrongOutputShape') | |||||
| def __call__(self, x): | |||||
| return x | |||||
| def infer_shape(self, x_shape): | |||||
| return x_shape | |||||
| def infer_dtype(self, x_type): | |||||
| return x_type | |||||
| @bprop_getters.register(BpropWithWrongOutputShape) | |||||
| def get_bprop_with_wrong_output_shape(self): | |||||
| """Generate bprop for BpropWithWrongOutputShape""" | |||||
| ones = Tensor(np.ones([2,]).astype(np.int32)) | |||||
| def bprop(x, out, dout): | |||||
| return (ones,) | |||||
| return bprop | |||||
| class BpropWithWrongOutputShapeCell(nn.Cell): | |||||
| def __init__(self): | |||||
| super(BpropWithWrongOutputShapeCell, self).__init__() | |||||
| def construct(self, x): | |||||
| return BpropWithWrongOutputShape()(x) | |||||
| with pytest.raises(TypeError): | |||||
| C.grad_all(BpropWithWrongOutputShapeCell())(Tensor(np.ones([64, 10]).astype(np.int32))) | |||||
| @@ -13,6 +13,7 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| import numpy as np | import numpy as np | ||||
| import pytest | |||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| import mindspore.ops.operations as P | import mindspore.ops.operations as P | ||||
| @@ -154,22 +155,47 @@ def test_hook(): | |||||
| print(loss_output.asnumpy().shape) | print(loss_output.asnumpy().shape) | ||||
| bprop_debug = False | |||||
| class MulAdd(nn.Cell): | class MulAdd(nn.Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| super(MulAdd, self).__init__() | super(MulAdd, self).__init__() | ||||
| def construct(self, x, y): | def construct(self, x, y): | ||||
| return 2 * x + y | |||||
| return 2 * x * x + y * y | |||||
| def bprop(self, x, y, out, dout): | def bprop(self, x, y, out, dout): | ||||
| assert (x == 1) | |||||
| assert (y == 2) | |||||
| assert (out == 4) | |||||
| assert (dout == 1) | |||||
| return 3 * dout, 2 * y | |||||
| global bprop_debug | |||||
| bprop_debug = True | |||||
| return dout, 2 * y | |||||
| def test_custom_bprop(): | def test_custom_bprop(): | ||||
| mul_add = MulAdd() | mul_add = MulAdd() | ||||
| mul_add.bprop_debug = True | mul_add.bprop_debug = True | ||||
| assert C.grad_all(mul_add)(1, 2) == (3, 4) | |||||
| x = Tensor(np.array([1, 2, 3]).astype(np.int32)) | |||||
| y = Tensor(np.array([2, 3, 4]).astype(np.int32)) | |||||
| C.grad_all(mul_add)(x, y) | |||||
| assert bprop_debug | |||||
| class Net(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| def construct(self, x, y): | |||||
| return 2 * x * x + y * y | |||||
| def test_grad_all(): | |||||
| net = Net() | |||||
| x = Tensor(np.array([1, 2, 3]).astype(np.int32)) | |||||
| y = Tensor(np.array([2, 3, 4]).astype(np.int32)) | |||||
| res = C.grad_all(net)(x, y) | |||||
| print(res) | |||||
| def test_check_input(): | |||||
| net = Net() | |||||
| x = np.array([1, 2, 3]) | |||||
| y = np.array([2, 3, 4]) | |||||
| with pytest.raises(TypeError): | |||||
| net(x, y) | |||||
| @@ -46,6 +46,7 @@ def test_InsertGradientOf_1(): | |||||
| c = x * y | c = x * y | ||||
| return c | return c | ||||
| @ms_function | |||||
| def f(x, y): | def f(x, y): | ||||
| return C.grad_all(stop_test)(x, y) | return C.grad_all(stop_test)(x, y) | ||||
| @@ -80,6 +81,7 @@ def test_InsertGradientOf_2(): | |||||
| def f(x, y): | def f(x, y): | ||||
| return clip_test(x, y) | return clip_test(x, y) | ||||
| @ms_function | |||||
| def fd(x, y): | def fd(x, y): | ||||
| return C.grad_all(clip_test)(x, y) | return C.grad_all(clip_test)(x, y) | ||||
| @@ -16,6 +16,7 @@ | |||||
| import numpy as np | import numpy as np | ||||
| import pytest | import pytest | ||||
| import mindspore as ms | |||||
| import mindspore.common.dtype as mstype | import mindspore.common.dtype as mstype | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore import Parameter, ParameterTuple | from mindspore import Parameter, ParameterTuple | ||||
| @@ -81,16 +82,24 @@ def stop_test4(x, y): | |||||
| return e | return e | ||||
| @ms_function | |||||
| def grad_stop_test(x, y): | def grad_stop_test(x, y): | ||||
| """ grad_stop_test """ | """ grad_stop_test """ | ||||
| return C.grad_all(stop_test2)(x, y) | return C.grad_all(stop_test2)(x, y) | ||||
| @ms_function | |||||
| def grad_stop_test1(x, y): | def grad_stop_test1(x, y): | ||||
| """ grad_stop_test1 """ | """ grad_stop_test1 """ | ||||
| return C.grad_all(stop_test3)(x, y) | return C.grad_all(stop_test3)(x, y) | ||||
| @ms_function | |||||
| def grad_stop_test5(x, y): | |||||
| """ grad_stop_test5 """ | |||||
| return C.grad_all(stop_test5)(x, y) | |||||
| def test_stop(): | def test_stop(): | ||||
| """ test_stop """ | """ test_stop """ | ||||
| print("test_stop:", grad_stop_test(1, 1)) | print("test_stop:", grad_stop_test(1, 1)) | ||||
| @@ -103,7 +112,7 @@ def test_stop1(): | |||||
| def test_stop5(): | def test_stop5(): | ||||
| """ test_stop1 """ | """ test_stop1 """ | ||||
| print("test_stop5:", C.grad_all(stop_test5)(2, 3)) | |||||
| print("test_stop5:", grad_stop_test5(2, 3)) | |||||
| class GradWrap(nn.Cell): | class GradWrap(nn.Cell): | ||||
| @@ -247,7 +256,7 @@ def test_stop_gradient_4(): | |||||
| def stop_test(x): | def stop_test(x): | ||||
| return stop_gradient(x) | return stop_gradient(x) | ||||
| assert C.grad_all(stop_test)(1) == (0,) | |||||
| assert C.grad_all(stop_test)(Tensor(1, dtype=ms.int32)) == (0,) | |||||
| def test_stop_gradient_5(): | def test_stop_gradient_5(): | ||||
| @@ -257,7 +266,7 @@ def test_stop_gradient_5(): | |||||
| ret = x + y | ret = x + y | ||||
| return ret | return ret | ||||
| assert C.grad_all(stop_test)(1) == (1,) | |||||
| assert C.grad_all(stop_test)(Tensor(1, dtype=ms.int32)) == (1,) | |||||
| def test_stop_gradient_6(): | def test_stop_gradient_6(): | ||||
| @@ -266,7 +275,7 @@ def test_stop_gradient_6(): | |||||
| ret = stop_gradient(ret) | ret = stop_gradient(ret) | ||||
| return ret | return ret | ||||
| assert C.grad_all(stop_test)(1, 3) == (0, 0) | |||||
| assert C.grad_all(stop_test)(Tensor(1, dtype=ms.int32), Tensor(3, dtype=ms.int32)) == (0, 0) | |||||
| class PrimWithMultiOutputs(PrimitiveWithInfer): | class PrimWithMultiOutputs(PrimitiveWithInfer): | ||||