You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_tuple_parameter.py 2.0 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. import numpy as np
  2. import mindspore.nn as nn
  3. from mindspore import context, Tensor
  4. from mindspore.ops import operations as P
  5. from mindspore.ops import composite as C
  6. def setup_module(module):
  7. context.set_context(mode=context.PYNATIVE_MODE)
  8. class Block1(nn.Cell):
  9. """ Define Cell with tuple input as parameter."""
  10. def __init__(self):
  11. super(Block1, self).__init__()
  12. self.mul = P.Mul()
  13. def construct(self, tuple_xy):
  14. x, y = tuple_xy
  15. z = self.mul(x, y)
  16. return z
  17. class Block2(nn.Cell):
  18. """ definition with tuple in tuple output in Cell."""
  19. def __init__(self):
  20. super(Block2, self).__init__()
  21. self.mul = P.Mul()
  22. self.add = P.Add()
  23. def construct(self, x, y):
  24. z1 = self.mul(x, y)
  25. z2 = self.add(z1, x)
  26. z3 = self.add(z1, y)
  27. return (z1, (z2, z3))
  28. class Net1(nn.Cell):
  29. def __init__(self):
  30. super(Net1, self).__init__()
  31. self.block = Block1()
  32. def construct(self, x, y):
  33. res = self.block((x, y))
  34. return res
  35. class Net2(nn.Cell):
  36. def __init__(self):
  37. super(Net2, self).__init__()
  38. self.add = P.Add()
  39. self.block = Block2()
  40. def construct(self, x, y):
  41. z1, (z2, z3) = self.block(x, y)
  42. res = self.add(z1, z2)
  43. res = self.add(res, z3)
  44. return res
  45. def test_net():
  46. x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32) * 2)
  47. y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32) * 3)
  48. net1 = Net1()
  49. grad_op = C.GradOperation(get_all=True)
  50. output = grad_op(net1)(x, y)
  51. assert np.all(output[0].asnumpy() == y.asnumpy())
  52. assert np.all(output[1].asnumpy() == x.asnumpy())
  53. net2 = Net2()
  54. output = grad_op(net2)(x, y)
  55. expect_x = np.ones([1, 1, 3, 3]).astype(np.float32) * 10
  56. expect_y = np.ones([1, 1, 3, 3]).astype(np.float32) * 7
  57. assert np.all(output[0].asnumpy() == expect_x)
  58. assert np.all(output[1].asnumpy() == expect_y)