Browse Source

[MNT] modify doc string in learning

pull/1/head
Gao Enhao 2 years ago
parent
commit
88fa466691
2 changed files with 10 additions and 26 deletions
  1. +4
    -24
      abl/learning/abl_model.py
  2. +6
    -2
      abl/learning/basic_nn.py

+ 4
- 24
abl/learning/abl_model.py View File

@@ -24,24 +24,6 @@ class ABLModel:
----------
base_model : Machine Learning Model
The base model to use for training and prediction.

Attributes
----------
classifier_list : List[Any]
A list of classifiers.

Methods
-------
predict(X: List[List[Any]], mapping: Optional[Dict] = None) -> Dict
Predict the labels and probabilities for the given data.
valid(X: List[List[Any]], Y: List[Any]) -> float
Calculate the accuracy score for the given data.
train(X: List[List[Any]], Y: List[Any]) -> float
Train the model on the given data.
save(*args, **kwargs) -> None
Save the model to a file.
load(*args, **kwargs) -> None
Load the model from a file.
"""

def __init__(self, base_model: Any) -> None:
@@ -56,8 +38,8 @@ class ABLModel:

Parameters
----------
X : List[List[Any]]
The data to predict on.
data_samples : ListData
A batch of data to predict on.

Returns
-------
@@ -86,10 +68,8 @@ class ABLModel:

Parameters
----------
X : List[List[Any]]
The data to train on.
Y : List[Any]
The true labels for the given data.
data_samples : ListData
A batch of data to train on, which typically contains the data, `X`, and the corresponding labels, `abduced_idx`.

Returns
-------


+ 6
- 2
abl/learning/basic_nn.py View File

@@ -48,8 +48,10 @@ class BasicNN:
The interval at which to save the model during training, by default None.
save_dir : Optional[str], optional
The directory in which to save the model during training, by default None.
transform : Callable[..., Any], optional
A function/transform that takes in an object and returns a transformed version, by default None.
train_transform : Callable[..., Any], optional
A function/transform that takes in an object and returns a transformed version used in the `fit` and `train_epoch` methods, by default None.
test_transform : Callable[..., Any], optional
A function/transform that takes in an object and returns a transformed version in the `predict`, `predict_proba` and `score` methods, , by default None.
collate_fn : Callable[[List[T]], Any], optional
The function used to collate data, by default None.
"""
@@ -344,6 +346,8 @@ class BasicNN:
Input samples.
y : List[int], optional
Target labels. If None, dummy labels are created, by default None.
shuffle : bool, optional
Whether to shuffle the data, by default True.

Returns
-------


Loading…
Cancel
Save