Browse Source

refining strategy-checking for resnet50

tags/v0.3.0-alpha
Xiaoda Zhang chang zherui 5 years ago
parent
commit
4520469e16
1 changed files with 6 additions and 6 deletions
  1. +6
    -6
      tests/ut/python/parallel/test_auto_parallel_resnet.py

+ 6
- 6
tests/ut/python/parallel/test_auto_parallel_resnet.py View File

@@ -295,11 +295,11 @@ def test_train_32k_8p(epoch_size=3, batch_size=32, num_classes=32768): #1048576
model.train(5, dataset, dataset_sink_mode=False) model.train(5, dataset, dataset_sink_mode=False)
strategies = _executor._get_strategy(model._train_network) strategies = _executor._get_strategy(model._train_network)
for (k, v) in strategies.items(): for (k, v) in strategies.items():
if re.match(k, 'Conv2D-op') is not None:
if re.search('Conv2D-op', k) is not None:
assert v[0][0] == dev_num assert v[0][0] == dev_num
elif re.match(k, 'MatMul-op') is not None:
elif re.search('MatMul-op', k) is not None:
assert v == [[dev_num, 1], [1, 1]] assert v == [[dev_num, 1], [1, 1]]
elif re.match(k, 'ReduceSum-op') is not None:
elif re.search('ReduceSum-op', k) is not None:
assert v == [[dev_num, 1]] assert v == [[dev_num, 1]]


allreduce_fusion_dict = _executor._get_allreduce_fusion(model._train_network) allreduce_fusion_dict = _executor._get_allreduce_fusion(model._train_network)
@@ -490,9 +490,9 @@ def test_train_64k_8p(epoch_size=3, batch_size=32, num_classes=65536): #1048576
model.train(5, dataset, dataset_sink_mode=False) model.train(5, dataset, dataset_sink_mode=False)
strategies = _executor._get_strategy(model._train_network) strategies = _executor._get_strategy(model._train_network)
for (k, v) in strategies.items(): for (k, v) in strategies.items():
if re.match(k, 'Conv2D-op') is not None:
if re.search('Conv2D-op', k ) is not None:
assert v[0][0] == dev_num assert v[0][0] == dev_num
elif re.match(k, 'MatMul-op') is not None:
elif re.search('MatMul-op', k) is not None:
assert v == [[1, 1], [dev_num, 1]] assert v == [[1, 1], [dev_num, 1]]
elif re.match(k, 'ReduceSum-op') is not None:
elif re.search('ReduceSum-op', k) is not None:
assert v == [[1, dev_num]] assert v == [[1, dev_num]]

Loading…
Cancel
Save