|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159 |
- import logging
- import numpy as np
- import mindspore.context as context
- import mindspore.ops.composite as C
- from mindspore import Tensor, Parameter
- from mindspore.common.initializer import initializer
- from mindspore.nn import Cell
- from mindspore.ops import operations as P
- from mindspore.nn.composite_ops import LambUpdateWithLR, LambNextMV
- from mindspore.ops import functional as F
- from mindspore.common import dtype as mstype
-
- log = logging.getLogger("ME")
- log.setLevel(level=logging.DEBUG)
- context.set_context(mode=context.GRAPH_MODE, save_graphs=True, device_target="Ascend")
-
- class LambUpdateNet(Cell):
- def __init__(self,shape):
- super(LambUpdateNet, self).__init__()
- self.lamb_update = LambUpdateWithLR()
- self.x6 = Parameter(initializer('normal', shape), name='x6')
-
- def construct(self, x1, x2, x3, x4, x5, gy, se, my):
- return self.lamb_update(x1, x2, x3, x4, x5, self.x6, gy, se, my)
-
- class LambUpdateNetTbe(Cell):
- def __init__(self):
- super(LambUpdateNetTbe, self).__init__()
- self.mul = P.Mul()
- self.sqrt = P.Sqrt()
- self.rsqrt = P.Rsqrt()
- self.square = P.Square()
- self.cast = P.Cast()
- self.reshape = P.Reshape()
- self.shape = P.Shape()
- self.pow = P.Pow()
- self.select = P.Select()
- self.greater = P.Greater()
- self.fill = P.Fill()
- self.dtype = P.DType()
-
- def construct(self, x1, x2, x3, x4, x5, x6, gy, se, my):
- trust_ratio = self.select(
- self.greater(x2, gy),
- self.select(self.greater(x1, gy), x2 / x3, se),
- se
- )
- trust_ratio = C.clip_by_value(trust_ratio, gy, my)
- update_with_lr = self.mul(self.mul(trust_ratio, x4), x5)
- next_param = x6 - self.reshape(update_with_lr, self.shape(x6))
- return next_param
-
- class LambNextMVNet(Cell):
- def __init__(self, shape):
- super(LambNextMVNet, self).__init__()
- self.i2 = Parameter(initializer('normal', shape), name='i2')
- self.i5 = Parameter(initializer('normal', shape), name='i5')
- self.lamb_next = LambNextMV()
-
- def construct(self, i1, i3, i4, i6, i7, i8, i9, x0, x1, x2, x3):
- return self.lamb_next(i1, self.i2, i3, i4, self.i5, i6, i7, i8, i9, x0, x1, x2, x3)
-
- class LambNextMVNetTbe(Cell):
- def __init__(self):
- super(LambNextMVNetTbe, self).__init__()
- self.mul = P.Mul()
- self.sqrt = P.Sqrt()
- self.rsqrt = P.Rsqrt()
- self.square = P.Square()
- self.cast = P.Cast()
- self.reshape = P.Reshape
- self.pow = P.Pow()
- self.select = P.Select()
-
- def construct(self, i1, i2, i3, i4, i5, i6, i7, i8, i9, x0, x1, x2, x3):
- # x1: 1 - beta2 i1: g^2 x0: beta2
- # i2: v i9: 1 - beta1 i4: g
- # i8: beta1 i5: m i6: 1 - beta1^(gs + 1)
- # i3: 1 - beta2^(gs + 1) x3: eps
- # x2: weight_decay_tensor i7: param
- m_fp32 = self.cast(i5, mstype.float32)
- v_fp32 = self.cast(i2, mstype.float32)
- next_m = self.mul(i8, m_fp32) + self.mul(i9, i4)
- next_v = self.mul(x0, v_fp32) + self.mul(x1, i1)
- next_mm = next_m / i6
- next_vv = next_v / i3
- update = next_mm / (self.sqrt(next_vv) + x3)
- add3 = self.mul(next_mm, self.rsqrt(next_vv + x3)) + x2 * i7
- return add3, next_m, next_v, update
-
-
- def tensor_all(*args):
- res = [Tensor(a) for a in args]
- return res
-
- # composite not inline funcGraph
- # def test_composite_lamb_update_with_lr():
- # shape = [1, 16]
- # oshape = [1]
- # x1 = np.random.normal(0, 1, oshape).astype(np.float32)
- # x2 = np.random.normal(0, 1, oshape).astype(np.float32)
- # x3 = np.random.normal(0, 1, oshape).astype(np.float32)
- # x4 = np.random.normal(0, 1, oshape).astype(np.float32)
- # x5 = np.random.normal(0, 1, shape).astype(np.float32)
- # gy = np.random.normal(0, 1, oshape).astype(np.float32)
- # se = np.random.normal(0, 1, oshape).astype(np.float32)
- # my = np.random.normal(0, 1, oshape).astype(np.float32)
-
- # net = LambUpdateNet(shape)
- # net1 = LambNextMVNetTbe()
-
- # x6 = net.x6.data.asnumpy().copy()
-
- # tx1, tx2, tx3, tx4, tx5, tx6, tgy, tse, tmy = tensor_all(x1, x2, x3, x4, x5, x6, gy, se, my)
-
- # _ = net(tx1, tx2, tx3, tx4, tx5, tgy, tse, tmy)
- # tres = net1(tx1, tx2, tx3, tx4, tx5, tx6, tgy, tse, tmy)
-
- # ares = net.x6.data.asnumpy().copy()
-
- # print("=======================================")
- # print("x6 before:\n{}".formata(x6))
- # print("white res:\n{}".format(ares)) # x6 will be inplace change
- # print("tbe b res:\n{}".format(tres))
- # print("=======================================")
-
-
- # def test_composite_lamb_next_mv():
- # shape = [1, 16]
- # i1 = np.random.normal(0, 1, shape).astype(np.float32)
- # i3 = np.random.normal(0, 1, shape).astype(np.float32)
- # i4 = np.random.normal(0, 1, shape).astype(np.float32)
- # i6 = np.random.normal(0, 1, shape).astype(np.float32)
- # i7 = np.random.normal(0, 1, shape).astype(np.float32)
- # i8 = np.random.normal(0, 1, shape).astype(np.float32)
- # i9 = np.random.normal(0, 1, shape).astype(np.float32)
- # x0 = np.random.normal(0, 1, shape).astype(np.float32)
- # x1 = np.random.normal(0, 1, shape).astype(np.float32)
- # x2 = np.random.normal(0, 1, shape).astype(np.float32)
- # x3 = np.random.normal(0, 1, shape).astype(np.float32)
-
- # net = LambNextMVNet(shape)
- # net1 = LambNextMVNetTbe()
-
- # i2 = net.i2.data.asnumpy().copy()
- # i5 = net.i5.data.asnumpy().copy()
-
- # ti1, ti2, ti3, ti4, ti5, ti6, ti7, ti8, ti9, tx0, tx1, tx2, tx3 = \
- # tensor_all(i1, i2, i3, i4, i5, i6, i7, i8, i9, x0, x1, x2, x3)
-
- # wa3, wup = net(ti1, ti3, ti4, ti6, ti7, ti8, ti9, tx0, tx1, tx2, tx3)
- # ba3, ba0, ba1, bup = net1(ti1, ti2, ti3, ti4, ti5, ti6, ti7, ti8, ti9, tx0, tx1, tx2, tx3)
- # wi2 = net.i2.data.asnumpy().copy()
- # wi5 = net.i5.data.asnumpy().copy()
-
- # print("==========================================")
- # print("before: \ni2:\n{}\ni5:\n{}".format(i2, i5))
- # print("wa3:{}\nwi2:\n{}\nwi5:\n{}\nwup:\n{}".format(wa3, wi2, wi5, wup))
- # print("ba3:{}\nbi2:\n{}\nbi5:\n{}\nbup:\n{}".format(ba3, ba0, ba1, bup))
|