You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_tuple_parameter.py 2.0 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import numpy as np
  2. import mindspore.nn as nn
  3. from mindspore import context, Tensor
  4. from mindspore.ops import operations as P
  5. from mindspore.ops import composite as C
  6. def setup_module(module):
  7. context.set_context(mode=context.PYNATIVE_MODE)
  8. class Block1(nn.Cell):
  9. """ Define Cell with tuple input as paramter."""
  10. def __init__(self):
  11. super(Block1, self).__init__()
  12. self.mul = P.Mul()
  13. def construct(self, tuple_xy):
  14. x, y = tuple_xy
  15. z = self.mul(x, y)
  16. return z
  17. class Block2(nn.Cell):
  18. """ definition with tuple in tuple output in Cell."""
  19. def __init__(self):
  20. super(Block2, self).__init__()
  21. self.mul = P.Mul()
  22. self.add = P.TensorAdd()
  23. def construct(self, x, y):
  24. z1 = self.mul(x, y)
  25. z2 = self.add(z1, x)
  26. z3 = self.add(z1, y)
  27. return (z1, (z2, z3))
  28. class Net1(nn.Cell):
  29. def __init__(self):
  30. super(Net1, self).__init__()
  31. self.block = Block1()
  32. def construct(self, x, y):
  33. res = self.block((x, y))
  34. return res
  35. class Net2(nn.Cell):
  36. def __init__(self):
  37. super(Net2, self).__init__()
  38. self.add = P.TensorAdd()
  39. self.block = Block2()
  40. def construct(self, x, y):
  41. z1, (z2, z3) = self.block(x, y)
  42. res = self.add(z1, z2)
  43. res = self.add(res, z3)
  44. return res
  45. def test_net():
  46. context.set_context(save_graphs=True)
  47. x = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32) * 2)
  48. y = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32) * 3)
  49. net1 = Net1()
  50. grad_op = C.GradOperation(get_all=True)
  51. output = grad_op(net1)(x, y)
  52. assert np.all(output[0].asnumpy() == y.asnumpy())
  53. assert np.all(output[1].asnumpy() == x.asnumpy())
  54. net2 = Net2()
  55. output = grad_op(net2)(x, y)
  56. expect_x = np.ones([1, 1, 3, 3]).astype(np.float32) * 10
  57. expect_y = np.ones([1, 1, 3, 3]).astype(np.float32) * 7
  58. assert np.all(output[0].asnumpy() == expect_x)
  59. assert np.all(output[1].asnumpy() == expect_y)