diff --git a/learnware/learnware/reuse.py b/learnware/learnware/reuse.py index d3b0c54..1526190 100644 --- a/learnware/learnware/reuse.py +++ b/learnware/learnware/reuse.py @@ -1,6 +1,6 @@ import torch import numpy as np -import tensorflow as tf +# import tensorflow as tf from typing import Tuple, Any, List, Union, Dict from cvxopt import matrix, solvers from lightgbm import LGBMClassifier @@ -56,8 +56,11 @@ class JobSelectorReuser(BaseReuser): pred_y = self.learnware_list[idx].predict(user_data[data_idx_list]) if isinstance(pred_y, torch.Tensor): pred_y = pred_y.detach().cpu().numpy() - elif isinstance(pred_y, tf.Tensor): - pred_y = pred_y.numpy() + # elif isinstance(pred_y, tf.Tensor): + # pred_y = pred_y.numpy() + + if not isinstance(pred_y, np.ndarray): + raise TypeError(f"Model output must be np.ndarray or torch.Tensor") pred_y_list.append(pred_y) data_idxs_list.append(data_idx_list) @@ -292,9 +295,12 @@ class AveragingReuser(BaseReuser): pred_y = self.learnware_list[idx].predict(user_data) if isinstance(pred_y, torch.Tensor): pred_y = pred_y.detach().cpu().numpy() - elif isinstance(pred_y, tf.Tensor): - pred_y = pred_y.numpy() + # elif isinstance(pred_y, tf.Tensor): + # pred_y = pred_y.numpy() + if not isinstance(pred_y, np.ndarray): + raise TypeError(f"Model output must be np.ndarray or torch.Tensor") + if self.mode == "mean": if mean_pred_y is None: mean_pred_y = pred_y