| @@ -0,0 +1,136 @@ | |||
| # `basic_model.py` | |||
| 可以使用`basic_model.py`中实现的`BasicModel`类将`pytorch`神经网络模型包装成`sklearn`模型的形式. | |||
| ## BasicModel 类提供的接口 | |||
| | 方法 | 功能 | | |||
| | ---- | ---- | | |||
| | fit(X, y) | 训练神经网络 | | |||
| | predict(X) | 预测 X 的类别 | | |||
| | predict_proba(X) | 预测 X 的类别概率 | | |||
| | score(X, y) | 计算模型在测试数据上的准确率 | | |||
| | save() | 保存模型 | | |||
| | load() | 加载模型 | | |||
| ## BasicModel 类的参数 | |||
| **model : torch.nn.Module** | |||
| + The PyTorch model to be trained or used for prediction. | |||
| **batch_size : int** | |||
| + The batch size used for training. | |||
| **num_epochs : int** | |||
| + The number of epochs used for training. | |||
| **stop_loss : Optional[float]** | |||
| + The loss value at which to stop training. | |||
| **num_workers : int** | |||
| + The number of workers used for loading data. | |||
| **criterion : torch.nn.Module** | |||
| + The loss function used for training. | |||
| **optimizer : torch.nn.Module** | |||
| + The optimizer used for training. | |||
| **transform : Callable[..., Any]** | |||
| + The transformation function used for data augmentation. | |||
| **device : torch.device** | |||
| + The device on which the model will be trained or used for prediction. | |||
| **recorder : Any** | |||
| + The recorder used to record training progress. | |||
| **save_interval : Optional[int]** | |||
| + The interval at which to save the model during training. | |||
| **save_dir : Optional[str]** | |||
| + The directory in which to save the model during training. | |||
| **collate_fn : Callable[[List[T]], Any]** | |||
| + The function used to collate data. | |||
| ## 例子 | |||
| > | |||
| > ```python | |||
| > # Three necessary component | |||
| > cls = LeNet5() | |||
| > criterion = nn.CrossEntropyLoss() | |||
| > optimizer = torch.optim.Adam(cls.parameters()) | |||
| > | |||
| > # Initialize base_model | |||
| > base_model = BasicModel( | |||
| > cls, | |||
| > criterion, | |||
| > optimizer, | |||
| > torch.device("cuda:0"), | |||
| > batch_size=32, | |||
| > num_epochs=10, | |||
| > ) | |||
| > | |||
| > # Prepare data | |||
| > train_X, train_y = get_train_data() | |||
| > test_X, test_y = get_test_data() | |||
| > | |||
| > # Train model | |||
| > base_model.fit(train_X, train_y) | |||
| > | |||
| > # Predict | |||
| > base_model.predict(test_X) | |||
| > | |||
| > # Validation | |||
| > base_model.score(test_X, test_y) | |||
| > ``` | |||
| # `wabl_models.py` | |||
| `wabl_models.py`中实现的`WABLBasicModel`能够序列化数据并为不同的机器学习模型提供统一的接口. | |||
| ## WABLBasicModel 类提供的接口 | |||
| | 方法 | 功能 | | |||
| | ---- | ---- | | |||
| | train(X, Y) | 利用训练数据训练机器学习模型(不涉及反绎) | | |||
| | predict(X) | 预测 X 的类别和概率 | | |||
| | valid(X, Y) | 计算模型在测试数据上的准确率 | | |||
| ## WABLBasicModel 类的参数 | |||
| **base_model : Machine Learning Model** | |||
| + The base model to use for training and prediction. | |||
| **pseudo_label_list : List[Any]** | |||
| + A list of pseudo labels to use for training. | |||
| ## 序列化数据 | |||
| 考虑到训练数据可能多种组织形式,比如:\ | |||
| `X: List[List[img]], Y: List[List[label]]`\ | |||
| `X: List[List[img]], Y: List[label]`\ | |||
| `X: List[img], Y: List[label]` | |||
| ... \ | |||
| 不便于训练. 因此先将形式统一为:`X: List[img], Y: List[label]`,也就是所谓的序列化数据. | |||
| ## 例子 | |||
| > | |||
| > ```python | |||
| > # Three necessary component | |||
| > # 'ml_model' is no longer limited to NN models | |||
| > model = WABLBasicModel(ml_model, kb.pseudo_label_list) | |||
| > | |||
| > # Prepare data | |||
| > train_X, train_y = get_train_data() | |||
| > test_X, test_y = get_test_data() | |||
| > | |||
| > # Train model | |||
| > model.train(train_X, train_y) | |||
| > | |||
| > # Predict | |||
| > model.predict(test_X) | |||
| > | |||
| > # Validation | |||
| > model.valid(test_X, test_y) | |||
| > ``` | |||