|
|
|
@@ -51,13 +51,13 @@ def test_two_matmul_batchnorm_ex(): |
|
|
|
class Net(nn.Cell): |
|
|
|
def __init__(self, strategy1, strategy2): |
|
|
|
super().__init__() |
|
|
|
self.matmul1 = P.MatMul().shard(strategy1) |
|
|
|
self.matmul1 = P.BatchMatMul().shard(strategy1) |
|
|
|
self.norm = P.FusedBatchNormEx() |
|
|
|
self.gamma = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="gamma") |
|
|
|
self.beta = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="beta") |
|
|
|
self.mean = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="mean") |
|
|
|
self.var = Parameter(Tensor(np.ones([64]), dtype=ms.float32), name="var") |
|
|
|
self.matmul2 = P.MatMul().shard(strategy2) |
|
|
|
self.matmul2 = P.BatchMatMul().shard(strategy2) |
|
|
|
|
|
|
|
def construct(self, x, y, b): |
|
|
|
out = self.matmul1(x, y) |
|
|
|
@@ -66,12 +66,12 @@ def test_two_matmul_batchnorm_ex(): |
|
|
|
return out |
|
|
|
|
|
|
|
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8) |
|
|
|
strategy1 = ((4, 2), (2, 1)) |
|
|
|
strategy2 = ((1, 8), (8, 1)) |
|
|
|
strategy1 = ((1, 1, 4, 2), (1, 1, 2, 1)) |
|
|
|
strategy2 = ((1, 1, 1, 8), (1, 1, 8, 1)) |
|
|
|
net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) |
|
|
|
net.set_auto_parallel() |
|
|
|
x = Tensor(np.ones([128, 32]), dtype=ms.float32) |
|
|
|
y = Tensor(np.ones([32, 64]), dtype=ms.float32) |
|
|
|
b = Tensor(np.ones([64, 64]), dtype=ms.float32) |
|
|
|
x = Tensor(np.ones([64, 64, 128, 32]), dtype=ms.float32) |
|
|
|
y = Tensor(np.ones([64, 64, 32, 64]), dtype=ms.float32) |
|
|
|
b = Tensor(np.ones([64, 64, 64, 64]), dtype=ms.float32) |
|
|
|
net.set_train() |
|
|
|
_executor.compile(net, x, y, b) |