Browse Source

Merge branch 'dev' of git.nju.edu.cn:learnware/learnware-market into dev

tags/v0.3.2
bxdd 3 years ago
parent
commit
6a67ce8c58
1 changed files with 11 additions and 5 deletions
  1. +11
    -5
      learnware/learnware/reuse.py

+ 11
- 5
learnware/learnware/reuse.py View File

@@ -1,6 +1,6 @@
import torch
import numpy as np
import tensorflow as tf
# import tensorflow as tf
from typing import Tuple, Any, List, Union, Dict
from cvxopt import matrix, solvers
from lightgbm import LGBMClassifier
@@ -56,8 +56,11 @@ class JobSelectorReuser(BaseReuser):
pred_y = self.learnware_list[idx].predict(user_data[data_idx_list])
if isinstance(pred_y, torch.Tensor):
pred_y = pred_y.detach().cpu().numpy()
elif isinstance(pred_y, tf.Tensor):
pred_y = pred_y.numpy()
# elif isinstance(pred_y, tf.Tensor):
# pred_y = pred_y.numpy()
if not isinstance(pred_y, np.ndarray):
raise TypeError(f"Model output must be np.ndarray or torch.Tensor")

pred_y_list.append(pred_y)
data_idxs_list.append(data_idx_list)
@@ -292,9 +295,12 @@ class AveragingReuser(BaseReuser):
pred_y = self.learnware_list[idx].predict(user_data)
if isinstance(pred_y, torch.Tensor):
pred_y = pred_y.detach().cpu().numpy()
elif isinstance(pred_y, tf.Tensor):
pred_y = pred_y.numpy()
# elif isinstance(pred_y, tf.Tensor):
# pred_y = pred_y.numpy()

if not isinstance(pred_y, np.ndarray):
raise TypeError(f"Model output must be np.ndarray or torch.Tensor")
if self.mode == "mean":
if mean_pred_y is None:
mean_pred_y = pred_y


Loading…
Cancel
Save