| @@ -980,7 +980,7 @@ std::vector<AnfNodePtr> PynativeExecutor::GetWeightsArgs(const py::object &weigh | |||
| } | |||
| } | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "training not paramter_tuple"; | |||
| MS_LOG(DEBUG) << "training not paramter_tuple"; | |||
| } | |||
| return w_args; | |||
| } | |||
| @@ -181,6 +181,9 @@ class Tensor(Tensor_): | |||
| def __imod__(self, other): | |||
| return self.__mod__(other) | |||
| def __pow__(self, other): | |||
| return tensor_operator_registry.get('__pow__')(self, other) | |||
| def __floordiv__(self, other): | |||
| return tensor_operator_registry.get('__floordiv__')(self, other) | |||
| @@ -176,7 +176,10 @@ class _Context: | |||
| self._context_switches.push(True, None) | |||
| else: | |||
| if self.enable_debug_runtime: | |||
| self.set_backend_policy("ge") | |||
| if self.device_target == "CPU": | |||
| self.set_backend_policy("vm") | |||
| else: | |||
| self.set_backend_policy("ge") | |||
| self._context_switches.push(False, None) | |||
| def set_backend_policy(self, policy): | |||
| @@ -16,6 +16,7 @@ | |||
| import time | |||
| import gc | |||
| from collections import OrderedDict | |||
| import numpy | |||
| from mindspore import log as logger | |||
| from .. import context | |||
| from ..common import dtype as mstype | |||
| @@ -211,6 +212,9 @@ class Cell: | |||
| if context.get_context("mode") == context.GRAPH_MODE: | |||
| out = self.compile_and_run(*inputs) | |||
| return out | |||
| for item in inputs: | |||
| if isinstance(item, numpy.ndarray): | |||
| raise TypeError("cell inputs should not be numpy array.") | |||
| self.init_parameters_data() | |||
| orign_grad = [] | |||
| if self.requires_grad is True: | |||
| @@ -17,6 +17,7 @@ | |||
| """Basic composite operations.""" | |||
| from functools import partial | |||
| from types import FunctionType | |||
| from mindspore import context | |||
| from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, \ | |||
| @@ -25,6 +26,7 @@ from ...common import dtype as mstype | |||
| from ...common.api import ms_function, _pynative_exec, _wrap_func | |||
| from .. import functional as F | |||
| from ...common.parameter import Parameter | |||
| from ...common.tensor import Tensor | |||
| __all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_] | |||
| @@ -114,37 +116,48 @@ class GradOperation(GradOperation_): | |||
| self.fn = None | |||
| self.need_forward = False | |||
| def _pynative_forward_run(self, args, fn): | |||
| """ Pynative forward run to build grad graph. """ | |||
| if self.sens_param: | |||
| args = args[:-1] | |||
| if isinstance(fn, FunctionType): | |||
| _pynative_exec.set_grad_flag(True) | |||
| _pynative_exec.new_graph(fn, *args) | |||
| output = fn(*args) | |||
| _pynative_exec.end_graph(fn, output, *args) | |||
| else: | |||
| if fn.is_run and not fn.requires_grad: | |||
| raise ValueError("obj must set_grad.") | |||
| if not fn.is_run: | |||
| self.need_forward = True | |||
| print("already has forward run before grad by user") | |||
| if self.need_forward: | |||
| fn.set_grad() | |||
| fn(*args) | |||
| def __call__(self, fn, weights=None): | |||
| grad_ = GradOperation('grad', self.get_all, self.get_by_list, self.sens_param) | |||
| if self.grad_fn is None or self.fn != fn: | |||
| if self.get_by_list: | |||
| if context.get_context("mode") == context.GRAPH_MODE: | |||
| if context.get_context("mode") == context.GRAPH_MODE: | |||
| if self.get_by_list: | |||
| @ms_function(obj=fn) | |||
| def after_grad(*args): | |||
| return grad_(fn, weights)(*args) | |||
| else: | |||
| @_wrap_func | |||
| @ms_function(obj=fn) | |||
| def after_grad(*args): | |||
| if fn.is_run and not fn.requires_grad: | |||
| raise ValueError("obj must set_grad.") | |||
| if not fn.is_run: | |||
| self.need_forward = True | |||
| print("already has forward run before grad by user") | |||
| if self.need_forward: | |||
| fn.set_grad() | |||
| if self.sens_param: | |||
| f_args = args[:-1] | |||
| fn(*f_args) | |||
| else: | |||
| fn(*args) | |||
| _pynative_exec.grad(grad_, fn, weights, *args) | |||
| out = _pynative_exec(*args) | |||
| _pynative_exec.clear() | |||
| return out | |||
| return grad_(fn)(*args) | |||
| else: | |||
| @ms_function(obj=fn) | |||
| @_wrap_func | |||
| def after_grad(*args): | |||
| return grad_(fn)(*args) | |||
| for arg in args: | |||
| if not isinstance(arg, Tensor): | |||
| raise TypeError("grad inputs should be tensor in pynative mode") | |||
| self._pynative_forward_run(args, fn) | |||
| _pynative_exec.grad(grad_, fn, weights, *args) | |||
| out = _pynative_exec(*args) | |||
| _pynative_exec.clear() | |||
| return out | |||
| self.grad_fn = after_grad | |||
| self.fn = fn | |||
| return self.grad_fn | |||
| @@ -166,6 +166,7 @@ tensor_operator_registry.register('__sub__', tensor_sub) | |||
| tensor_operator_registry.register('__mul__', tensor_mul) | |||
| tensor_operator_registry.register('__truediv__', tensor_div) | |||
| tensor_operator_registry.register('__mod__', tensor_mod) | |||
| tensor_operator_registry.register('__pow__', tensor_pow) | |||
| tensor_operator_registry.register('__floordiv__', tensor_floordiv) | |||
| #ms cannot support Tensor(True) compare | |||
| tensor_operator_registry.register('__eq__', equal) | |||
| @@ -228,6 +228,7 @@ def test_biasadd_3d(): | |||
| error = np.ones(shape=[3, 4, 8]) * 1.0e-6 | |||
| context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") | |||
| net = BiasAdd() | |||
| net.set_grad() | |||
| result = net(x, b) | |||
| diff = result.asnumpy() - expect | |||
| assert np.all(diff < error) | |||
| @@ -45,6 +45,7 @@ def test_net_infer(): | |||
| def test_assign_in_while(): | |||
| context.set_context(device_target="Ascend") | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| class Net(nn.Cell): | |||
| def __init__(self, input_shape): | |||
| @@ -16,6 +16,7 @@ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore as ms | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore.nn as nn | |||
| from mindspore import Parameter | |||
| @@ -24,12 +25,15 @@ from mindspore.common.initializer import initializer | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import operations as P | |||
| from ....mindspore_test_framework.utils.bprop_util import bprop | |||
| from .....mindspore_test_framework.utils.bprop_util import bprop | |||
| def setup_module(module): | |||
| context.set_context(mode=context.PYNATIVE_MODE) | |||
| context.set_context(device_target="CPU") | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| def teardown_module(module): | |||
| context.set_context(device_target="Ascend") | |||
| class MulAdd(nn.Cell): | |||
| def __init__(self): | |||
| @@ -45,7 +49,9 @@ class MulAdd(nn.Cell): | |||
| def test_grad_mul_add(): | |||
| mul_add = MulAdd() | |||
| assert C.grad_all(mul_add)(1, 2) == (2, 4) | |||
| x = Tensor(1, dtype=ms.int32) | |||
| y = Tensor(2, dtype=ms.int32) | |||
| assert C.grad_all(mul_add)(x, y) == (2, 4) | |||
| class InlineMulADD(nn.Cell): | |||
| @@ -60,7 +66,9 @@ class InlineMulADD(nn.Cell): | |||
| def test_grad_inline_mul_add(): | |||
| inline_mul_add = InlineMulADD() | |||
| assert C.grad_all(inline_mul_add)(1, 2) == (3, 6) | |||
| x = Tensor(1, dtype=ms.int32) | |||
| y = Tensor(2, dtype=ms.int32) | |||
| assert C.grad_all(inline_mul_add)(x, y) == (3, 6) | |||
| class WithParameter(nn.Cell): | |||
| @@ -93,7 +101,9 @@ class WithNoBprop(nn.Cell): | |||
| def test_with_no_bprop(): | |||
| with_no_bprop = WithNoBprop() | |||
| assert C.grad_all(with_no_bprop)(1, 2) == (2, 1) | |||
| x = Tensor(1, dtype=ms.int32) | |||
| y = Tensor(2, dtype=ms.int32) | |||
| assert C.grad_all(with_no_bprop)(x, y) == (2, 1) | |||
| def test_grad_in_bprop_1(): | |||
| @@ -19,21 +19,27 @@ | |||
| @Desc : | |||
| """ | |||
| import logging | |||
| import pytest | |||
| import numpy as np | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| from mindspore import Tensor | |||
| from mindspore import context | |||
| from mindspore.ops import composite as C | |||
| from mindspore.common.api import ms_function, _executor | |||
| from mindspore.ops._grad.grad_base import bprop_getters | |||
| from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer | |||
| from mindspore.ops.functional import tensor_add | |||
| from ...ut_filter import non_graph_engine | |||
| # pylint: disable=W0613 | |||
| # pylint: disable=W0613,W0612 | |||
| # W0613: unused-argument | |||
| log = logging.getLogger("test") | |||
| log.setLevel(level=logging.ERROR) | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| # Test case: use the parse obj interface use default parameter | |||
| @@ -135,3 +141,113 @@ def test_net_with_ndarray(): | |||
| input_data = np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32') | |||
| net(ms.Tensor(input_data)) | |||
| def test_bprop_with_wrong_output_num(): | |||
| context.set_context(check_bprop=True) | |||
| class BpropWithWrongOutputNum(PrimitiveWithInfer): | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| super(BpropWithWrongOutputNum, self).__init__('BpropWithWrongOutputNum') | |||
| def __call__(self, x, y): | |||
| return x | |||
| def infer_shape(self, x_shape, yshape): | |||
| return x_shape | |||
| def infer_dtype(self, x_type, y_type): | |||
| return x_type | |||
| @bprop_getters.register(BpropWithWrongOutputNum) | |||
| def get_bprop_with_wrong_output_num(self): | |||
| """Generate bprop for BpropWithWrongOutputNum""" | |||
| def bprop(x, y, out, dout): | |||
| return (dout,) | |||
| return bprop | |||
| class BpropWithWrongOutputNumCell(nn.Cell): | |||
| def __init__(self): | |||
| super(BpropWithWrongOutputNumCell, self).__init__() | |||
| def construct(self, x, y): | |||
| return BpropWithWrongOutputNum()(x, y) | |||
| with pytest.raises(TypeError): | |||
| C.grad_all(BpropWithWrongOutputNumCell())(1, 2) | |||
| def test_bprop_with_wrong_output_type(): | |||
| context.set_context(check_bprop=True) | |||
| class BpropWithWrongOutputType(PrimitiveWithInfer): | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| super(BpropWithWrongOutputType, self).__init__('BpropWithWrongOutputType') | |||
| def __call__(self, x): | |||
| return x | |||
| def infer_shape(self, x_shape): | |||
| return x_shape | |||
| def infer_dtype(self, x_type): | |||
| return x_type | |||
| @bprop_getters.register(BpropWithWrongOutputType) | |||
| def get_bprop_with_wrong_output_type(self): | |||
| """Generate bprop for BpropWithWrongOutputType""" | |||
| def bprop(x, out, dout): | |||
| return (1,) | |||
| return bprop | |||
| class BpropWithWrongOutputTypeCell(nn.Cell): | |||
| def __init__(self): | |||
| super(BpropWithWrongOutputTypeCell, self).__init__() | |||
| def construct(self, x): | |||
| return BpropWithWrongOutputType()(x) | |||
| with pytest.raises(TypeError): | |||
| C.grad_all(BpropWithWrongOutputTypeCell())(Tensor(np.ones([64, 10]).astype(np.int32))) | |||
| def test_bprop_with_wrong_output_shape(): | |||
| context.set_context(check_bprop=True) | |||
| class BpropWithWrongOutputShape(PrimitiveWithInfer): | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| super(BpropWithWrongOutputShape, self).__init__('BpropWithWrongOutputShape') | |||
| def __call__(self, x): | |||
| return x | |||
| def infer_shape(self, x_shape): | |||
| return x_shape | |||
| def infer_dtype(self, x_type): | |||
| return x_type | |||
| @bprop_getters.register(BpropWithWrongOutputShape) | |||
| def get_bprop_with_wrong_output_shape(self): | |||
| """Generate bprop for BpropWithWrongOutputShape""" | |||
| ones = Tensor(np.ones([2,]).astype(np.int32)) | |||
| def bprop(x, out, dout): | |||
| return (ones,) | |||
| return bprop | |||
| class BpropWithWrongOutputShapeCell(nn.Cell): | |||
| def __init__(self): | |||
| super(BpropWithWrongOutputShapeCell, self).__init__() | |||
| def construct(self, x): | |||
| return BpropWithWrongOutputShape()(x) | |||
| with pytest.raises(TypeError): | |||
| net = BpropWithWrongOutputShapeCell() | |||
| net.set_grad() | |||
| C.grad_all(net)(Tensor(np.ones([64, 10]).astype(np.int32))) | |||
| @@ -78,3 +78,9 @@ def test_tensor_imul(): | |||
| y = Tensor(np.ones([3, 3, 3, 3]).astype(np.float32)) | |||
| x *= y | |||
| assert x.asnumpy()[0][0][0][0] == 1.0 | |||
| def test_tensor_pow(): | |||
| x = Tensor(np.ones([3, 3, 3, 3]).astype(np.float32) * 2) | |||
| y = x ** 3 | |||
| assert y.asnumpy()[0][0][0][0] == 8.0 | |||
| @@ -89,7 +89,11 @@ def test_scalar_cast_grad(): | |||
| output = F.scalar_cast(x, input_t) | |||
| return output | |||
| gfn = C.grad(fx_cast)(input_x) | |||
| @ms_function | |||
| def grad_fx_cast(input_x): | |||
| return C.grad(fx_cast)(input_x) | |||
| gfn = grad_fx_cast(input_x) | |||
| expect_dx = 1 | |||
| assert gfn == expect_dx | |||
| @@ -133,25 +137,6 @@ def test_transpose_grad(): | |||
| assert np.all(gout[0].asnumpy() == expect) | |||
| @non_graph_engine | |||
| def test_squeeze_grad(): | |||
| """ test_squeeze_grad """ | |||
| input_tensor = Tensor(np.ones(shape=[3, 2, 1])) | |||
| squeeze = P.Squeeze(2) | |||
| def fn(x): | |||
| output = squeeze(x) | |||
| return output | |||
| out = fn(input_tensor) | |||
| gfn = grad_all_with_sens(fn) | |||
| sens = Tensor(np.ones_like(out.asnumpy())) | |||
| args = [input_tensor, sens] | |||
| gout = gfn(*args) | |||
| expect = np.ones([3, 2, 1]) | |||
| assert np.all(gout[0].asnumpy() == expect) | |||
| def test_select_grad(): | |||
| """ test_select_grad """ | |||
| select = P.Select() | |||
| @@ -176,6 +161,25 @@ def test_select_grad(): | |||
| assert np.all(gout[2].asnumpy() == expect_y) | |||
| @non_graph_engine | |||
| def test_squeeze_grad(): | |||
| """ test_squeeze_grad """ | |||
| input_tensor = Tensor(np.ones(shape=[3, 2, 1])) | |||
| squeeze = P.Squeeze(2) | |||
| def fn(x): | |||
| output = squeeze(x) | |||
| return output | |||
| out = fn(input_tensor) | |||
| gfn = grad_all_with_sens(fn) | |||
| sens = Tensor(np.ones_like(out.asnumpy())) | |||
| args = [input_tensor, sens] | |||
| gout = gfn(*args) | |||
| expect = np.ones([3, 2, 1]) | |||
| assert np.all(gout[0].asnumpy() == expect) | |||
| def test_SubGrad(): | |||
| """ test_SubGrad """ | |||
| input_x = Tensor(np.array([[2, 2]])) | |||
| @@ -16,6 +16,7 @@ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore as ms | |||
| import mindspore.nn as nn | |||
| from mindspore import context | |||
| from mindspore.common import dtype as mstype | |||
| @@ -23,8 +24,6 @@ from mindspore.common.parameter import Parameter, ParameterTuple | |||
| from mindspore.common.tensor import Tensor | |||
| from mindspore.ops import composite as C | |||
| from mindspore.ops import operations as P | |||
| from mindspore.ops._grad.grad_base import bprop_getters | |||
| from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer | |||
| from ..ut_filter import non_graph_engine | |||
| from ....mindspore_test_framework.utils.check_gradient import ( | |||
| ms_function, check_jacobian, Tensor, NNGradChecker, | |||
| @@ -156,14 +155,14 @@ def test_if_always_true(): | |||
| @non_graph_engine | |||
| def test_f(): | |||
| """ test_f """ | |||
| res = mainf(3, 2) | |||
| res = mainf(Tensor(3, dtype=ms.int32), Tensor(2, dtype=ms.int32)) | |||
| assert res == (2, 3) | |||
| @non_graph_engine | |||
| def test_grad_add_mul(): | |||
| """ test_grad_add_mul """ | |||
| res = grad_add_mul(3, 2) | |||
| res = grad_add_mul(Tensor(3, dtype=ms.int32), Tensor(2, dtype=ms.int32)) | |||
| assert res == (2, 7) | |||
| @@ -262,17 +261,19 @@ def test_if_tensor(): | |||
| assert res == Tensor(np.ones([1]).astype(np.int32) * 4) | |||
| @ms_function | |||
| def rec(x): | |||
| """ rec """ | |||
| if x > 0: | |||
| return rec(x - 1) | |||
| return x | |||
| @ms_function | |||
| def grad_rec(input_x): | |||
| return C.grad(rec)(input_x) | |||
| def test_grad_rec(): | |||
| """ test_grad_rec """ | |||
| res = C.grad(rec)(10) | |||
| res = grad_rec(3) | |||
| assert res == 1 | |||
| @@ -282,7 +283,6 @@ def test_me_rec(): | |||
| assert res == 0 | |||
| @ms_function | |||
| def t2_while(x, y): | |||
| out = y - x | |||
| i = 0 | |||
| @@ -298,8 +298,10 @@ def test_while2(): | |||
| def test_grad_while2(): | |||
| res = C.grad(t2_while)(2, 3) | |||
| assert res == 3 | |||
| @ms_function | |||
| def df_t2_while(input_x, input_y): | |||
| return C.grad(t2_while)(input_x, input_y) | |||
| assert df_t2_while(2, 3) == 3 | |||
| def if_test(a, b): | |||
| @@ -316,7 +318,7 @@ def grad_if(x, y): | |||
| def test_grad_if(): | |||
| """ test_grad_if """ | |||
| assert grad_if(5, 4) == (3, 0) | |||
| assert grad_if(Tensor(5, dtype=ms.int32), Tensor(4, dtype=ms.int32)) == (3, 0) | |||
| # While loop is not unrolled in forward and backward graphs. | |||
| @@ -421,7 +423,7 @@ def grad_while(x): | |||
| def test_grad_while(): | |||
| """ test_grad_while """ | |||
| assert grad_while(5) == (60,) | |||
| assert grad_while(Tensor(5, dtype=ms.int32)) == (60,) | |||
| @ms_function | |||
| @@ -438,8 +440,10 @@ def test_factorial(): | |||
| def test_grad_factorial(): | |||
| res = C.grad(factorial)(3) | |||
| assert res == 11 | |||
| @ms_function | |||
| def df_factorial(x): | |||
| return C.grad(factorial)(x) | |||
| assert df_factorial(3) == 11 | |||
| @ms_function | |||
| @@ -513,7 +517,7 @@ def _for(x): | |||
| ret = ret * i | |||
| return ret | |||
| @ms_function | |||
| def grad_for(x): | |||
| """ grad_for """ | |||
| return C.grad_all(_for)(x) | |||
| @@ -786,7 +790,10 @@ def multi_outputs(x, y): | |||
| def test_grad_multi_outputs(): | |||
| assert C.grad_all_with_sens(multi_outputs)(2, 3, (1, 1)) == (4, 4) | |||
| @ms_function | |||
| def df_multi_outputs(x, y): | |||
| return C.grad_all_with_sens(multi_outputs)(x, y, (1, 1)) | |||
| assert df_multi_outputs(2, 3) == (4, 4) | |||
| @ms_function | |||
| @@ -813,7 +820,7 @@ def grad_refactor_simple_1(x, y): | |||
| def test_grad_refactor_simple_1(): | |||
| assert C.grad_all(grad_refactor_simple_1)(2, 1) == (4, 2) | |||
| assert C.grad_all(grad_refactor_simple_1)(Tensor(2, dtype=ms.int32), Tensor(1, dtype=ms.int32)) == (4, 2) | |||
| def grad_refactor_simple_2(x, y, z): | |||
| @@ -822,7 +829,10 @@ def grad_refactor_simple_2(x, y, z): | |||
| def test_grad_refactor_simple_2(): | |||
| assert C.grad_all(grad_refactor_simple_2)(2, 3, 0) == (7, 4, 7) | |||
| x = Tensor(2, dtype=ms.int32) | |||
| y = Tensor(3, dtype=ms.int32) | |||
| z = Tensor(0, dtype=ms.int32) | |||
| assert C.grad_all(grad_refactor_simple_2)(x, y, z) == (7, 4, 7) | |||
| def grad_refactor_1(a, b): | |||
| @@ -835,7 +845,7 @@ def grad_refactor_1(a, b): | |||
| def test_grad_refactor_1(): | |||
| assert C.grad_all(grad_refactor_1)(2, 3) == (3, 2) | |||
| assert C.grad_all(grad_refactor_1)(Tensor(2, dtype=ms.int32), Tensor(3, dtype=ms.int32)) == (3, 2) | |||
| def grad_refactor_2(a, b): | |||
| @@ -848,7 +858,7 @@ def grad_refactor_2(a, b): | |||
| def test_grad_refactor_2(): | |||
| assert C.grad_all(grad_refactor_2)(2, 3) == (27, 54) | |||
| assert C.grad_all(grad_refactor_2)(Tensor(2, dtype=ms.int32), Tensor(3, dtype=ms.int32)) == (27, 54) | |||
| def grad_refactor_3(a): | |||
| @@ -859,7 +869,10 @@ def grad_refactor_3(a): | |||
| def test_grad_refactor_3(): | |||
| assert C.grad_all(grad_refactor_3)(3) == (3,) | |||
| @ms_function | |||
| def df_refactor_3(x): | |||
| return C.grad_all(grad_refactor_3)(x) | |||
| assert df_refactor_3(3) == (3,) | |||
| def grad_refactor_4(a): | |||
| @@ -870,7 +883,7 @@ def grad_refactor_4(a): | |||
| def test_grad_refactor_4(): | |||
| assert C.grad_all(grad_refactor_4)(4) == (3,) | |||
| assert C.grad_all(grad_refactor_4)(Tensor(4, dtype=ms.int32)) == (3,) | |||
| def grad_refactor_5(a): | |||
| @@ -881,7 +894,10 @@ def grad_refactor_5(a): | |||
| def test_grad_refactor_5(): | |||
| assert C.grad_all(grad_refactor_5)(1) == (1,) | |||
| @ms_function | |||
| def df_refactor_5(x): | |||
| return C.grad_all(grad_refactor_5)(x) | |||
| assert df_refactor_5(1) == (1,) | |||
| def grad_refactor_6(a, b): | |||
| @@ -892,7 +908,7 @@ def grad_refactor_6(a, b): | |||
| def test_grad_refactor_6(): | |||
| assert C.grad_all(grad_refactor_6)(3, 2) == (3, 1) | |||
| assert C.grad_all(grad_refactor_6)(Tensor(3, dtype=ms.int32), Tensor(2, dtype=ms.int32)) == (3, 1) | |||
| def grad_refactor_while(x): | |||
| @@ -904,7 +920,10 @@ def grad_refactor_while(x): | |||
| def test_grad_refactor_9(): | |||
| assert C.grad_all(grad_refactor_while)(3) == (6,) | |||
| @ms_function | |||
| def df_refactor_while(input_x): | |||
| return C.grad_all(grad_refactor_while)(input_x) | |||
| assert df_refactor_while(3) == (6,) | |||
| def grad_refactor__while_1(x): | |||
| @@ -919,7 +938,7 @@ def grad_refactor__while_1(x): | |||
| def test_grad_refactor_10(): | |||
| """ test_grad_while """ | |||
| assert C.grad_all(grad_refactor__while_1)(5) == (60,) | |||
| assert C.grad_all(grad_refactor__while_1)(Tensor(5, dtype=ms.int32)) == (60,) | |||
| def test_grad_refactor_11(): | |||
| @@ -985,7 +1004,10 @@ def grad_refactor_14(a, b): | |||
| def test_grad_refactor_14(): | |||
| assert C.grad_all(grad_refactor_14)(2, 3) == (3, 9) | |||
| @ms_function | |||
| def df_refactor_14(x, y): | |||
| return C.grad_all(grad_refactor_14)(x, y) | |||
| assert df_refactor_14(2, 3) == (3, 9) | |||
| # pylint: disable=using-constant-test | |||
| @@ -1009,111 +1031,3 @@ def test_grad_if_defer_inline(): | |||
| inp = Tensor(np.ones([128, 96]).astype(np.float32)) | |||
| grads = C.grad_all(network)(inp) | |||
| assert grads == (Tensor(np.full([128, 96], 0.6, dtype=np.float32)),) | |||
| def test_bprop_with_wrong_output_num(): | |||
| context.set_context(check_bprop=True) | |||
| class BpropWithWrongOutputNum(PrimitiveWithInfer): | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| super(BpropWithWrongOutputNum, self).__init__('BpropWithWrongOutputNum') | |||
| def __call__(self, x, y): | |||
| return x | |||
| def infer_shape(self, x_shape, yshape): | |||
| return x_shape | |||
| def infer_dtype(self, x_type, y_type): | |||
| return x_type | |||
| @bprop_getters.register(BpropWithWrongOutputNum) | |||
| def get_bprop_with_wrong_output_num(self): | |||
| """Generate bprop for BpropWithWrongOutputNum""" | |||
| def bprop(x, y, out, dout): | |||
| return (dout,) | |||
| return bprop | |||
| class BpropWithWrongOutputNumCell(nn.Cell): | |||
| def __init__(self): | |||
| super(BpropWithWrongOutputNumCell, self).__init__() | |||
| def construct(self, x, y): | |||
| return BpropWithWrongOutputNum()(x, y) | |||
| with pytest.raises(TypeError): | |||
| C.grad_all(BpropWithWrongOutputNumCell())(1, 2) | |||
| def test_bprop_with_wrong_output_type(): | |||
| context.set_context(check_bprop=True) | |||
| class BpropWithWrongOutputType(PrimitiveWithInfer): | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| super(BpropWithWrongOutputType, self).__init__('BpropWithWrongOutputType') | |||
| def __call__(self, x): | |||
| return x | |||
| def infer_shape(self, x_shape): | |||
| return x_shape | |||
| def infer_dtype(self, x_type): | |||
| return x_type | |||
| @bprop_getters.register(BpropWithWrongOutputType) | |||
| def get_bprop_with_wrong_output_type(self): | |||
| """Generate bprop for BpropWithWrongOutputType""" | |||
| def bprop(x, out, dout): | |||
| return (1,) | |||
| return bprop | |||
| class BpropWithWrongOutputTypeCell(nn.Cell): | |||
| def __init__(self): | |||
| super(BpropWithWrongOutputTypeCell, self).__init__() | |||
| def construct(self, x): | |||
| return BpropWithWrongOutputType()(x) | |||
| with pytest.raises(TypeError): | |||
| C.grad_all(BpropWithWrongOutputTypeCell())(Tensor(np.ones([64, 10]).astype(np.int32))) | |||
| def test_bprop_with_wrong_output_shape(): | |||
| context.set_context(check_bprop=True) | |||
| class BpropWithWrongOutputShape(PrimitiveWithInfer): | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| super(BpropWithWrongOutputShape, self).__init__('BpropWithWrongOutputShape') | |||
| def __call__(self, x): | |||
| return x | |||
| def infer_shape(self, x_shape): | |||
| return x_shape | |||
| def infer_dtype(self, x_type): | |||
| return x_type | |||
| @bprop_getters.register(BpropWithWrongOutputShape) | |||
| def get_bprop_with_wrong_output_shape(self): | |||
| """Generate bprop for BpropWithWrongOutputShape""" | |||
| ones = Tensor(np.ones([2,]).astype(np.int32)) | |||
| def bprop(x, out, dout): | |||
| return (ones,) | |||
| return bprop | |||
| class BpropWithWrongOutputShapeCell(nn.Cell): | |||
| def __init__(self): | |||
| super(BpropWithWrongOutputShapeCell, self).__init__() | |||
| def construct(self, x): | |||
| return BpropWithWrongOutputShape()(x) | |||
| with pytest.raises(TypeError): | |||
| C.grad_all(BpropWithWrongOutputShapeCell())(Tensor(np.ones([64, 10]).astype(np.int32))) | |||
| @@ -13,6 +13,7 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.nn as nn | |||
| import mindspore.ops.operations as P | |||
| @@ -154,22 +155,47 @@ def test_hook(): | |||
| print(loss_output.asnumpy().shape) | |||
| bprop_debug = False | |||
| class MulAdd(nn.Cell): | |||
| def __init__(self): | |||
| super(MulAdd, self).__init__() | |||
| def construct(self, x, y): | |||
| return 2 * x + y | |||
| return 2 * x * x + y * y | |||
| def bprop(self, x, y, out, dout): | |||
| assert (x == 1) | |||
| assert (y == 2) | |||
| assert (out == 4) | |||
| assert (dout == 1) | |||
| return 3 * dout, 2 * y | |||
| global bprop_debug | |||
| bprop_debug = True | |||
| return dout, 2 * y | |||
| def test_custom_bprop(): | |||
| mul_add = MulAdd() | |||
| mul_add.bprop_debug = True | |||
| assert C.grad_all(mul_add)(1, 2) == (3, 4) | |||
| x = Tensor(np.array([1, 2, 3]).astype(np.int32)) | |||
| y = Tensor(np.array([2, 3, 4]).astype(np.int32)) | |||
| C.grad_all(mul_add)(x, y) | |||
| assert bprop_debug | |||
| class Net(nn.Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| def construct(self, x, y): | |||
| return 2 * x * x + y * y | |||
| def test_grad_all(): | |||
| net = Net() | |||
| x = Tensor(np.array([1, 2, 3]).astype(np.int32)) | |||
| y = Tensor(np.array([2, 3, 4]).astype(np.int32)) | |||
| res = C.grad_all(net)(x, y) | |||
| print(res) | |||
| def test_check_input(): | |||
| net = Net() | |||
| x = np.array([1, 2, 3]) | |||
| y = np.array([2, 3, 4]) | |||
| with pytest.raises(TypeError): | |||
| net(x, y) | |||
| @@ -46,6 +46,7 @@ def test_InsertGradientOf_1(): | |||
| c = x * y | |||
| return c | |||
| @ms_function | |||
| def f(x, y): | |||
| return C.grad_all(stop_test)(x, y) | |||
| @@ -80,6 +81,7 @@ def test_InsertGradientOf_2(): | |||
| def f(x, y): | |||
| return clip_test(x, y) | |||
| @ms_function | |||
| def fd(x, y): | |||
| return C.grad_all(clip_test)(x, y) | |||
| @@ -16,6 +16,7 @@ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore as ms | |||
| import mindspore.common.dtype as mstype | |||
| import mindspore.nn as nn | |||
| from mindspore import Parameter, ParameterTuple | |||
| @@ -81,16 +82,24 @@ def stop_test4(x, y): | |||
| return e | |||
| @ms_function | |||
| def grad_stop_test(x, y): | |||
| """ grad_stop_test """ | |||
| return C.grad_all(stop_test2)(x, y) | |||
| @ms_function | |||
| def grad_stop_test1(x, y): | |||
| """ grad_stop_test1 """ | |||
| return C.grad_all(stop_test3)(x, y) | |||
| @ms_function | |||
| def grad_stop_test5(x, y): | |||
| """ grad_stop_test5 """ | |||
| return C.grad_all(stop_test5)(x, y) | |||
| def test_stop(): | |||
| """ test_stop """ | |||
| print("test_stop:", grad_stop_test(1, 1)) | |||
| @@ -103,7 +112,7 @@ def test_stop1(): | |||
| def test_stop5(): | |||
| """ test_stop1 """ | |||
| print("test_stop5:", C.grad_all(stop_test5)(2, 3)) | |||
| print("test_stop5:", grad_stop_test5(2, 3)) | |||
| class GradWrap(nn.Cell): | |||
| @@ -247,7 +256,7 @@ def test_stop_gradient_4(): | |||
| def stop_test(x): | |||
| return stop_gradient(x) | |||
| assert C.grad_all(stop_test)(1) == (0,) | |||
| assert C.grad_all(stop_test)(Tensor(1, dtype=ms.int32)) == (0,) | |||
| def test_stop_gradient_5(): | |||
| @@ -257,7 +266,7 @@ def test_stop_gradient_5(): | |||
| ret = x + y | |||
| return ret | |||
| assert C.grad_all(stop_test)(1) == (1,) | |||
| assert C.grad_all(stop_test)(Tensor(1, dtype=ms.int32)) == (1,) | |||
| def test_stop_gradient_6(): | |||
| @@ -266,7 +275,7 @@ def test_stop_gradient_6(): | |||
| ret = stop_gradient(ret) | |||
| return ret | |||
| assert C.grad_all(stop_test)(1, 3) == (0, 0) | |||
| assert C.grad_all(stop_test)(Tensor(1, dtype=ms.int32), Tensor(3, dtype=ms.int32)) == (0, 0) | |||
| class PrimWithMultiOutputs(PrimitiveWithInfer): | |||