You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

example_init.py 1.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. import os
  2. import joblib
  3. import numpy as np
  4. from learnware.model import BaseModel
  5. import torch
  6. from torchtext.models import RobertaClassificationHead, XLMR_BASE_ENCODER
  7. import torchtext.functional as F
  8. import torchtext.transforms as T
  9. from torch.hub import load_state_dict_from_url
  10. class Model(BaseModel):
  11. def __init__(self):
  12. super().__init__(input_shape=None, output_shape=(2,))
  13. dir_path = os.path.dirname(os.path.abspath(__file__))
  14. self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  15. num_classes = 2
  16. input_dim = 768
  17. classifier_head = RobertaClassificationHead(num_classes=num_classes, input_dim=input_dim)
  18. self.model = XLMR_BASE_ENCODER.get_model(head=classifier_head).to(self.device)
  19. self.model.load_state_dict(torch.load(os.path.join(dir_path, "model.pth")))
  20. def fit(self, X: np.ndarray, y: np.ndarray):
  21. pass
  22. def predict(self, X: np.ndarray) -> np.ndarray:
  23. X = sentence_preprocess(X)
  24. X = F.to_tensor(X, padding_value=1).to(self.device)
  25. return self.model(X)
  26. def finetune(self, X: np.ndarray, y: np.ndarray):
  27. pass
  28. def sentence_preprocess(x_datapipe):
  29. padding_idx = 1
  30. bos_idx = 0
  31. eos_idx = 2
  32. max_seq_len = 256
  33. xlmr_vocab_path = r"https://download.pytorch.org/models/text/xlmr.vocab.pt"
  34. xlmr_spm_model_path = r"https://download.pytorch.org/models/text/xlmr.sentencepiece.bpe.model"
  35. text_transform = T.Sequential(
  36. T.SentencePieceTokenizer(xlmr_spm_model_path),
  37. T.VocabTransform(load_state_dict_from_url(xlmr_vocab_path)),
  38. T.Truncate(max_seq_len - 2),
  39. T.AddToken(token=bos_idx, begin=True),
  40. T.AddToken(token=eos_idx, begin=False),
  41. )
  42. x_datapipe = [text_transform(x) for x in x_datapipe]
  43. # x_datapipe = x_datapipe.map(text_transform)
  44. return x_datapipe