From 7ceb3a3da14b09ca8bfae11467c42c59b2000037 Mon Sep 17 00:00:00 2001 From: Tony-HYX <605698554@qq.com> Date: Sat, 6 Jan 2024 21:19:46 +0800 Subject: [PATCH] [MNI] Minor issues in code --- abl/learning/basic_nn.py | 11 ++++++----- abl/reasoning/kb.py | 2 -- pyproject.toml | 1 + 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/abl/learning/basic_nn.py b/abl/learning/basic_nn.py index b2309e0..18f38ec 100644 --- a/abl/learning/basic_nn.py +++ b/abl/learning/basic_nn.py @@ -307,8 +307,9 @@ class BasicNN: data_loader = DataLoader( dataset, batch_size=self.batch_size, - num_workers=int(self.num_workers), + num_workers=self.num_workers, collate_fn=self.collate_fn, + pin_memory=torch.cuda.is_available() ) return self._predict(data_loader).argmax(axis=1).cpu().numpy() @@ -348,8 +349,9 @@ class BasicNN: data_loader = DataLoader( dataset, batch_size=self.batch_size, - num_workers=int(self.num_workers), + num_workers=self.num_workers, collate_fn=self.collate_fn, + pin_memory=torch.cuda.is_available() ) return self._predict(data_loader).softmax(axis=1).cpu().numpy() @@ -381,11 +383,9 @@ class BasicNN: model.eval() total_correct_num, total_num, total_loss = 0, 0, 0.0 - with torch.no_grad(): for data, target in data_loader: data, target = data.to(device), target.to(device) - out = model(data) if len(out.shape) > 1: @@ -482,8 +482,9 @@ class BasicNN: dataset, batch_size=self.batch_size, shuffle=shuffle, - num_workers=int(self.num_workers), + num_workers=self.num_workers, collate_fn=self.collate_fn, + pin_memory=torch.cuda.is_available() ) return data_loader diff --git a/abl/reasoning/kb.py b/abl/reasoning/kb.py index 0fd9f48..90e0dc0 100644 --- a/abl/reasoning/kb.py +++ b/abl/reasoning/kb.py @@ -77,8 +77,6 @@ class KBBase(ABC): logger="current", level=logging.WARNING, ) - # TODO 添加半监督 - # TODO 添加consistency measure+max_err容忍错误 @abstractmethod def logic_forward(self, pseudo_label: List[Any], x: Optional[List[Any]] = None) -> Any: diff --git a/pyproject.toml b/pyproject.toml index a83f0dd..d684c55 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ classifiers = [ "Development Status :: 4 - Beta", "Intended Audience :: Science/Research", "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", "Programming Language :: Python", "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Software Development",