You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 6.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. import os
  2. import numpy as np
  3. import random
  4. import math
  5. import torch
  6. import torch.nn as nn
  7. import torch.optim as optim
  8. import pickle
  9. import torchtext.transforms as T
  10. from torch.hub import load_state_dict_from_url
  11. from torchtext.models import RobertaClassificationHead, XLMR_BASE_ENCODER
  12. import torchtext.functional as F
  13. from torch.optim import AdamW
  14. from torch.utils.data import DataLoader
  15. class TextDataLoader:
  16. def __init__(self, data_root, train: bool = True):
  17. self.data_root = data_root
  18. self.train = train
  19. def get_idx_data(self, idx=0):
  20. if self.train:
  21. X_path = os.path.join(self.data_root, "uploader", "uploader_%d_X.pkl" % (idx))
  22. y_path = os.path.join(self.data_root, "uploader", "uploader_%d_y.pkl" % (idx))
  23. if not (os.path.exists(X_path) and os.path.exists(y_path)):
  24. raise Exception("Index Error")
  25. with open(X_path, "rb") as f:
  26. X = pickle.load(f)
  27. with open(y_path, "rb") as f:
  28. y = pickle.load(f)
  29. else:
  30. X_path = os.path.join(self.data_root, "user", "user_%d_X.pkl" % (idx))
  31. y_path = os.path.join(self.data_root, "user", "user_%d_y.pkl" % (idx))
  32. if not (os.path.exists(X_path) and os.path.exists(y_path)):
  33. raise Exception("Index Error")
  34. with open(X_path, "rb") as f:
  35. X = pickle.load(f)
  36. with open(y_path, "rb") as f:
  37. y = pickle.load(f)
  38. return X, y
  39. def generate_uploader(data_x, data_y, n_uploaders=50, data_save_root=None):
  40. if data_save_root is None:
  41. return
  42. os.makedirs(data_save_root, exist_ok=True)
  43. n = len(data_x)
  44. for i in range(n_uploaders):
  45. selected_X = data_x[i * (n // n_uploaders) : (i + 1) * (n // n_uploaders)]
  46. selected_y = data_y[i * (n // n_uploaders) : (i + 1) * (n // n_uploaders)]
  47. X_save_dir = os.path.join(data_save_root, "uploader_%d_X.pkl" % (i))
  48. y_save_dir = os.path.join(data_save_root, "uploader_%d_y.pkl" % (i))
  49. with open(X_save_dir, "wb") as f:
  50. pickle.dump(selected_X, f)
  51. with open(y_save_dir, "wb") as f:
  52. pickle.dump(selected_y, f)
  53. print("Saving to %s" % (X_save_dir))
  54. def generate_user(data_x, data_y, n_users=50, data_save_root=None):
  55. if data_save_root is None:
  56. return
  57. os.makedirs(data_save_root, exist_ok=True)
  58. n = len(data_x)
  59. for i in range(n_users):
  60. selected_X = data_x[i * (n // n_users) : (i + 1) * (n // n_users)]
  61. selected_y = data_y[i * (n // n_users) : (i + 1) * (n // n_users)]
  62. X_save_dir = os.path.join(data_save_root, "user_%d_X.pkl" % (i))
  63. y_save_dir = os.path.join(data_save_root, "user_%d_y.pkl" % (i))
  64. with open(X_save_dir, "wb") as f:
  65. pickle.dump(selected_X, f)
  66. with open(y_save_dir, "wb") as f:
  67. pickle.dump(selected_y, f)
  68. print("Saving to %s" % (X_save_dir))
  69. def sentence_preprocess(x_datapipe):
  70. padding_idx = 1
  71. bos_idx = 0
  72. eos_idx = 2
  73. max_seq_len = 256
  74. xlmr_vocab_path = r"https://download.pytorch.org/models/text/xlmr.vocab.pt"
  75. xlmr_spm_model_path = r"https://download.pytorch.org/models/text/xlmr.sentencepiece.bpe.model"
  76. text_transform = T.Sequential(
  77. T.SentencePieceTokenizer(xlmr_spm_model_path),
  78. T.VocabTransform(load_state_dict_from_url(xlmr_vocab_path)),
  79. T.Truncate(max_seq_len - 2),
  80. T.AddToken(token=bos_idx, begin=True),
  81. T.AddToken(token=eos_idx, begin=False),
  82. )
  83. x_datapipe = [text_transform(x) for x in x_datapipe]
  84. # x_datapipe = x_datapipe.map(text_transform)
  85. return x_datapipe
  86. def train_step(model, criteria, optim, input, target):
  87. output = model(input)
  88. loss = criteria(output, target)
  89. optim.zero_grad()
  90. loss.backward()
  91. optim.step()
  92. def eval_step(model, criteria, input, target):
  93. output = model(input)
  94. loss = criteria(output, target).item()
  95. return float(loss), (output.argmax(1) == target).type(torch.float).sum().item()
  96. def evaluate(model, criteria, dev_dataloader):
  97. model.eval()
  98. total_loss = 0
  99. correct_predictions = 0
  100. total_predictions = 0
  101. counter = 0
  102. with torch.no_grad():
  103. for batch in dev_dataloader:
  104. input = F.to_tensor(batch["token_ids"], padding_value=1).to(DEVICE)
  105. target = torch.tensor(batch["target"]).to(DEVICE)
  106. loss, predictions = eval_step(model, criteria, input, target)
  107. total_loss += loss
  108. correct_predictions += predictions
  109. total_predictions += len(target)
  110. counter += 1
  111. return total_loss / counter, correct_predictions / total_predictions
  112. DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  113. # Train Uploaders' models
  114. def train(X, y, out_classes, epochs=35, batch_size=128):
  115. # print(X.shape, y.shape)
  116. from torchdata.datapipes.iter import IterableWrapper
  117. X = sentence_preprocess(X)
  118. data_size = len(X)
  119. train_datapipe = list(zip(X, y))
  120. train_datapipe = IterableWrapper(train_datapipe)
  121. train_datapipe = train_datapipe.batch(batch_size)
  122. train_datapipe = train_datapipe.rows2columnar(["token_ids", "target"])
  123. train_dataloader = DataLoader(train_datapipe, batch_size=None)
  124. num_classes = 2
  125. input_dim = 768
  126. classifier_head = RobertaClassificationHead(num_classes=num_classes, input_dim=input_dim)
  127. model = XLMR_BASE_ENCODER.get_model(head=classifier_head)
  128. learning_rate = 1e-5
  129. optim = AdamW(model.parameters(), lr=learning_rate)
  130. criteria = nn.CrossEntropyLoss()
  131. model.to(DEVICE)
  132. num_epochs = 10
  133. for e in range(num_epochs):
  134. for batch in train_dataloader:
  135. input = F.to_tensor(batch["token_ids"], padding_value=1).to(DEVICE)
  136. target = torch.tensor(batch["target"]).to(DEVICE)
  137. train_step(model, criteria, optim, input, target)
  138. loss, accuracy = evaluate(model, criteria, train_dataloader)
  139. print("Epoch = [{}], loss = [{}], accuracy = [{}]".format(e, loss, accuracy))
  140. return model
  141. def eval_prediction(pred_y, target_y):
  142. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  143. if not isinstance(pred_y, np.ndarray):
  144. pred_y = pred_y.detach().cpu().numpy()
  145. if len(pred_y.shape) == 1:
  146. predicted = np.array(pred_y)
  147. else:
  148. predicted = np.argmax(pred_y, 1)
  149. annos = np.array(target_y)
  150. # print(predicted, annos)
  151. # annos = target_y
  152. total = predicted.shape[0]
  153. correct = (predicted == annos).sum().item()
  154. criterion = nn.CrossEntropyLoss()
  155. return correct / total