Browse Source

[DOC] add readme.md for models part

pull/3/head
Gao Enhao 2 years ago
parent
commit
6a5cfbf5ec
1 changed files with 136 additions and 0 deletions
  1. +136
    -0
      abl/models/readme.md

+ 136
- 0
abl/models/readme.md View File

@@ -0,0 +1,136 @@
# `basic_model.py`

可以使用`basic_model.py`中实现的`BasicModel`类将`pytorch`神经网络模型包装成`sklearn`模型的形式.

## BasicModel 类提供的接口

| 方法 | 功能 |
| ---- | ---- |
| fit(X, y) | 训练神经网络 |
| predict(X) | 预测 X 的类别 |
| predict_proba(X) | 预测 X 的类别概率 |
| score(X, y) | 计算模型在测试数据上的准确率 |
| save() | 保存模型 |
| load() | 加载模型 |


## BasicModel 类的参数

**model : torch.nn.Module**
+ The PyTorch model to be trained or used for prediction.

**batch_size : int**
+ The batch size used for training.

**num_epochs : int**
+ The number of epochs used for training.

**stop_loss : Optional[float]**
+ The loss value at which to stop training.

**num_workers : int**
+ The number of workers used for loading data.

**criterion : torch.nn.Module**
+ The loss function used for training.

**optimizer : torch.nn.Module**
+ The optimizer used for training.

**transform : Callable[..., Any]**
+ The transformation function used for data augmentation.

**device : torch.device**
+ The device on which the model will be trained or used for prediction.

**recorder : Any**
+ The recorder used to record training progress.

**save_interval : Optional[int]**
+ The interval at which to save the model during training.

**save_dir : Optional[str]**
+ The directory in which to save the model during training.

**collate_fn : Callable[[List[T]], Any]**
+ The function used to collate data.

## 例子
>
> ```python
> # Three necessary component
> cls = LeNet5()
> criterion = nn.CrossEntropyLoss()
> optimizer = torch.optim.Adam(cls.parameters())
>
> # Initialize base_model
> base_model = BasicModel(
> cls,
> criterion,
> optimizer,
> torch.device("cuda:0"),
> batch_size=32,
> num_epochs=10,
> )
>
> # Prepare data
> train_X, train_y = get_train_data()
> test_X, test_y = get_test_data()
>
> # Train model
> base_model.fit(train_X, train_y)
>
> # Predict
> base_model.predict(test_X)
>
> # Validation
> base_model.score(test_X, test_y)
> ```

# `wabl_models.py`

`wabl_models.py`中实现的`WABLBasicModel`能够序列化数据并为不同的机器学习模型提供统一的接口.

## WABLBasicModel 类提供的接口

| 方法 | 功能 |
| ---- | ---- |
| train(X, Y) | 利用训练数据训练机器学习模型(不涉及反绎) |
| predict(X) | 预测 X 的类别和概率 |
| valid(X, Y) | 计算模型在测试数据上的准确率 |

## WABLBasicModel 类的参数
**base_model : Machine Learning Model**
+ The base model to use for training and prediction.

**pseudo_label_list : List[Any]**
+ A list of pseudo labels to use for training.

## 序列化数据
考虑到训练数据可能多种组织形式,比如:\
`X: List[List[img]], Y: List[List[label]]`\
`X: List[List[img]], Y: List[label]`\
`X: List[img], Y: List[label]`
... \
不便于训练. 因此先将形式统一为:`X: List[img], Y: List[label]`,也就是所谓的序列化数据.

## 例子
>
> ```python
> # Three necessary component
> # 'ml_model' is no longer limited to NN models
> model = WABLBasicModel(ml_model, kb.pseudo_label_list)
>
> # Prepare data
> train_X, train_y = get_train_data()
> test_X, test_y = get_test_data()
>
> # Train model
> model.train(train_X, train_y)
>
> # Predict
> model.predict(test_X)
>
> # Validation
> model.valid(test_X, test_y)
> ```

Loading…
Cancel
Save