| @@ -295,11 +295,11 @@ def test_train_32k_8p(epoch_size=3, batch_size=32, num_classes=32768): #1048576 | |||||
| model.train(5, dataset, dataset_sink_mode=False) | model.train(5, dataset, dataset_sink_mode=False) | ||||
| strategies = _executor._get_strategy(model._train_network) | strategies = _executor._get_strategy(model._train_network) | ||||
| for (k, v) in strategies.items(): | for (k, v) in strategies.items(): | ||||
| if re.match(k, 'Conv2D-op') is not None: | |||||
| if re.search('Conv2D-op', k) is not None: | |||||
| assert v[0][0] == dev_num | assert v[0][0] == dev_num | ||||
| elif re.match(k, 'MatMul-op') is not None: | |||||
| elif re.search('MatMul-op', k) is not None: | |||||
| assert v == [[dev_num, 1], [1, 1]] | assert v == [[dev_num, 1], [1, 1]] | ||||
| elif re.match(k, 'ReduceSum-op') is not None: | |||||
| elif re.search('ReduceSum-op', k) is not None: | |||||
| assert v == [[dev_num, 1]] | assert v == [[dev_num, 1]] | ||||
| allreduce_fusion_dict = _executor._get_allreduce_fusion(model._train_network) | allreduce_fusion_dict = _executor._get_allreduce_fusion(model._train_network) | ||||
| @@ -490,9 +490,9 @@ def test_train_64k_8p(epoch_size=3, batch_size=32, num_classes=65536): #1048576 | |||||
| model.train(5, dataset, dataset_sink_mode=False) | model.train(5, dataset, dataset_sink_mode=False) | ||||
| strategies = _executor._get_strategy(model._train_network) | strategies = _executor._get_strategy(model._train_network) | ||||
| for (k, v) in strategies.items(): | for (k, v) in strategies.items(): | ||||
| if re.match(k, 'Conv2D-op') is not None: | |||||
| if re.search('Conv2D-op', k ) is not None: | |||||
| assert v[0][0] == dev_num | assert v[0][0] == dev_num | ||||
| elif re.match(k, 'MatMul-op') is not None: | |||||
| elif re.search('MatMul-op', k) is not None: | |||||
| assert v == [[1, 1], [dev_num, 1]] | assert v == [[1, 1], [dev_num, 1]] | ||||
| elif re.match(k, 'ReduceSum-op') is not None: | |||||
| elif re.search('ReduceSum-op', k) is not None: | |||||
| assert v == [[1, dev_num]] | assert v == [[1, dev_num]] | ||||