|
|
|
@@ -80,9 +80,9 @@ def test_double_star_graph(): |
|
|
|
|
|
|
|
_executor.compile(net, x, y, z, w, phase='train') |
|
|
|
strategies = _executor._get_strategy(net) |
|
|
|
expected_strategies = {'Default/network-Net/Cast-op1': [[8, 1]], |
|
|
|
'Default/network-Net/Cast-op3': [[1, 8]], |
|
|
|
'Default/network-Net/MatMul-op2': [[8, 1], [1, 1]], |
|
|
|
expected_strategies = {'Default/network-Net/Cast-op0': [[8, 1]], |
|
|
|
'Default/network-Net/Cast-op1': [[1, 8]], |
|
|
|
'Default/network-Net/MatMul-op3': [[8, 1], [1, 1]], |
|
|
|
'Default/network-Net/MatMul-op4': [[1, 1], [1, 8]], |
|
|
|
'Default/network-Net/MatMul-op0': [[1, 8], [8, 1]]} |
|
|
|
'Default/network-Net/MatMul-op2': [[1, 8], [8, 1]]} |
|
|
|
assert strategies == expected_strategies |