Browse Source

[MNI] Minor issues in code

pull/5/head
Tony-HYX 2 years ago
parent
commit
7ceb3a3da1
3 changed files with 7 additions and 7 deletions
  1. +6
    -5
      abl/learning/basic_nn.py
  2. +0
    -2
      abl/reasoning/kb.py
  3. +1
    -0
      pyproject.toml

+ 6
- 5
abl/learning/basic_nn.py View File

@@ -307,8 +307,9 @@ class BasicNN:
data_loader = DataLoader(
dataset,
batch_size=self.batch_size,
num_workers=int(self.num_workers),
num_workers=self.num_workers,
collate_fn=self.collate_fn,
pin_memory=torch.cuda.is_available()
)
return self._predict(data_loader).argmax(axis=1).cpu().numpy()

@@ -348,8 +349,9 @@ class BasicNN:
data_loader = DataLoader(
dataset,
batch_size=self.batch_size,
num_workers=int(self.num_workers),
num_workers=self.num_workers,
collate_fn=self.collate_fn,
pin_memory=torch.cuda.is_available()
)
return self._predict(data_loader).softmax(axis=1).cpu().numpy()

@@ -381,11 +383,9 @@ class BasicNN:
model.eval()

total_correct_num, total_num, total_loss = 0, 0, 0.0

with torch.no_grad():
for data, target in data_loader:
data, target = data.to(device), target.to(device)

out = model(data)

if len(out.shape) > 1:
@@ -482,8 +482,9 @@ class BasicNN:
dataset,
batch_size=self.batch_size,
shuffle=shuffle,
num_workers=int(self.num_workers),
num_workers=self.num_workers,
collate_fn=self.collate_fn,
pin_memory=torch.cuda.is_available()
)
return data_loader



+ 0
- 2
abl/reasoning/kb.py View File

@@ -77,8 +77,6 @@ class KBBase(ABC):
logger="current",
level=logging.WARNING,
)
# TODO 添加半监督
# TODO 添加consistency measure+max_err容忍错误

@abstractmethod
def logic_forward(self, pseudo_label: List[Any], x: Optional[List[Any]] = None) -> Any:


+ 1
- 0
pyproject.toml View File

@@ -17,6 +17,7 @@ classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Science/Research",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development",


Loading…
Cancel
Save