| @@ -1,3 +1,5 @@ | |||||
| import os | |||||
| import pytest | import pytest | ||||
| from fastNLP.core.drivers import PaddleSingleDriver, PaddleFleetDriver | from fastNLP.core.drivers import PaddleSingleDriver, PaddleFleetDriver | ||||
| @@ -40,9 +42,14 @@ def test_get_fleet(device): | |||||
| """ | """ | ||||
| 测试 fleet 多卡的初始化情况 | 测试 fleet 多卡的初始化情况 | ||||
| """ | """ | ||||
| flag = False | |||||
| if "USER_CUDA_VISIBLE_DEVICES" not in os.environ: | |||||
| os.environ["USER_CUDA_VISIBLE_DEVICES"] = "0,1,2,3" | |||||
| flag = True | |||||
| model = PaddleNormalModel_Classification_1(20, 10) | model = PaddleNormalModel_Classification_1(20, 10) | ||||
| driver = initialize_paddle_driver("paddle", device, model) | driver = initialize_paddle_driver("paddle", device, model) | ||||
| if flag: | |||||
| del os.environ["USER_CUDA_VISIBLE_DEVICES"] | |||||
| assert isinstance(driver, PaddleFleetDriver) | assert isinstance(driver, PaddleFleetDriver) | ||||