| @@ -33,7 +33,8 @@ class LambNet(Cell): | |||||
| def construct(self, i1, i3, i4, i6, i7, i8, i9, ix0, ix1, ix2, ix3, | def construct(self, i1, i3, i4, i6, i7, i8, i9, ix0, ix1, ix2, ix3, | ||||
| x1, x2, x3, x4, x5, gy, se, my): | x1, x2, x3, x4, x5, gy, se, my): | ||||
| return self.lamb_next(i1, self.i2, i3, i4, self.i5, i6, i7, i8, i9, ix0, | |||||
| i1_ = i1 + i3 | |||||
| return self.lamb_next(i1_, self.i2, i3, i4, self.i5, i6, i7, i8, i9, ix0, | |||||
| ix1, ix2, ix3), \ | ix1, ix2, ix3), \ | ||||
| self.lamb_update(x1, x2, x3, x4, x5, self.x6, gy, se, my) | self.lamb_update(x1, x2, x3, x4, x5, self.x6, gy, se, my) | ||||
| @@ -113,7 +114,8 @@ def test_graph_kernel_lamb(): | |||||
| context.set_context(enable_graph_kernel=False) | context.set_context(enable_graph_kernel=False) | ||||
| a3, a0, a1, up = LambNextMVNumpy(i1, i2, i3, i4, i5, i6, i7, i8, i9, ix0, | |||||
| i1_ = i1 + i3 | |||||
| a3, a0, a1, up = LambNextMVNumpy(i1_, i2, i3, i4, i5, i6, i7, i8, i9, ix0, | |||||
| ix1, ix2, ix3) | ix1, ix2, ix3) | ||||
| np_res = LambUpdateNumpy(x1, x2, x3, x4, x5, x6, gy, se, my) | np_res = LambUpdateNumpy(x1, x2, x3, x4, x5, x6, gy, se, my) | ||||