|
|
|
@@ -45,7 +45,6 @@ class GradWrap(nn.Cell): |
|
|
|
return C.grad_all(self.network)(x) |
|
|
|
|
|
|
|
|
|
|
|
# core dump, step_auto_parallel should SetInputs for transpose axis |
|
|
|
def test_reshape_matmul(): |
|
|
|
class Net(nn.Cell): |
|
|
|
def __init__(self): |
|
|
|
@@ -68,6 +67,28 @@ def test_reshape_matmul(): |
|
|
|
net.set_auto_parallel() |
|
|
|
_executor.compile(net, x) |
|
|
|
|
|
|
|
def test_reshape_reshape(): |
|
|
|
class Net(nn.Cell): |
|
|
|
def __init__(self): |
|
|
|
super().__init__() |
|
|
|
self.reshape = P.Reshape() |
|
|
|
self.relu = P.ReLU() |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
x = self.relu(x) |
|
|
|
out = self.reshape(x, (64, 28)) |
|
|
|
out = self.reshape(out, (64, 28, 1)) |
|
|
|
return out |
|
|
|
|
|
|
|
size = 8 |
|
|
|
context.set_auto_parallel_context(device_num=size, global_rank=0) |
|
|
|
x = Tensor(np.ones([8 * size, 28, 1, 1]), dtype=ms.float32) |
|
|
|
|
|
|
|
net = GradWrap(NetWithLoss(Net())) |
|
|
|
context.set_auto_parallel_context(parallel_mode="auto_parallel") |
|
|
|
net.set_auto_parallel() |
|
|
|
_executor.compile(net, x) |
|
|
|
|
|
|
|
|
|
|
|
def test_reshape_auto_1(): |
|
|
|
class Net(nn.Cell): |
|
|
|
|