You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

example_init.py 843 B

123456789101112131415161718192021222324252627282930
  1. import os
  2. import pickle
  3. import numpy as np
  4. from learnware.model import BaseModel
  5. class Model(BaseModel):
  6. def __init__(self):
  7. super(Model, self).__init__(input_shape=(1,), output_shape=(20,))
  8. dir_path = os.path.dirname(os.path.abspath(__file__))
  9. modelv_path = os.path.join(dir_path, "modelv.pth")
  10. with open(modelv_path, "rb") as f:
  11. self.modelv = pickle.load(f)
  12. modell_path = os.path.join(dir_path, "modell.pth")
  13. with open(modell_path, "rb") as f:
  14. self.modell = pickle.load(f)
  15. def fit(self, X: np.ndarray, y: np.ndarray):
  16. pass
  17. def predict(self, X: np.ndarray) -> np.ndarray:
  18. # predict -> predict_proba
  19. return self.modell.predict_proba(self.modelv.transform(X))
  20. def finetune(self, X: np.ndarray, y: np.ndarray):
  21. pass