Browse Source

fix bug in BasicModel test

pull/3/head
Gao Enhao 2 years ago
parent
commit
b854ffd282
2 changed files with 6 additions and 2 deletions
  1. +5
    -1
      .github/workflows/build-and-test.yaml
  2. +1
    -1
      abl/models/basic_model.py

+ 5
- 1
.github/workflows/build-and-test.yaml View File

@@ -11,7 +11,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest, windows-latest]
os: ubuntu-latest
python-version: [3.8, 3.9]
steps:
- uses: actions/checkout@v2
@@ -25,6 +25,10 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -r ./requirements.txt
- uses: Jimver/cuda-toolkit@v0.2.10
id: cuda-toolkit
with:
cuda: '12.1.0'
- name: Run tests
run: |
pytest --cov-config=.coveragerc --cov-report=xml --cov=abl ./tests


+ 1
- 1
abl/models/basic_model.py View File

@@ -251,7 +251,7 @@ class BasicModel:
def save(self, epoch_id, save_dir):
recorder = self.recorder
if not os.path.exists(save_dir):
os.mkdir(save_dir)
os.makedirs(save_dir)
recorder.print("Saving model and opter")
save_path = os.path.join(save_dir, str(epoch_id) + "_net.pth")
torch.save(self.model.state_dict(), save_path)


Loading…
Cancel
Save