You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

get_data.py 414 B

123456789101112131415
  1. import torch
  2. from torchtext.datasets import SST2
  3. def get_sst2(data_root="./data"):
  4. train_datapipe = SST2(root="./data", split="train")
  5. X_train = [x[0] for x in train_datapipe]
  6. y_train = [x[1] for x in train_datapipe]
  7. dev_datapipe = SST2(root="./data", split="dev")
  8. X_test = [x[0] for x in dev_datapipe]
  9. y_test = [x[1] for x in dev_datapipe]
  10. return X_train, y_train, X_test, y_test