You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

get_data.py 401 B

1234567891011121314
  1. from torchtext.datasets import SST2
  2. def get_sst2(data_root="./data"):
  3. train_datapipe = SST2(root="./data", split="train")
  4. X_train = [x[0] for x in train_datapipe]
  5. y_train = [x[1] for x in train_datapipe]
  6. dev_datapipe = SST2(root="./data", split="dev")
  7. X_test = [x[0] for x in dev_datapipe]
  8. y_test = [x[1] for x in dev_datapipe]
  9. return X_train, y_train, X_test, y_test