diff --git a/main.py b/main.py index 9ba1bf0..81bdc16 100644 --- a/main.py +++ b/main.py @@ -57,7 +57,7 @@ def main(stage, gpus = 1 if torch.cuda.is_available() and gpus is None and tpu_cores is None else None # 定义不常改动的通用参数 # TODO 获得最优的batch size - num_workers = cpu_count() + num_workers = min([cpu_count(), 8]) # 获得非通用参数 config = {'dim_in': 32, } for kth_fold in range(kth_fold_start, k_fold):