Are you sure you want to delete this task? Once this task is deleted, it cannot be recovered.
|
|
2 years ago | |
|---|---|---|
| .. | ||
| __init__.py | 2 years ago | |
| abl_model.py | 2 years ago | |
| basic_nn.py | 2 years ago | |
| readme.md | 2 years ago | |
basic_model.py可以使用basic_model.py中实现的BasicModel类将pytorch神经网络模型包装成sklearn模型的形式.
| 方法 | 功能 |
|---|---|
| fit(X, y) | 训练神经网络 |
| predict(X) | 预测 X 的类别 |
| predict_proba(X) | 预测 X 的类别概率 |
| score(X, y) | 计算模型在测试数据上的准确率 |
| save() | 保存模型 |
| load() | 加载模型 |
model : torch.nn.Module
batch_size : int
num_epochs : int
stop_loss : Optional[float]
num_workers : int
loss_fn : torch.nn.Module
optimizer : torch.nn.Module
transform : Callable[..., Any]
device : torch.device
recorder : Any
save_interval : Optional[int]
save_dir : Optional[str]
collate_fn : CallableList[T, Any]
# Three necessary component cls = LeNet5() loss_fn = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(cls.parameters()) # Initialize base_model base_model = BasicModel( cls, loss_fn, optimizer, torch.device("cuda:0"), batch_size=32, num_epochs=10, ) # Prepare data train_X, train_y = get_train_data() test_X, test_y = get_test_data() # Train model base_model.fit(train_X, train_y) # Predict base_model.predict(test_X) # Validation base_model.score(test_X, test_y)
wabl_models.pywabl_models.py中实现的WABLBasicModel能够序列化数据并为不同的机器学习模型提供统一的接口.
| 方法 | 功能 |
|---|---|
| train(X, Y) | 利用训练数据训练机器学习模型(不涉及反绎) |
| predict(X) | 预测 X 的类别和概率 |
| valid(X, Y) | 计算模型在测试数据上的准确率 |
base_model : Machine Learning Model
pseudo_label_list : List[Any]
考虑到训练数据可能多种组织形式,比如:
X: List[List[img]], Y: List[List[label]]
X: List[List[img]], Y: List[label]
X: List[img], Y: List[label]
...
不便于训练. 因此先将形式统一为:X: List[img], Y: List[label],也就是所谓的序列化数据.
# Three necessary component # 'ml_model' is no longer limited to NN models model = WABLBasicModel(ml_model, kb.pseudo_label_list) # Prepare data train_X, train_y = get_train_data() test_X, test_y = get_test_data() # Train model model.train(train_X, train_y) # Predict model.predict(test_X) # Validation model.valid(test_X, test_y)
An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.
Python other