Browse Source

add primitive operator to test_lamb

tags/v0.6.0-beta
duxiutao 5 years ago
parent
commit
793737ab62
1 changed files with 4 additions and 2 deletions
  1. +4
    -2
      tests/st/ops/graph_kernel/test_lamb.py

+ 4
- 2
tests/st/ops/graph_kernel/test_lamb.py View File

@@ -33,7 +33,8 @@ class LambNet(Cell):


def construct(self, i1, i3, i4, i6, i7, i8, i9, ix0, ix1, ix2, ix3, def construct(self, i1, i3, i4, i6, i7, i8, i9, ix0, ix1, ix2, ix3,
x1, x2, x3, x4, x5, gy, se, my): x1, x2, x3, x4, x5, gy, se, my):
return self.lamb_next(i1, self.i2, i3, i4, self.i5, i6, i7, i8, i9, ix0,
i1_ = i1 + i3
return self.lamb_next(i1_, self.i2, i3, i4, self.i5, i6, i7, i8, i9, ix0,
ix1, ix2, ix3), \ ix1, ix2, ix3), \
self.lamb_update(x1, x2, x3, x4, x5, self.x6, gy, se, my) self.lamb_update(x1, x2, x3, x4, x5, self.x6, gy, se, my)


@@ -113,7 +114,8 @@ def test_graph_kernel_lamb():


context.set_context(enable_graph_kernel=False) context.set_context(enable_graph_kernel=False)


a3, a0, a1, up = LambNextMVNumpy(i1, i2, i3, i4, i5, i6, i7, i8, i9, ix0,
i1_ = i1 + i3
a3, a0, a1, up = LambNextMVNumpy(i1_, i2, i3, i4, i5, i6, i7, i8, i9, ix0,
ix1, ix2, ix3) ix1, ix2, ix3)


np_res = LambUpdateNumpy(x1, x2, x3, x4, x5, x6, gy, se, my) np_res = LambUpdateNumpy(x1, x2, x3, x4, x5, x6, gy, se, my)


Loading…
Cancel
Save