From 88fa46669152bfeeae4a6e98eab728be9265cd7c Mon Sep 17 00:00:00 2001 From: Gao Enhao Date: Mon, 27 Nov 2023 16:33:49 +0800 Subject: [PATCH] [MNT] modify doc string in learning --- abl/learning/abl_model.py | 28 ++++------------------------ abl/learning/basic_nn.py | 8 ++++++-- 2 files changed, 10 insertions(+), 26 deletions(-) diff --git a/abl/learning/abl_model.py b/abl/learning/abl_model.py index bcf03df..b433c64 100644 --- a/abl/learning/abl_model.py +++ b/abl/learning/abl_model.py @@ -24,24 +24,6 @@ class ABLModel: ---------- base_model : Machine Learning Model The base model to use for training and prediction. - - Attributes - ---------- - classifier_list : List[Any] - A list of classifiers. - - Methods - ------- - predict(X: List[List[Any]], mapping: Optional[Dict] = None) -> Dict - Predict the labels and probabilities for the given data. - valid(X: List[List[Any]], Y: List[Any]) -> float - Calculate the accuracy score for the given data. - train(X: List[List[Any]], Y: List[Any]) -> float - Train the model on the given data. - save(*args, **kwargs) -> None - Save the model to a file. - load(*args, **kwargs) -> None - Load the model from a file. """ def __init__(self, base_model: Any) -> None: @@ -56,8 +38,8 @@ class ABLModel: Parameters ---------- - X : List[List[Any]] - The data to predict on. + data_samples : ListData + A batch of data to predict on. Returns ------- @@ -86,10 +68,8 @@ class ABLModel: Parameters ---------- - X : List[List[Any]] - The data to train on. - Y : List[Any] - The true labels for the given data. + data_samples : ListData + A batch of data to train on, which typically contains the data, `X`, and the corresponding labels, `abduced_idx`. Returns ------- diff --git a/abl/learning/basic_nn.py b/abl/learning/basic_nn.py index 115b098..9a068b5 100644 --- a/abl/learning/basic_nn.py +++ b/abl/learning/basic_nn.py @@ -48,8 +48,10 @@ class BasicNN: The interval at which to save the model during training, by default None. save_dir : Optional[str], optional The directory in which to save the model during training, by default None. - transform : Callable[..., Any], optional - A function/transform that takes in an object and returns a transformed version, by default None. + train_transform : Callable[..., Any], optional + A function/transform that takes in an object and returns a transformed version used in the `fit` and `train_epoch` methods, by default None. + test_transform : Callable[..., Any], optional + A function/transform that takes in an object and returns a transformed version in the `predict`, `predict_proba` and `score` methods, , by default None. collate_fn : Callable[[List[T]], Any], optional The function used to collate data, by default None. """ @@ -344,6 +346,8 @@ class BasicNN: Input samples. y : List[int], optional Target labels. If None, dummy labels are created, by default None. + shuffle : bool, optional + Whether to shuffle the data, by default True. Returns -------