You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_lamb.py 6.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. import logging
  2. import numpy as np
  3. import mindspore.context as context
  4. import mindspore.ops.composite as C
  5. from mindspore import Tensor, Parameter
  6. from mindspore.common.initializer import initializer
  7. from mindspore.nn import Cell
  8. from mindspore.ops import operations as P
  9. from mindspore.nn.composite_ops import LambUpdateWithLR, LambNextMV
  10. from mindspore.ops import functional as F
  11. from mindspore.common import dtype as mstype
  12. log = logging.getLogger("ME")
  13. log.setLevel(level=logging.DEBUG)
  14. context.set_context(mode=context.GRAPH_MODE, save_graphs=True, device_target="Ascend")
  15. class LambUpdateNet(Cell):
  16. def __init__(self,shape):
  17. super(LambUpdateNet, self).__init__()
  18. self.lamb_update = LambUpdateWithLR()
  19. self.x6 = Parameter(initializer('normal', shape), name='x6')
  20. def construct(self, x1, x2, x3, x4, x5, gy, se, my):
  21. return self.lamb_update(x1, x2, x3, x4, x5, self.x6, gy, se, my)
  22. class LambUpdateNetTbe(Cell):
  23. def __init__(self):
  24. super(LambUpdateNetTbe, self).__init__()
  25. self.mul = P.Mul()
  26. self.sqrt = P.Sqrt()
  27. self.rsqrt = P.Rsqrt()
  28. self.square = P.Square()
  29. self.cast = P.Cast()
  30. self.reshape = P.Reshape()
  31. self.shape = P.Shape()
  32. self.pow = P.Pow()
  33. self.select = P.Select()
  34. self.greater = P.Greater()
  35. self.fill = P.Fill()
  36. self.dtype = P.DType()
  37. def construct(self, x1, x2, x3, x4, x5, x6, gy, se, my):
  38. trust_ratio = self.select(
  39. self.greater(x2, gy),
  40. self.select(self.greater(x1, gy), x2 / x3, se),
  41. se
  42. )
  43. trust_ratio = C.clip_by_value(trust_ratio, gy, my)
  44. update_with_lr = self.mul(self.mul(trust_ratio, x4), x5)
  45. next_param = x6 - self.reshape(update_with_lr, self.shape(x6))
  46. return next_param
  47. class LambNextMVNet(Cell):
  48. def __init__(self, shape):
  49. super(LambNextMVNet, self).__init__()
  50. self.i2 = Parameter(initializer('normal', shape), name='i2')
  51. self.i5 = Parameter(initializer('normal', shape), name='i5')
  52. self.lamb_next = LambNextMV()
  53. def construct(self, i1, i3, i4, i6, i7, i8, i9, x0, x1, x2, x3):
  54. return self.lamb_next(i1, self.i2, i3, i4, self.i5, i6, i7, i8, i9, x0, x1, x2, x3)
  55. class LambNextMVNetTbe(Cell):
  56. def __init__(self):
  57. super(LambNextMVNetTbe, self).__init__()
  58. self.mul = P.Mul()
  59. self.sqrt = P.Sqrt()
  60. self.rsqrt = P.Rsqrt()
  61. self.square = P.Square()
  62. self.cast = P.Cast()
  63. self.reshape = P.Reshape
  64. self.pow = P.Pow()
  65. self.select = P.Select()
  66. def construct(self, i1, i2, i3, i4, i5, i6, i7, i8, i9, x0, x1, x2, x3):
  67. # x1: 1 - beta2 i1: g^2 x0: beta2
  68. # i2: v i9: 1 - beta1 i4: g
  69. # i8: beta1 i5: m i6: 1 - beta1^(gs + 1)
  70. # i3: 1 - beta2^(gs + 1) x3: eps
  71. # x2: weight_decay_tensor i7: param
  72. m_fp32 = self.cast(i5, mstype.float32)
  73. v_fp32 = self.cast(i2, mstype.float32)
  74. next_m = self.mul(i8, m_fp32) + self.mul(i9, i4)
  75. next_v = self.mul(x0, v_fp32) + self.mul(x1, i1)
  76. next_mm = next_m / i6
  77. next_vv = next_v / i3
  78. update = next_mm / (self.sqrt(next_vv) + x3)
  79. add3 = self.mul(next_mm, self.rsqrt(next_vv + x3)) + x2 * i7
  80. return add3, next_m, next_v, update
  81. def tensor_all(*args):
  82. res = [Tensor(a) for a in args]
  83. return res
  84. # composite not inline funcGraph
  85. # def test_composite_lamb_update_with_lr():
  86. # shape = [1, 16]
  87. # oshape = [1]
  88. # x1 = np.random.normal(0, 1, oshape).astype(np.float32)
  89. # x2 = np.random.normal(0, 1, oshape).astype(np.float32)
  90. # x3 = np.random.normal(0, 1, oshape).astype(np.float32)
  91. # x4 = np.random.normal(0, 1, oshape).astype(np.float32)
  92. # x5 = np.random.normal(0, 1, shape).astype(np.float32)
  93. # gy = np.random.normal(0, 1, oshape).astype(np.float32)
  94. # se = np.random.normal(0, 1, oshape).astype(np.float32)
  95. # my = np.random.normal(0, 1, oshape).astype(np.float32)
  96. # net = LambUpdateNet(shape)
  97. # net1 = LambNextMVNetTbe()
  98. # x6 = net.x6.data.asnumpy().copy()
  99. # tx1, tx2, tx3, tx4, tx5, tx6, tgy, tse, tmy = tensor_all(x1, x2, x3, x4, x5, x6, gy, se, my)
  100. # _ = net(tx1, tx2, tx3, tx4, tx5, tgy, tse, tmy)
  101. # tres = net1(tx1, tx2, tx3, tx4, tx5, tx6, tgy, tse, tmy)
  102. # ares = net.x6.data.asnumpy().copy()
  103. # print("=======================================")
  104. # print("x6 before:\n{}".formata(x6))
  105. # print("white res:\n{}".format(ares)) # x6 will be inplace change
  106. # print("tbe b res:\n{}".format(tres))
  107. # print("=======================================")
  108. # def test_composite_lamb_next_mv():
  109. # shape = [1, 16]
  110. # i1 = np.random.normal(0, 1, shape).astype(np.float32)
  111. # i3 = np.random.normal(0, 1, shape).astype(np.float32)
  112. # i4 = np.random.normal(0, 1, shape).astype(np.float32)
  113. # i6 = np.random.normal(0, 1, shape).astype(np.float32)
  114. # i7 = np.random.normal(0, 1, shape).astype(np.float32)
  115. # i8 = np.random.normal(0, 1, shape).astype(np.float32)
  116. # i9 = np.random.normal(0, 1, shape).astype(np.float32)
  117. # x0 = np.random.normal(0, 1, shape).astype(np.float32)
  118. # x1 = np.random.normal(0, 1, shape).astype(np.float32)
  119. # x2 = np.random.normal(0, 1, shape).astype(np.float32)
  120. # x3 = np.random.normal(0, 1, shape).astype(np.float32)
  121. # net = LambNextMVNet(shape)
  122. # net1 = LambNextMVNetTbe()
  123. # i2 = net.i2.data.asnumpy().copy()
  124. # i5 = net.i5.data.asnumpy().copy()
  125. # ti1, ti2, ti3, ti4, ti5, ti6, ti7, ti8, ti9, tx0, tx1, tx2, tx3 = \
  126. # tensor_all(i1, i2, i3, i4, i5, i6, i7, i8, i9, x0, x1, x2, x3)
  127. # wa3, wup = net(ti1, ti3, ti4, ti6, ti7, ti8, ti9, tx0, tx1, tx2, tx3)
  128. # ba3, ba0, ba1, bup = net1(ti1, ti2, ti3, ti4, ti5, ti6, ti7, ti8, ti9, tx0, tx1, tx2, tx3)
  129. # wi2 = net.i2.data.asnumpy().copy()
  130. # wi5 = net.i5.data.asnumpy().copy()
  131. # print("==========================================")
  132. # print("before: \ni2:\n{}\ni5:\n{}".format(i2, i5))
  133. # print("wa3:{}\nwi2:\n{}\nwi5:\n{}\nwup:\n{}".format(wa3, wi2, wi5, wup))
  134. # print("ba3:{}\nbi2:\n{}\nbi5:\n{}\nbup:\n{}".format(ba3, ba0, ba1, bup))