|
|
|
@@ -97,6 +97,24 @@ class ReshapeMulNet(nn.Cell): |
|
|
|
out = self.mul(weight, self.mul_weight) |
|
|
|
return out |
|
|
|
|
|
|
|
class ParallelMulNet(nn.Cell): |
|
|
|
def __init__(self, dense_in_channel=2048, dense_out_channel=250): |
|
|
|
super().__init__() |
|
|
|
weight_np = np.full((dense_out_channel, dense_in_channel), 0.01, dtype=np.float32) |
|
|
|
bias_np = np.full((dense_out_channel,), 0.01, dtype=np.float32) |
|
|
|
self.flat = nn.Flatten() |
|
|
|
self.dense = nn.Dense(in_channels=dense_in_channel, |
|
|
|
out_channels=dense_out_channel, |
|
|
|
weight_init=Tensor(weight_np), |
|
|
|
bias_init=Tensor(bias_np), |
|
|
|
has_bias=True) |
|
|
|
self.mul = P.Mul() |
|
|
|
def construct(self, inputs): |
|
|
|
x = self.flat(inputs) |
|
|
|
x = self.dense(x) |
|
|
|
x = self.mul(x, x) |
|
|
|
return x |
|
|
|
|
|
|
|
def compile_graph(x, net): |
|
|
|
net.set_auto_parallel() |
|
|
|
net.set_train(False) |
|
|
|
@@ -104,6 +122,13 @@ def compile_graph(x, net): |
|
|
|
strategies = _executor._get_shard_strategy(net) |
|
|
|
return strategies |
|
|
|
|
|
|
|
def compile_graph_two_input(x, y, net): |
|
|
|
net.set_auto_parallel() |
|
|
|
net.set_train(False) |
|
|
|
_executor.compile(net, x, y, auto_parallel_mode=True) |
|
|
|
strategies = _executor._get_shard_strategy(net) |
|
|
|
return strategies |
|
|
|
|
|
|
|
|
|
|
|
def test_dense_relu_semi_auto(): |
|
|
|
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=False) |
|
|
|
@@ -250,3 +275,33 @@ def test_reshape_mul_auto(): |
|
|
|
for (k, v) in strategies.items(): |
|
|
|
if re.search('VirtualOutput-op', k) is not None: |
|
|
|
assert v[0][0] == 1 |
|
|
|
|
|
|
|
def test_scalar_output_semi_auto(): |
|
|
|
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=False) |
|
|
|
net = ParallelMulNet() |
|
|
|
loss_fn = nn.SoftmaxCrossEntropyWithLogits(reduction='mean') |
|
|
|
eval_net = nn.WithEvalCell(net, loss_fn) |
|
|
|
x = Tensor(np.ones([4096, 1, 2, 1024]).astype(np.float32)*0.01) |
|
|
|
label = Tensor(np.ones([4096, 250]).astype(np.float32)*0.01) |
|
|
|
strategies = compile_graph_two_input(x, label, eval_net) |
|
|
|
count = 0 |
|
|
|
for (k, v) in strategies.items(): |
|
|
|
if re.search('VirtualOutput-op', k) is not None: |
|
|
|
assert v[0][0] == 8 |
|
|
|
count += 1 |
|
|
|
assert count == 1 |
|
|
|
|
|
|
|
def test_scalar_output_auto(): |
|
|
|
context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", full_batch=False) |
|
|
|
net = ParallelMulNet() |
|
|
|
loss_fn = nn.SoftmaxCrossEntropyWithLogits(reduction='mean') |
|
|
|
eval_net = nn.WithEvalCell(net, loss_fn) |
|
|
|
x = Tensor(np.ones([4096, 1, 2, 1024]).astype(np.float32)*0.01) |
|
|
|
label = Tensor(np.ones([4096, 250]).astype(np.float32)*0.01) |
|
|
|
strategies = compile_graph_two_input(x, label, eval_net) |
|
|
|
count = 0 |
|
|
|
for (k, v) in strategies.items(): |
|
|
|
if re.search('VirtualOutput-op', k) is not None: |
|
|
|
assert v[0][0] == 8 |
|
|
|
count += 1 |
|
|
|
assert count == 1 |