| @@ -233,10 +233,10 @@ class SimpleBridge(BaseBridge): | |||||
| Data will be split into segments of this size and data in each segment | Data will be split into segments of this size and data in each segment | ||||
| will be used together to train the model, by default 1.0. | will be used together to train the model, by default 1.0. | ||||
| eval_interval : int | eval_interval : int | ||||
| The model will be evaluated every ``eval_interval`` loops during training, | |||||
| The model will be evaluated every ``eval_interval`` loop during training, | |||||
| by default 1. | by default 1. | ||||
| save_interval : int, optional | save_interval : int, optional | ||||
| The model will be saved every ``eval_interval`` loops during training, by | |||||
| The model will be saved every ``eval_interval`` loop during training, by | |||||
| default None. | default None. | ||||
| save_dir : str, optional | save_dir : str, optional | ||||
| Directory to save the model, by default None. | Directory to save the model, by default None. | ||||
| @@ -52,7 +52,7 @@ class BaseMetric(metaclass=ABCMeta): | |||||
| ------- | ------- | ||||
| dict | dict | ||||
| The computed metrics. The keys are the names of the metrics, | The computed metrics. The keys are the names of the metrics, | ||||
| and the values are corresponding results. | |||||
| and the values are the corresponding results. | |||||
| """ | """ | ||||
| def evaluate(self) -> dict: | def evaluate(self) -> dict: | ||||
| @@ -64,7 +64,7 @@ class BaseMetric(metaclass=ABCMeta): | |||||
| ------- | ------- | ||||
| dict | dict | ||||
| Evaluation metrics dict on the val dataset. The keys are the | Evaluation metrics dict on the val dataset. The keys are the | ||||
| names of the metrics, and the values are corresponding results. | |||||
| names of the metrics, and the values are the corresponding results. | |||||
| """ | """ | ||||
| if len(self.results) == 0: | if len(self.results) == 0: | ||||
| print_log( | print_log( | ||||
| @@ -11,7 +11,7 @@ class SymbolAccuracy(BaseMetric): | |||||
| A metrics class for evaluating symbol-level accuracy. | A metrics class for evaluating symbol-level accuracy. | ||||
| This class is designed to assess the accuracy of symbol prediction. Symbol accuracy | This class is designed to assess the accuracy of symbol prediction. Symbol accuracy | ||||
| are calculated by comparing predicted presudo labels and their ground truth. | |||||
| is calculated by comparing predicted presudo labels and their ground truth. | |||||
| Parameters | Parameters | ||||
| ---------- | ---------- | ||||
| @@ -86,7 +86,7 @@ class ABLModel: | |||||
| Returns | Returns | ||||
| ------- | ------- | ||||
| float | float | ||||
| The accuracy the trained model. | |||||
| The accuracy of the trained model. | |||||
| """ | """ | ||||
| data_X = data_examples.flatten("X") | data_X = data_examples.flatten("X") | ||||
| data_y = data_examples.flatten("abduced_idx") | data_y = data_examples.flatten("abduced_idx") | ||||
| @@ -40,7 +40,7 @@ class BasicNN: | |||||
| num_workers : int | num_workers : int | ||||
| The number of workers used for loading data, by default 0. | The number of workers used for loading data, by default 0. | ||||
| save_interval : int, optional | save_interval : int, optional | ||||
| The model will be saved every ``save_interval`` epochs during training, by default None. | |||||
| The model will be saved every ``save_interval`` epoch during training, by default None. | |||||
| save_dir : str, optional | save_dir : str, optional | ||||
| The directory in which to save the model during training, by default None. | The directory in which to save the model during training, by default None. | ||||
| train_transform : Callable[..., Any], optional | train_transform : Callable[..., Any], optional | ||||
| @@ -63,7 +63,7 @@ class ModelConverter: | |||||
| num_workers : int | num_workers : int | ||||
| The number of workers used for loading data, by default 0. | The number of workers used for loading data, by default 0. | ||||
| save_interval : int, optional | save_interval : int, optional | ||||
| The model will be saved every ``save_interval`` epochs during training, by default None. | |||||
| The model will be saved every ``save_interval`` epoch during training, by default None. | |||||
| save_dir : str, optional | save_dir : str, optional | ||||
| The directory in which to save the model during training, by default None. | The directory in which to save the model during training, by default None. | ||||
| train_transform : Callable[..., Any], optional | train_transform : Callable[..., Any], optional | ||||
| @@ -153,7 +153,7 @@ class ModelConverter: | |||||
| num_workers : int | num_workers : int | ||||
| The number of workers used for loading data, by default 0. | The number of workers used for loading data, by default 0. | ||||
| save_interval : int, optional | save_interval : int, optional | ||||
| The model will be saved every ``save_interval`` epochs during training, by default None. | |||||
| The model will be saved every ``save_interval`` epoch during training, by default None. | |||||
| save_dir : str, optional | save_dir : str, optional | ||||
| The directory in which to save the model during training, by default None. | The directory in which to save the model during training, by default None. | ||||
| train_transform : Callable[..., Any], optional | train_transform : Callable[..., Any], optional | ||||
| @@ -94,7 +94,7 @@ def confidence_dist(pred_prob: List[np.ndarray], candidates_idxs: List[List[Any] | |||||
| Parameters | Parameters | ||||
| ---------- | ---------- | ||||
| pred_prob : List[np.ndarray] | pred_prob : List[np.ndarray] | ||||
| Prediction probability distributions, each element is an ndarray | |||||
| Prediction probability distributions, each element is an array | |||||
| representing the probability distribution of a particular prediction. | representing the probability distribution of a particular prediction. | ||||
| candidates_idxs : List[List[Any]] | candidates_idxs : List[List[Any]] | ||||
| Multiple possible candidates' indices. | Multiple possible candidates' indices. | ||||