You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_ms_function.py 2.4 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. # Copyright 2020-2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import pytest
  17. import mindspore.nn as nn
  18. from mindspore.ops import composite as C
  19. from mindspore.nn import Momentum
  20. from mindspore import context, Tensor
  21. from mindspore.common.api import ms_function
  22. grad_all = C.GradOperation(get_all=True)
  23. class CellBprop(nn.Cell):
  24. def __init__(self):
  25. super(CellBprop, self).__init__()
  26. def construct(self, x, y):
  27. return 2 * x * x + y * y
  28. @ms_function
  29. def bprop(self, x, y, out, dout):
  30. return dout, 2 * y
  31. def test_cell_bprop_grad():
  32. input_x = Tensor(np.random.randn(2, 2).astype(np.float32))
  33. input_y = Tensor(np.random.randn(2, 2).astype(np.float32))
  34. context.set_context(mode=context.PYNATIVE_MODE)
  35. net = CellBprop()
  36. with pytest.raises(RuntimeError):
  37. grad_all(net)(input_x, input_y)
  38. class ConvNet(nn.Cell):
  39. def __init__(self):
  40. super(ConvNet, self).__init__()
  41. self.conv = nn.Conv2d(1, 2, kernel_size=2, stride=1, padding=0, weight_init="ones", pad_mode="valid")
  42. def construct(self, x):
  43. out = self.conv(x)
  44. return out
  45. class MomentumWithMsFunc(nn.Cell):
  46. def __init__(self, net):
  47. super(MomentumWithMsFunc, self).__init__()
  48. self.net = net
  49. self.optimizer = Momentum(filter(lambda x: x.requires_grad, self.net.get_parameters()), 0.1, 0.9)
  50. @ms_function
  51. def construct(self, grads):
  52. ret = self.optimizer(grads)
  53. return ret
  54. def test_ms_func_decorate_forward():
  55. context.set_context(mode=context.PYNATIVE_MODE)
  56. input_x = Tensor(np.random.randn(1, 1, 2, 2).astype(np.float32))
  57. net = ConvNet()
  58. grad_out = grad_all(net)(input_x)
  59. opt = MomentumWithMsFunc(net)
  60. opt(grad_out)