You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_composite.py 3.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import numpy as np
  2. import mindspore.context as context
  3. from mindspore import Tensor, Parameter
  4. from mindspore.nn import Cell, Composite
  5. from mindspore.ops import operations as P
  6. import mindspore.ops.composite as C
  7. import logging
  8. log = logging.getLogger("ME")
  9. log.setLevel(level=logging.DEBUG)
  10. context.set_context(mode=context.GRAPH_MODE, save_graphs=True, device_target="Ascend")
  11. class Sigmoid(Composite):
  12. def __init__(self):
  13. super(Sigmoid, self).__init__()
  14. self.neg = P.Neg()
  15. self.exp = P.Exp()
  16. self.div = P.Div()
  17. def construct(self, x):
  18. neg_val = self.neg(x)
  19. exp_val = self.exp(neg_val)
  20. sigmoid = 1.0 / (1.0 + exp_val)
  21. return sigmoid
  22. class Net(Cell):
  23. def __init__(self):
  24. super(Net, self).__init__()
  25. self.sigmoid = Sigmoid()
  26. self.exp = P.Exp()
  27. def construct(self, x):
  28. return self.sigmoid(x) * self.exp(x)
  29. class NetComposite(Composite):
  30. def __init__(self):
  31. super(NetComposite, self).__init__()
  32. self.sigmoid = Sigmoid()
  33. self.exp = P.Exp()
  34. def construct(self, x):
  35. return self.sigmoid(x) * self.exp(x)
  36. class Net1(Cell):
  37. def __init__(self):
  38. super(Net1, self).__init__()
  39. self.exp = P.Exp()
  40. def construct(self, x):
  41. return self.exp(x)
  42. class NetComposite1(Composite):
  43. def __init__(self):
  44. super(NetComposite1, self).__init__()
  45. self.net = Net1()
  46. self.exp = P.Exp()
  47. def construct(self, x):
  48. return self.exp(x) * self.net(x)
  49. class Net_grad(Cell):
  50. def __init__(self):
  51. super(Net_grad, self).__init__()
  52. self.sigmoid = Sigmoid()
  53. self.exp = P.Exp()
  54. def construct(self, x, dout):
  55. dout = C.grad_with_sens(self.sigmoid)(x, dout)
  56. #out = self.sigmoid(x)
  57. return dout
  58. def vm_impl(x):
  59. return (1.0 / (1.0 + np.exp(-x))) * np.exp(x)
  60. def vm_sigmoid(x):
  61. return 1.0 / (1.0 + np.exp(-x))
  62. # composite not inline funcGraph
  63. def test_composite_sigmoid1():
  64. x = np.random.normal(0, 1, [2, 3]).astype(np.float32)
  65. net = Net()
  66. result = net(Tensor(x))
  67. vm_result = vm_impl(x)
  68. print("=======================================")
  69. print("x: {}".format(x))
  70. print("result: {}".format(result))
  71. print("vm_result: {}".format(vm_result))
  72. print("=======================================")
  73. # composite inline composite
  74. def test_composite_sigmoid2():
  75. x = Tensor(np.random.normal(0, 1, [2, 3]).astype(np.float32))
  76. net = NetComposite()
  77. result = net(x)
  78. print("=======================================")
  79. print(result)
  80. print("=======================================")
  81. # composite inline func
  82. def test_composite_sigmoid3():
  83. x = Tensor(np.random.normal(0, 1, [2, 3]).astype(np.float32))
  84. net = NetComposite1()
  85. result = net(x)
  86. print("=======================================")
  87. print(result)
  88. print("=======================================")
  89. def test_composite_sigmoid_grad():
  90. x = np.random.normal(0, 1, [2, 3]).astype(np.float32)
  91. dout = np.random.normal(0, 1, [2, 3]).astype(np.float32)
  92. net = Net_grad()
  93. result = net(Tensor(x), Tensor(dout))
  94. print("=======================================")
  95. print("x: {}".format(x))
  96. print("result: {}".format(result))
  97. print("=======================================")
  98. test_composite_sigmoid1()
  99. test_composite_sigmoid_grad()