From b854ffd282d06da0ed77ad6577db13d4c4acbaef Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Mon, 20 Mar 2023 22:26:41 +0800 Subject: [PATCH] fix bug in BasicModel test --- .github/workflows/build-and-test.yaml | 6 +++++- abl/models/basic_model.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build-and-test.yaml b/.github/workflows/build-and-test.yaml index 45b2bd2..0e38b25 100644 --- a/.github/workflows/build-and-test.yaml +++ b/.github/workflows/build-and-test.yaml @@ -11,7 +11,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-latest, windows-latest] + os: ubuntu-latest python-version: [3.8, 3.9] steps: - uses: actions/checkout@v2 @@ -25,6 +25,10 @@ jobs: run: | python -m pip install --upgrade pip pip install -r ./requirements.txt + - uses: Jimver/cuda-toolkit@v0.2.10 + id: cuda-toolkit + with: + cuda: '12.1.0' - name: Run tests run: | pytest --cov-config=.coveragerc --cov-report=xml --cov=abl ./tests diff --git a/abl/models/basic_model.py b/abl/models/basic_model.py index 8c137e2..9050365 100644 --- a/abl/models/basic_model.py +++ b/abl/models/basic_model.py @@ -251,7 +251,7 @@ class BasicModel: def save(self, epoch_id, save_dir): recorder = self.recorder if not os.path.exists(save_dir): - os.mkdir(save_dir) + os.makedirs(save_dir) recorder.print("Saving model and opter") save_path = os.path.join(save_dir, str(epoch_id) + "_net.pth") torch.save(self.model.state_dict(), save_path)