|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440 |
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """ test_stop_gradient """
- import numpy as np
- import pytest
-
- import mindspore as ms
- import mindspore.common.dtype as mstype
- import mindspore.nn as nn
- from mindspore import Parameter, ParameterTuple
- from mindspore import Tensor
- from mindspore import context
- from mindspore.common.api import ms_function
- from mindspore.ops import composite as C
- from mindspore.ops import operations as P
- from mindspore.ops.functional import stop_gradient
- from mindspore.ops.primitive import prim_attr_register, PrimitiveWithInfer
- from ..ut_filter import non_graph_engine
- from ....mindspore_test_framework.utils.bprop_util import bprop
-
-
- def setup_module(module):
- context.set_context(mode=context.PYNATIVE_MODE)
-
-
- def stop_func(x, y):
- """ stop_func"""
- c = x * y
- c_s = x + y
- return c_s, c
-
-
- def stop_test1(x, y):
- """ stop_test1 """
- c = x * y
- c_s = stop_gradient(c)
- return c_s
-
-
- def stop_test2(x, y):
- """ stop_test2 """
- c = x * y
- c_s = stop_gradient(c)
- d = c_s + x * y
- return d * y
-
-
- def stop_test3(x, y):
- """ stop_test3 """
- x = x * y
- z = stop_test1(x, y)
- k = z * y
- return k
-
-
- def stop_test5(x, y):
- """ stop_test3 """
- x = x + y
- o1, o2 = stop_func(x, y)
- c = stop_gradient(o1)
- c = o2 + c
- return c
-
-
- def stop_test4(x, y):
- """ stop_test4 """
- c = x + y
- c_s = stop_gradient(c)
- e = c + c_s
- return e
-
-
- @ms_function
- def grad_stop_test(x, y):
- """ grad_stop_test """
- return C.grad_all(stop_test2)(x, y)
-
-
- @ms_function
- def grad_stop_test1(x, y):
- """ grad_stop_test1 """
- return C.grad_all(stop_test3)(x, y)
-
-
- @ms_function
- def grad_stop_test5(x, y):
- """ grad_stop_test5 """
- return C.grad_all(stop_test5)(x, y)
-
-
- def test_stop():
- """ test_stop """
- print("test_stop:", grad_stop_test(1, 1))
-
-
- def test_stop1():
- """ test_stop1 """
- print("test_stop1:", grad_stop_test1(2, 3))
-
-
- def test_stop5():
- """ test_stop1 """
- print("test_stop5:", grad_stop_test5(2, 3))
-
-
- class GradWrap(nn.Cell):
- """ GradWrap definition """
-
- def __init__(self, network):
- super(GradWrap, self).__init__()
- self.network = network
- self.weights = ParameterTuple(network.get_parameters())
-
- @ms_function
- def construct(self, x, label):
- weights = self.weights
- return C.grad_by_list(self.network, weights)(x, label)
-
-
- @non_graph_engine
- def test_softmaxloss_grad():
- """ test_softmaxloss_grad """
-
- class NetWithLossClass(nn.Cell):
- """ NetWithLossClass definition """
-
- def __init__(self, network):
- super(NetWithLossClass, self).__init__()
- self.loss = nn.SoftmaxCrossEntropyWithLogits()
- self.network = network
-
- @ms_function
- def construct(self, x, label):
- predict = self.network(x)
- return self.loss(predict, label)
-
- class Net(nn.Cell):
- """ Net definition """
-
- def __init__(self):
- super(Net, self).__init__()
- self.weight = Parameter(Tensor(np.ones([64, 10])), name="weight")
- self.bias = Parameter(Tensor(np.ones([10]).astype(np.float32)), name="bias")
- self.fc = P.MatMul()
- self.fc2 = nn.Dense(10, 10)
- self.biasAdd = P.BiasAdd()
- self.relu = nn.ReLU()
- self.cast = P.Cast()
-
- @ms_function
- def construct(self, x):
- x = self.fc(x, self.weight)
- x = self.cast(x, mstype.float32)
- x = self.relu(self.fc2(x))
- x = self.fc2(x)
- x = stop_gradient(x)
- x = self.biasAdd(x, self.bias)
- return x
-
- net = GradWrap(NetWithLossClass(Net()))
-
- predict = Tensor(np.ones([1, 64]))
- label = Tensor(np.zeros([1, 10]).astype(np.float32))
- print("pynative run")
- out = net(predict, label)
- print("out:", out)
-
-
- def test_stop_gradient_1():
- class Mul(nn.Cell):
- def __init__(self):
- super(Mul, self).__init__()
-
- @ms_function
- def construct(self, x, y):
- ret = x * y
- ret = stop_gradient(ret)
- return ret
-
- dx, dy = bprop(Mul(), Tensor(np.ones([2, 2]).astype(np.float32)),
- Tensor(np.ones([2, 2]).astype(np.float32)), wrt=['inputs'])
- expect = np.zeros([2, 2])
- assert (dx.asnumpy() == expect).all()
- assert (dy.asnumpy() == expect).all()
-
-
- def test_stop_gradient_2():
- class Mul(nn.Cell):
- def __init__(self):
- super(Mul, self).__init__()
-
- @ms_function
- def construct(self, x, y):
- c = x * y
- z = x * y
- return c, z
-
- class MulAdd(nn.Cell):
- def __init__(self):
- super(MulAdd, self).__init__()
- self.mul = Mul()
-
- @ms_function
- def construct(self, x, y):
- u = x + y
- v = x - y
- c, z = self.mul(u, v)
- c = stop_gradient(c)
- ret1 = c + x + y
- ret2 = z + y + y
- return ret1, ret2
-
- dx = bprop(MulAdd(), Tensor(np.ones([2, 2]).astype(np.float32)),
- Tensor(np.ones([2, 2]).astype(np.float32)))
- expect = np.array([[3.0, 3.0], [3.0, 3.0]])
- assert (dx.asnumpy() == expect).all()
-
-
- def test_stop_gradient_3():
- class TupleGetItem(nn.Cell):
- def __init__(self):
- super(TupleGetItem, self).__init__()
-
- @ms_function
- def construct(self, x1, x2, x3, x4, x5):
- z1 = x1 + x1
- z2 = x1 * x2
- t = (z1, z2, x3, x4, x5)
- z2 = t[1]
- z2 = stop_gradient(z2)
- return z1, z2, x3, x4, x5
-
- dx = bprop(TupleGetItem(),
- Tensor(np.ones([2]).astype(np.float32)),
- Tensor(np.ones([2]).astype(np.float32)),
- Tensor(np.ones([2]).astype(np.float32)),
- Tensor(np.ones([2]).astype(np.float32)),
- Tensor(np.ones([2]).astype(np.float32)))
- expect = np.array([[2.0, 2.0], [2.0, 2.0]])
- assert (dx.asnumpy() == expect).all()
-
-
- def test_stop_gradient_4():
- def stop_test(x):
- return stop_gradient(x)
-
- assert C.grad_all(stop_test)(Tensor(1, dtype=ms.int32)) == (1,)
-
-
- def test_stop_gradient_5():
- def stop_test(x):
- y = x + x
- y = stop_gradient(y)
- ret = x + y
- return ret
-
- assert C.grad_all(stop_test)(Tensor(1, dtype=ms.int32)) == (1,)
-
-
- def test_stop_gradient_6():
- def stop_test(x, y):
- ret = x * y
- ret = stop_gradient(ret)
- return ret
-
- assert C.grad_all(stop_test)(Tensor(1, dtype=ms.int32), Tensor(3, dtype=ms.int32)) == (0, 0)
-
-
- class PrimWithMultiOutputs(PrimitiveWithInfer):
- @prim_attr_register
- def __init__(self):
- """init"""
-
- def __call__(self, x, y):
- """Implement by vm mode."""
- return x, y
-
- def infer_shape(self, x_shape, y_shape):
- return x_shape, y_shape
-
- def infer_dtype(self, x_type, y_type):
- return x_type, y_type
-
- def get_bprop(self):
- def bprop(x, y, out, dout):
- return (dout[0], dout[1])
-
- return bprop
-
-
- def test_stop_gradient_7():
- class PrimWithMultiOutputs_(nn.Cell):
- def __init__(self):
- super(PrimWithMultiOutputs_, self).__init__()
- self.prim_with_multi_outputs = PrimWithMultiOutputs()
-
- @ms_function
- def construct(self, x1, x2):
- x1, x2 = self.prim_with_multi_outputs(x1, x2)
- x1 = stop_gradient(x1)
- return x1, x2
-
- dx, dy = bprop(PrimWithMultiOutputs_(), Tensor(np.ones([2]).astype(np.float32)),
- Tensor(np.ones([2]).astype(np.float32)), wrt=['inputs'])
- expect_dx = np.zeros([2])
- expect_dy = np.ones([2])
- assert (dx.asnumpy() == expect_dx).all()
- assert (dy.asnumpy() == expect_dy).all()
-
-
- def test_stop_gradient_8():
- class PrimWithMultiOutputs_(nn.Cell):
- def __init__(self):
- super(PrimWithMultiOutputs_, self).__init__()
- self.prim_with_multi_output = PrimWithMultiOutputs()
-
- @ms_function
- def construct(self, x1, x2):
- x1, x2 = stop_gradient(self.prim_with_multi_output(x1, x2))
- return x1, x2
-
- dx, dy = bprop(PrimWithMultiOutputs_(), Tensor(np.ones([2]).astype(np.float32)),
- Tensor(np.ones([2]).astype(np.float32)), wrt=['inputs'])
- expect_dx = np.zeros([2])
- expect_dy = np.zeros([2])
- assert (dx.asnumpy() == expect_dx).all()
- assert (dy.asnumpy() == expect_dy).all()
-
-
- def test_stop_gradient_9():
- class Mul(nn.Cell):
- def __init__(self):
- super(Mul, self).__init__()
-
- @ms_function
- def construct(self, x, y):
- c = x * y
- z = x * y
- return c, z
-
- class MulAdd(nn.Cell):
- def __init__(self):
- super(MulAdd, self).__init__()
- self.mul = Mul()
-
- @ms_function
- def construct(self, x, y):
- u = x + y
- v = x - y
- c, z = self.mul(u, v)
- c1 = stop_gradient(c)
- c2 = c
- ret1 = c1 + x + y + c2
- ret2 = z + y + y
- return ret1, ret2
-
- dx = bprop(MulAdd(), Tensor(np.ones([2, 2]).astype(np.float32)),
- Tensor(np.ones([2, 2]).astype(np.float32)))
- expect = np.array([[5.0, 5.0], [5.0, 5.0]])
- assert (dx.asnumpy() == expect).all()
-
-
- class PrimWithNoBprop(PrimitiveWithInfer):
- @prim_attr_register
- def __init__(self):
- """init"""
-
- def __call__(self, x, y):
- """Implement by vm mode."""
- return x, y
-
- def infer_shape(self, x_shape, y_shape):
- return x_shape, y_shape
-
- def infer_dtype(self, x_type, y_type):
- return x_type, y_type
-
-
- def test_stop_gradient_10():
- class PrimWithNoBprop_(nn.Cell):
- def __init__(self):
- super(PrimWithNoBprop_, self).__init__()
- self.prim_with_no_bprop = PrimWithNoBprop()
-
- @ms_function
- def construct(self, x, y):
- x = x * y
- x, y = self.prim_with_no_bprop(x, y)
- x = stop_gradient(x)
- y = stop_gradient(y)
- return x, y
-
- dx = bprop(PrimWithNoBprop_(), Tensor(np.ones([2]).astype(np.float32)),
- Tensor(np.ones([2]).astype(np.float32)))
- expect_dx = np.zeros([2])
- assert (dx.asnumpy() == expect_dx).all()
-
-
- def test_stop_gradient_11():
- class PrimWithNoBprop_(nn.Cell):
- def __init__(self):
- super(PrimWithNoBprop_, self).__init__()
- self.prim_with_no_bprop = PrimWithNoBprop()
-
- @ms_function
- def construct(self, x, y):
- x, y = self.prim_with_no_bprop(x, y)
- x = stop_gradient(x)
- return x, y
-
- with pytest.raises(RuntimeError):
- bprop(PrimWithNoBprop_(), Tensor(np.ones([2]).astype(np.float32)),
- Tensor(np.ones([2]).astype(np.float32)))
-
-
- def test_stop_print():
- class StopPrint(nn.Cell):
- def __init__(self):
- super(StopPrint, self).__init__()
- self.printm = P.Print()
-
- def construct(self, x, y):
- self.printm("StopPrint", x)
- self.printm(y)
- return x, y
-
- C.grad_all(StopPrint())(Tensor(np.ones([2]).astype(np.float32)),
- Tensor(np.ones([2]).astype(np.float32)))
|