import torch from torchtext.datasets import SST2 def get_sst2(data_root="./data"): train_datapipe = SST2(root="./data", split="train") X_train = [x[0] for x in train_datapipe] y_train = [x[1] for x in train_datapipe] dev_datapipe = SST2(root="./data", split="dev") X_test = [x[0] for x in dev_datapipe] y_test = [x[1] for x in dev_datapipe] return X_train, y_train, X_test, y_test