|
|
|
@@ -48,8 +48,8 @@ def test_get_parameter_layout(): |
|
|
|
net.set_auto_parallel() |
|
|
|
exe = me._executor |
|
|
|
exe.compile(net, x, auto_parallel_mode=True) |
|
|
|
x_layout = ([2, 4], [1, -1]) # device_arrangement = [2, 4], tensor_map = [1, -1] |
|
|
|
weight_layout = ([2, 4], [0, -1]) # device_arrangement = [2, 4], tensor_map = [0, -1] |
|
|
|
x_layout = [[2, 4], [1, -1], [16, 32]] # device_arrangement = [2, 4], tensor_map = [1, -1] |
|
|
|
weight_layout = [[2, 4], [0, -1], [16, 32]] # device_arrangement = [2, 4], tensor_map = [0, -1] |
|
|
|
expect_dict = {'x': x_layout, 'w1': weight_layout} |
|
|
|
# to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut |
|
|
|
assert (net.parameter_layout_dict == expect_dict) |
|
|
|
|