From 6a5cfbf5ecc4e0ec5b638aede31cce378b189570 Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Thu, 30 Mar 2023 20:52:56 +0800 Subject: [PATCH] [DOC] add readme.md for models part --- abl/models/readme.md | 136 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 136 insertions(+) create mode 100644 abl/models/readme.md diff --git a/abl/models/readme.md b/abl/models/readme.md new file mode 100644 index 0000000..76e7bbc --- /dev/null +++ b/abl/models/readme.md @@ -0,0 +1,136 @@ +# `basic_model.py` + +可以使用`basic_model.py`中实现的`BasicModel`类将`pytorch`神经网络模型包装成`sklearn`模型的形式. + +## BasicModel 类提供的接口 + +| 方法 | 功能 | +| ---- | ---- | +| fit(X, y) | 训练神经网络 | +| predict(X) | 预测 X 的类别 | +| predict_proba(X) | 预测 X 的类别概率 | +| score(X, y) | 计算模型在测试数据上的准确率 | +| save() | 保存模型 | +| load() | 加载模型 | + + +## BasicModel 类的参数 + +**model : torch.nn.Module** ++ The PyTorch model to be trained or used for prediction. + +**batch_size : int** ++ The batch size used for training. + +**num_epochs : int** ++ The number of epochs used for training. + +**stop_loss : Optional[float]** ++ The loss value at which to stop training. + +**num_workers : int** ++ The number of workers used for loading data. + +**criterion : torch.nn.Module** ++ The loss function used for training. + +**optimizer : torch.nn.Module** ++ The optimizer used for training. + +**transform : Callable[..., Any]** ++ The transformation function used for data augmentation. + +**device : torch.device** ++ The device on which the model will be trained or used for prediction. + +**recorder : Any** ++ The recorder used to record training progress. + +**save_interval : Optional[int]** ++ The interval at which to save the model during training. + +**save_dir : Optional[str]** ++ The directory in which to save the model during training. + +**collate_fn : Callable[[List[T]], Any]** ++ The function used to collate data. + +## 例子 +> +> ```python +> # Three necessary component +> cls = LeNet5() +> criterion = nn.CrossEntropyLoss() +> optimizer = torch.optim.Adam(cls.parameters()) +> +> # Initialize base_model +> base_model = BasicModel( +> cls, +> criterion, +> optimizer, +> torch.device("cuda:0"), +> batch_size=32, +> num_epochs=10, +> ) +> +> # Prepare data +> train_X, train_y = get_train_data() +> test_X, test_y = get_test_data() +> +> # Train model +> base_model.fit(train_X, train_y) +> +> # Predict +> base_model.predict(test_X) +> +> # Validation +> base_model.score(test_X, test_y) +> ``` + +# `wabl_models.py` + +`wabl_models.py`中实现的`WABLBasicModel`能够序列化数据并为不同的机器学习模型提供统一的接口. + +## WABLBasicModel 类提供的接口 + +| 方法 | 功能 | +| ---- | ---- | +| train(X, Y) | 利用训练数据训练机器学习模型(不涉及反绎) | +| predict(X) | 预测 X 的类别和概率 | +| valid(X, Y) | 计算模型在测试数据上的准确率 | + +## WABLBasicModel 类的参数 +**base_model : Machine Learning Model** ++ The base model to use for training and prediction. + +**pseudo_label_list : List[Any]** ++ A list of pseudo labels to use for training. + +## 序列化数据 +考虑到训练数据可能多种组织形式,比如:\ +`X: List[List[img]], Y: List[List[label]]`\ +`X: List[List[img]], Y: List[label]`\ +`X: List[img], Y: List[label]` +... \ +不便于训练. 因此先将形式统一为:`X: List[img], Y: List[label]`,也就是所谓的序列化数据. + +## 例子 +> +> ```python +> # Three necessary component +> # 'ml_model' is no longer limited to NN models +> model = WABLBasicModel(ml_model, kb.pseudo_label_list) +> +> # Prepare data +> train_X, train_y = get_train_data() +> test_X, test_y = get_test_data() +> +> # Train model +> model.train(train_X, train_y) +> +> # Predict +> model.predict(test_X) +> +> # Validation +> model.valid(test_X, test_y) +> ``` \ No newline at end of file